competitive/data_structure/
binary_indexed_tree_2d.rs1use super::{Group, Monoid};
2use std::fmt::{self, Debug, Formatter};
3
4pub struct BinaryIndexedTree2D<M>
5where
6 M: Monoid,
7{
8 h: usize,
9 w: usize,
10 bit: Vec<Vec<M::T>>,
11}
12
13impl<M> Clone for BinaryIndexedTree2D<M>
14where
15 M: Monoid,
16{
17 fn clone(&self) -> Self {
18 Self {
19 h: self.h,
20 w: self.w,
21 bit: self.bit.clone(),
22 }
23 }
24}
25
26impl<M> Debug for BinaryIndexedTree2D<M>
27where
28 M: Monoid,
29 M::T: Debug,
30{
31 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
32 f.debug_struct("BinaryIndexedTree2D")
33 .field("h", &self.h)
34 .field("w", &self.w)
35 .field("bit", &self.bit)
36 .finish()
37 }
38}
39
40impl<M> BinaryIndexedTree2D<M>
41where
42 M: Monoid,
43{
44 #[inline]
45 pub fn new(h: usize, w: usize) -> Self {
46 let bit = vec![vec![M::unit(); w + 1]; h + 1];
47 Self { h, w, bit }
48 }
49 #[inline]
50 pub fn accumulate0(&self, i: usize, j: usize) -> M::T {
52 let mut res = M::unit();
53 let mut a = i;
54 while a > 0 {
55 let mut b = j;
56 while b > 0 {
57 res = M::operate(&res, &self.bit[a][b]);
58 b -= b & (!b + 1);
59 }
60 a -= a & (!a + 1);
61 }
62 res
63 }
64 #[inline]
65 pub fn accumulate(&self, i: usize, j: usize) -> M::T {
67 self.accumulate0(i + 1, j + 1)
68 }
69 #[inline]
70 pub fn update(&mut self, i: usize, j: usize, x: M::T) {
71 let mut a = i + 1;
72 while a <= self.h {
73 let mut b = j + 1;
74 while b <= self.w {
75 self.bit[a][b] = M::operate(&self.bit[a][b], &x);
76 b += b & (!b + 1);
77 }
78 a += a & (!a + 1);
79 }
80 }
81}
82
83impl<G> BinaryIndexedTree2D<G>
84where
85 G: Group,
86{
87 #[inline]
88 pub fn fold(&self, i1: usize, j1: usize, i2: usize, j2: usize) -> G::T {
90 let mut res = G::unit();
91 res = G::operate(&res, &self.accumulate0(i1, j1));
92 res = G::rinv_operate(&res, &self.accumulate0(i1, j2));
93 res = G::rinv_operate(&res, &self.accumulate0(i2, j1));
94 res = G::operate(&res, &self.accumulate0(i2, j2));
95 res
96 }
97 #[inline]
98 pub fn get(&self, i: usize, j: usize) -> G::T {
99 self.fold(i, j, i + 1, j + 1)
100 }
101 #[inline]
102 pub fn set(&mut self, i: usize, j: usize, x: G::T) {
103 let y = G::inverse(&self.get(i, j));
104 let z = G::operate(&y, &x);
105 self.update(i, j, z);
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use crate::{
113 algebra::{AdditiveOperation, MaxOperation},
114 tools::Xorshift,
115 };
116
117 const Q: usize = 100_000;
118 const A: u64 = 1_000_000_000;
119 const B: i64 = 1_000_000_000;
120
121 #[test]
122 fn test_binary_indexed_tree_2d() {
123 let mut rng = Xorshift::new();
124 const H: usize = 150;
125 const W: usize = 250;
126 let mut bit = BinaryIndexedTree2D::<AdditiveOperation<_>>::new(H, W);
127 let mut arr = vec![vec![0; W]; H];
128 for (i, j, v) in rng.random_iter((..H, ..W, ..A)).take(Q) {
129 bit.update(i, j, v);
130 arr[i][j] += v;
131 }
132 for arr in arr.iter_mut() {
133 for j in 0..W - 1 {
134 arr[j + 1] += arr[j];
135 }
136 }
137 for i in 0..H - 1 {
138 for j in 0..W {
139 arr[i + 1][j] += arr[i][j];
140 }
141 }
142 for (i, arr) in arr.iter().enumerate() {
143 for (j, a) in arr.iter().cloned().enumerate() {
144 assert_eq!(bit.accumulate(i, j), a);
145 }
146 }
147
148 let mut bit = BinaryIndexedTree2D::<MaxOperation<_>>::new(H, W);
149 let mut arr = vec![vec![0; W]; H];
150 for (i, j, v) in rng.random_iter((..H, ..W, ..A)).take(Q) {
151 bit.update(i, j, v);
152 arr[i][j] = std::cmp::max(arr[i][j], v);
153 }
154 for arr in arr.iter_mut() {
155 for j in 0..W - 1 {
156 arr[j + 1] = std::cmp::max(arr[j + 1], arr[j]);
157 }
158 }
159 for i in 0..H - 1 {
160 for j in 0..W {
161 arr[i + 1][j] = std::cmp::max(arr[i + 1][j], arr[i][j]);
162 }
163 }
164 for (i, arr) in arr.iter().enumerate() {
165 for (j, a) in arr.iter().cloned().enumerate() {
166 assert_eq!(bit.accumulate(i, j), a);
167 }
168 }
169 }
170
171 #[test]
172 fn test_group_binary_indexed_tree2d() {
173 let mut rng = Xorshift::new();
174 const H: usize = 15;
175 const W: usize = 25;
176 let mut bit = BinaryIndexedTree2D::<AdditiveOperation<_>>::new(H, W);
177 let mut arr = vec![vec![0; W + 1]; H + 1];
178 for (i, j, v) in rng.random_iter((..H, ..W, -B..B)).take(Q) {
179 bit.set(i, j, v);
180 arr[i + 1][j + 1] = v;
181 }
182 for arr in arr.iter_mut() {
183 for j in 0..W {
184 arr[j + 1] += arr[j];
185 }
186 }
187 for i in 0..H {
188 for j in 0..W + 1 {
189 arr[i + 1][j] += arr[i][j];
190 }
191 }
192 for i1 in 0..H {
193 for i2 in i1 + 1..H + 1 {
194 for j1 in 0..W {
195 for j2 in j1 + 1..W + 1 {
196 assert_eq!(
197 bit.fold(i1, j1, i2, j2),
198 arr[i2][j2] - arr[i2][j1] - arr[i1][j2] + arr[i1][j1]
199 );
200 }
201 }
202 }
203 }
204 }
205}