competitive/data_structure/
binary_indexed_tree_2d.rs1use super::{Group, Monoid};
2use std::fmt::{self, Debug, Formatter};
3
4pub struct BinaryIndexedTree2D<M>
5where
6 M: Monoid,
7{
8 h: usize,
9 w: usize,
10 bit: Vec<Vec<M::T>>,
11}
12
13impl<M> Clone for BinaryIndexedTree2D<M>
14where
15 M: Monoid,
16{
17 fn clone(&self) -> Self {
18 Self {
19 h: self.h,
20 w: self.w,
21 bit: self.bit.clone(),
22 }
23 }
24}
25
26impl<M> Debug for BinaryIndexedTree2D<M>
27where
28 M: Monoid<T: Debug>,
29{
30 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
31 f.debug_struct("BinaryIndexedTree2D")
32 .field("h", &self.h)
33 .field("w", &self.w)
34 .field("bit", &self.bit)
35 .finish()
36 }
37}
38
39impl<M> BinaryIndexedTree2D<M>
40where
41 M: Monoid,
42{
43 #[inline]
44 pub fn new(h: usize, w: usize) -> Self {
45 let bit = vec![vec![M::unit(); w + 1]; h + 1];
46 Self { h, w, bit }
47 }
48 #[inline]
49 pub fn accumulate0(&self, i: usize, j: usize) -> M::T {
51 let mut res = M::unit();
52 let mut a = i;
53 while a > 0 {
54 let mut b = j;
55 while b > 0 {
56 res = M::operate(&res, &self.bit[a][b]);
57 b -= b & (!b + 1);
58 }
59 a -= a & (!a + 1);
60 }
61 res
62 }
63 #[inline]
64 pub fn accumulate(&self, i: usize, j: usize) -> M::T {
66 self.accumulate0(i + 1, j + 1)
67 }
68 #[inline]
69 pub fn update(&mut self, i: usize, j: usize, x: M::T) {
70 let mut a = i + 1;
71 while a <= self.h {
72 let mut b = j + 1;
73 while b <= self.w {
74 self.bit[a][b] = M::operate(&self.bit[a][b], &x);
75 b += b & (!b + 1);
76 }
77 a += a & (!a + 1);
78 }
79 }
80}
81
82impl<G> BinaryIndexedTree2D<G>
83where
84 G: Group,
85{
86 #[inline]
87 pub fn fold(&self, i1: usize, j1: usize, i2: usize, j2: usize) -> G::T {
89 let mut res = G::unit();
90 res = G::operate(&res, &self.accumulate0(i1, j1));
91 res = G::rinv_operate(&res, &self.accumulate0(i1, j2));
92 res = G::rinv_operate(&res, &self.accumulate0(i2, j1));
93 res = G::operate(&res, &self.accumulate0(i2, j2));
94 res
95 }
96 #[inline]
97 pub fn get(&self, i: usize, j: usize) -> G::T {
98 self.fold(i, j, i + 1, j + 1)
99 }
100 #[inline]
101 pub fn set(&mut self, i: usize, j: usize, x: G::T) {
102 let y = G::inverse(&self.get(i, j));
103 let z = G::operate(&y, &x);
104 self.update(i, j, z);
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use crate::{
112 algebra::{AdditiveOperation, MaxOperation},
113 tools::Xorshift,
114 };
115
116 const Q: usize = 100_000;
117 const A: u64 = 1_000_000_000;
118 const B: i64 = 1_000_000_000;
119
120 #[test]
121 fn test_binary_indexed_tree_2d() {
122 let mut rng = Xorshift::default();
123 const H: usize = 150;
124 const W: usize = 250;
125 let mut bit = BinaryIndexedTree2D::<AdditiveOperation<_>>::new(H, W);
126 let mut arr = vec![vec![0; W]; H];
127 for (i, j, v) in rng.random_iter((..H, ..W, ..A)).take(Q) {
128 bit.update(i, j, v);
129 arr[i][j] += v;
130 }
131 for arr in arr.iter_mut() {
132 for j in 0..W - 1 {
133 arr[j + 1] += arr[j];
134 }
135 }
136 for i in 0..H - 1 {
137 let [a, b] = arr.get_disjoint_mut([i + 1, i]).unwrap();
138 for (a, b) in a.iter_mut().zip(b) {
139 *a += *b;
140 }
141 }
142 for (i, arr) in arr.iter().enumerate() {
143 for (j, a) in arr.iter().cloned().enumerate() {
144 assert_eq!(bit.accumulate(i, j), a);
145 }
146 }
147
148 let mut bit = BinaryIndexedTree2D::<MaxOperation<_>>::new(H, W);
149 let mut arr = vec![vec![0; W]; H];
150 for (i, j, v) in rng.random_iter((..H, ..W, ..A)).take(Q) {
151 bit.update(i, j, v);
152 arr[i][j] = std::cmp::max(arr[i][j], v);
153 }
154 for arr in arr.iter_mut() {
155 for j in 0..W - 1 {
156 arr[j + 1] = std::cmp::max(arr[j + 1], arr[j]);
157 }
158 }
159 for i in 0..H - 1 {
160 let [a, b] = arr.get_disjoint_mut([i + 1, i]).unwrap();
161 for (a, b) in a.iter_mut().zip(b) {
162 *a = std::cmp::max(*a, *b);
163 }
164 }
165 for (i, arr) in arr.iter().enumerate() {
166 for (j, a) in arr.iter().cloned().enumerate() {
167 assert_eq!(bit.accumulate(i, j), a);
168 }
169 }
170 }
171
172 #[test]
173 fn test_group_binary_indexed_tree2d() {
174 let mut rng = Xorshift::default();
175 const H: usize = 15;
176 const W: usize = 25;
177 let mut bit = BinaryIndexedTree2D::<AdditiveOperation<_>>::new(H, W);
178 let mut arr = vec![vec![0; W + 1]; H + 1];
179 for (i, j, v) in rng.random_iter((..H, ..W, -B..B)).take(Q) {
180 bit.set(i, j, v);
181 arr[i + 1][j + 1] = v;
182 }
183 for arr in arr.iter_mut() {
184 for j in 0..W {
185 arr[j + 1] += arr[j];
186 }
187 }
188 for i in 0..H {
189 let [a, b] = arr.get_disjoint_mut([i + 1, i]).unwrap();
190 for (a, b) in a.iter_mut().zip(b) {
191 *a += *b;
192 }
193 }
194 for i1 in 0..H {
195 for i2 in i1 + 1..H + 1 {
196 for j1 in 0..W {
197 for j2 in j1 + 1..W + 1 {
198 assert_eq!(
199 bit.fold(i1, j1, i2, j2),
200 arr[i2][j2] - arr[i2][j1] - arr[i1][j2] + arr[i1][j1]
201 );
202 }
203 }
204 }
205 }
206 }
207}