competitive/data_structure/
binary_indexed_tree_2d.rs

1use super::{Group, Monoid};
2use std::fmt::{self, Debug, Formatter};
3
4pub struct BinaryIndexedTree2D<M>
5where
6    M: Monoid,
7{
8    h: usize,
9    w: usize,
10    bit: Vec<Vec<M::T>>,
11}
12
13impl<M> Clone for BinaryIndexedTree2D<M>
14where
15    M: Monoid,
16{
17    fn clone(&self) -> Self {
18        Self {
19            h: self.h,
20            w: self.w,
21            bit: self.bit.clone(),
22        }
23    }
24}
25
26impl<M> Debug for BinaryIndexedTree2D<M>
27where
28    M: Monoid,
29    M::T: Debug,
30{
31    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
32        f.debug_struct("BinaryIndexedTree2D")
33            .field("h", &self.h)
34            .field("w", &self.w)
35            .field("bit", &self.bit)
36            .finish()
37    }
38}
39
40impl<M> BinaryIndexedTree2D<M>
41where
42    M: Monoid,
43{
44    #[inline]
45    pub fn new(h: usize, w: usize) -> Self {
46        let bit = vec![vec![M::unit(); w + 1]; h + 1];
47        Self { h, w, bit }
48    }
49    #[inline]
50    /// fold [0, i) x [0, j)
51    pub fn accumulate0(&self, i: usize, j: usize) -> M::T {
52        let mut res = M::unit();
53        let mut a = i;
54        while a > 0 {
55            let mut b = j;
56            while b > 0 {
57                res = M::operate(&res, &self.bit[a][b]);
58                b -= b & (!b + 1);
59            }
60            a -= a & (!a + 1);
61        }
62        res
63    }
64    #[inline]
65    /// fold [0, i] x [0, j]
66    pub fn accumulate(&self, i: usize, j: usize) -> M::T {
67        self.accumulate0(i + 1, j + 1)
68    }
69    #[inline]
70    pub fn update(&mut self, i: usize, j: usize, x: M::T) {
71        let mut a = i + 1;
72        while a <= self.h {
73            let mut b = j + 1;
74            while b <= self.w {
75                self.bit[a][b] = M::operate(&self.bit[a][b], &x);
76                b += b & (!b + 1);
77            }
78            a += a & (!a + 1);
79        }
80    }
81}
82
83impl<G> BinaryIndexedTree2D<G>
84where
85    G: Group,
86{
87    #[inline]
88    /// 0-indexed [i1, i2) x [j1, j2)
89    pub fn fold(&self, i1: usize, j1: usize, i2: usize, j2: usize) -> G::T {
90        let mut res = G::unit();
91        res = G::operate(&res, &self.accumulate0(i1, j1));
92        res = G::rinv_operate(&res, &self.accumulate0(i1, j2));
93        res = G::rinv_operate(&res, &self.accumulate0(i2, j1));
94        res = G::operate(&res, &self.accumulate0(i2, j2));
95        res
96    }
97    #[inline]
98    pub fn get(&self, i: usize, j: usize) -> G::T {
99        self.fold(i, j, i + 1, j + 1)
100    }
101    #[inline]
102    pub fn set(&mut self, i: usize, j: usize, x: G::T) {
103        let y = G::inverse(&self.get(i, j));
104        let z = G::operate(&y, &x);
105        self.update(i, j, z);
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use crate::{
113        algebra::{AdditiveOperation, MaxOperation},
114        tools::Xorshift,
115    };
116
117    const Q: usize = 100_000;
118    const A: u64 = 1_000_000_000;
119    const B: i64 = 1_000_000_000;
120
121    #[test]
122    fn test_binary_indexed_tree_2d() {
123        let mut rng = Xorshift::new();
124        const H: usize = 150;
125        const W: usize = 250;
126        let mut bit = BinaryIndexedTree2D::<AdditiveOperation<_>>::new(H, W);
127        let mut arr = vec![vec![0; W]; H];
128        for (i, j, v) in rng.random_iter((..H, ..W, ..A)).take(Q) {
129            bit.update(i, j, v);
130            arr[i][j] += v;
131        }
132        for arr in arr.iter_mut() {
133            for j in 0..W - 1 {
134                arr[j + 1] += arr[j];
135            }
136        }
137        for i in 0..H - 1 {
138            let [a, b] = arr.get_disjoint_mut([i + 1, i]).unwrap();
139            for (a, b) in a.iter_mut().zip(b) {
140                *a += *b;
141            }
142        }
143        for (i, arr) in arr.iter().enumerate() {
144            for (j, a) in arr.iter().cloned().enumerate() {
145                assert_eq!(bit.accumulate(i, j), a);
146            }
147        }
148
149        let mut bit = BinaryIndexedTree2D::<MaxOperation<_>>::new(H, W);
150        let mut arr = vec![vec![0; W]; H];
151        for (i, j, v) in rng.random_iter((..H, ..W, ..A)).take(Q) {
152            bit.update(i, j, v);
153            arr[i][j] = std::cmp::max(arr[i][j], v);
154        }
155        for arr in arr.iter_mut() {
156            for j in 0..W - 1 {
157                arr[j + 1] = std::cmp::max(arr[j + 1], arr[j]);
158            }
159        }
160        for i in 0..H - 1 {
161            let [a, b] = arr.get_disjoint_mut([i + 1, i]).unwrap();
162            for (a, b) in a.iter_mut().zip(b) {
163                *a = std::cmp::max(*a, *b);
164            }
165        }
166        for (i, arr) in arr.iter().enumerate() {
167            for (j, a) in arr.iter().cloned().enumerate() {
168                assert_eq!(bit.accumulate(i, j), a);
169            }
170        }
171    }
172
173    #[test]
174    fn test_group_binary_indexed_tree2d() {
175        let mut rng = Xorshift::new();
176        const H: usize = 15;
177        const W: usize = 25;
178        let mut bit = BinaryIndexedTree2D::<AdditiveOperation<_>>::new(H, W);
179        let mut arr = vec![vec![0; W + 1]; H + 1];
180        for (i, j, v) in rng.random_iter((..H, ..W, -B..B)).take(Q) {
181            bit.set(i, j, v);
182            arr[i + 1][j + 1] = v;
183        }
184        for arr in arr.iter_mut() {
185            for j in 0..W {
186                arr[j + 1] += arr[j];
187            }
188        }
189        for i in 0..H {
190            let [a, b] = arr.get_disjoint_mut([i + 1, i]).unwrap();
191            for (a, b) in a.iter_mut().zip(b) {
192                *a += *b;
193            }
194        }
195        for i1 in 0..H {
196            for i2 in i1 + 1..H + 1 {
197                for j1 in 0..W {
198                    for j2 in j1 + 1..W + 1 {
199                        assert_eq!(
200                            bit.fold(i1, j1, i2, j2),
201                            arr[i2][j2] - arr[i2][j1] - arr[i1][j2] + arr[i1][j1]
202                        );
203                    }
204                }
205            }
206        }
207    }
208}