competitive/data_structure/
binary_indexed_tree_2d.rs1use super::{Group, Monoid};
2use std::fmt::{self, Debug, Formatter};
3
4pub struct BinaryIndexedTree2D<M>
5where
6 M: Monoid,
7{
8 h: usize,
9 w: usize,
10 bit: Vec<Vec<M::T>>,
11}
12
13impl<M> Clone for BinaryIndexedTree2D<M>
14where
15 M: Monoid,
16{
17 fn clone(&self) -> Self {
18 Self {
19 h: self.h,
20 w: self.w,
21 bit: self.bit.clone(),
22 }
23 }
24}
25
26impl<M> Debug for BinaryIndexedTree2D<M>
27where
28 M: Monoid,
29 M::T: Debug,
30{
31 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
32 f.debug_struct("BinaryIndexedTree2D")
33 .field("h", &self.h)
34 .field("w", &self.w)
35 .field("bit", &self.bit)
36 .finish()
37 }
38}
39
40impl<M> BinaryIndexedTree2D<M>
41where
42 M: Monoid,
43{
44 #[inline]
45 pub fn new(h: usize, w: usize) -> Self {
46 let bit = vec![vec![M::unit(); w + 1]; h + 1];
47 Self { h, w, bit }
48 }
49 #[inline]
50 pub fn accumulate0(&self, i: usize, j: usize) -> M::T {
52 let mut res = M::unit();
53 let mut a = i;
54 while a > 0 {
55 let mut b = j;
56 while b > 0 {
57 res = M::operate(&res, &self.bit[a][b]);
58 b -= b & (!b + 1);
59 }
60 a -= a & (!a + 1);
61 }
62 res
63 }
64 #[inline]
65 pub fn accumulate(&self, i: usize, j: usize) -> M::T {
67 self.accumulate0(i + 1, j + 1)
68 }
69 #[inline]
70 pub fn update(&mut self, i: usize, j: usize, x: M::T) {
71 let mut a = i + 1;
72 while a <= self.h {
73 let mut b = j + 1;
74 while b <= self.w {
75 self.bit[a][b] = M::operate(&self.bit[a][b], &x);
76 b += b & (!b + 1);
77 }
78 a += a & (!a + 1);
79 }
80 }
81}
82
83impl<G> BinaryIndexedTree2D<G>
84where
85 G: Group,
86{
87 #[inline]
88 pub fn fold(&self, i1: usize, j1: usize, i2: usize, j2: usize) -> G::T {
90 let mut res = G::unit();
91 res = G::operate(&res, &self.accumulate0(i1, j1));
92 res = G::rinv_operate(&res, &self.accumulate0(i1, j2));
93 res = G::rinv_operate(&res, &self.accumulate0(i2, j1));
94 res = G::operate(&res, &self.accumulate0(i2, j2));
95 res
96 }
97 #[inline]
98 pub fn get(&self, i: usize, j: usize) -> G::T {
99 self.fold(i, j, i + 1, j + 1)
100 }
101 #[inline]
102 pub fn set(&mut self, i: usize, j: usize, x: G::T) {
103 let y = G::inverse(&self.get(i, j));
104 let z = G::operate(&y, &x);
105 self.update(i, j, z);
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use crate::{
113 algebra::{AdditiveOperation, MaxOperation},
114 tools::Xorshift,
115 };
116
117 const Q: usize = 100_000;
118 const A: u64 = 1_000_000_000;
119 const B: i64 = 1_000_000_000;
120
121 #[test]
122 fn test_binary_indexed_tree_2d() {
123 let mut rng = Xorshift::new();
124 const H: usize = 150;
125 const W: usize = 250;
126 let mut bit = BinaryIndexedTree2D::<AdditiveOperation<_>>::new(H, W);
127 let mut arr = vec![vec![0; W]; H];
128 for (i, j, v) in rng.random_iter((..H, ..W, ..A)).take(Q) {
129 bit.update(i, j, v);
130 arr[i][j] += v;
131 }
132 for arr in arr.iter_mut() {
133 for j in 0..W - 1 {
134 arr[j + 1] += arr[j];
135 }
136 }
137 for i in 0..H - 1 {
138 let [a, b] = arr.get_disjoint_mut([i + 1, i]).unwrap();
139 for (a, b) in a.iter_mut().zip(b) {
140 *a += *b;
141 }
142 }
143 for (i, arr) in arr.iter().enumerate() {
144 for (j, a) in arr.iter().cloned().enumerate() {
145 assert_eq!(bit.accumulate(i, j), a);
146 }
147 }
148
149 let mut bit = BinaryIndexedTree2D::<MaxOperation<_>>::new(H, W);
150 let mut arr = vec![vec![0; W]; H];
151 for (i, j, v) in rng.random_iter((..H, ..W, ..A)).take(Q) {
152 bit.update(i, j, v);
153 arr[i][j] = std::cmp::max(arr[i][j], v);
154 }
155 for arr in arr.iter_mut() {
156 for j in 0..W - 1 {
157 arr[j + 1] = std::cmp::max(arr[j + 1], arr[j]);
158 }
159 }
160 for i in 0..H - 1 {
161 let [a, b] = arr.get_disjoint_mut([i + 1, i]).unwrap();
162 for (a, b) in a.iter_mut().zip(b) {
163 *a = std::cmp::max(*a, *b);
164 }
165 }
166 for (i, arr) in arr.iter().enumerate() {
167 for (j, a) in arr.iter().cloned().enumerate() {
168 assert_eq!(bit.accumulate(i, j), a);
169 }
170 }
171 }
172
173 #[test]
174 fn test_group_binary_indexed_tree2d() {
175 let mut rng = Xorshift::new();
176 const H: usize = 15;
177 const W: usize = 25;
178 let mut bit = BinaryIndexedTree2D::<AdditiveOperation<_>>::new(H, W);
179 let mut arr = vec![vec![0; W + 1]; H + 1];
180 for (i, j, v) in rng.random_iter((..H, ..W, -B..B)).take(Q) {
181 bit.set(i, j, v);
182 arr[i + 1][j + 1] = v;
183 }
184 for arr in arr.iter_mut() {
185 for j in 0..W {
186 arr[j + 1] += arr[j];
187 }
188 }
189 for i in 0..H {
190 let [a, b] = arr.get_disjoint_mut([i + 1, i]).unwrap();
191 for (a, b) in a.iter_mut().zip(b) {
192 *a += *b;
193 }
194 }
195 for i1 in 0..H {
196 for i2 in i1 + 1..H + 1 {
197 for j1 in 0..W {
198 for j2 in j1 + 1..W + 1 {
199 assert_eq!(
200 bit.fold(i1, j1, i2, j2),
201 arr[i2][j2] - arr[i2][j1] - arr[i1][j2] + arr[i1][j1]
202 );
203 }
204 }
205 }
206 }
207 }
208}