competitive/data_structure/
binary_indexed_tree_2d.rs

1use super::{Group, Monoid};
2use std::fmt::{self, Debug, Formatter};
3
4pub struct BinaryIndexedTree2D<M>
5where
6    M: Monoid,
7{
8    h: usize,
9    w: usize,
10    bit: Vec<Vec<M::T>>,
11}
12
13impl<M> Clone for BinaryIndexedTree2D<M>
14where
15    M: Monoid,
16{
17    fn clone(&self) -> Self {
18        Self {
19            h: self.h,
20            w: self.w,
21            bit: self.bit.clone(),
22        }
23    }
24}
25
26impl<M> Debug for BinaryIndexedTree2D<M>
27where
28    M: Monoid,
29    M::T: Debug,
30{
31    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
32        f.debug_struct("BinaryIndexedTree2D")
33            .field("h", &self.h)
34            .field("w", &self.w)
35            .field("bit", &self.bit)
36            .finish()
37    }
38}
39
40impl<M> BinaryIndexedTree2D<M>
41where
42    M: Monoid,
43{
44    #[inline]
45    pub fn new(h: usize, w: usize) -> Self {
46        let bit = vec![vec![M::unit(); w + 1]; h + 1];
47        Self { h, w, bit }
48    }
49    #[inline]
50    /// fold [0, i) x [0, j)
51    pub fn accumulate0(&self, i: usize, j: usize) -> M::T {
52        let mut res = M::unit();
53        let mut a = i;
54        while a > 0 {
55            let mut b = j;
56            while b > 0 {
57                res = M::operate(&res, &self.bit[a][b]);
58                b -= b & (!b + 1);
59            }
60            a -= a & (!a + 1);
61        }
62        res
63    }
64    #[inline]
65    /// fold [0, i] x [0, j]
66    pub fn accumulate(&self, i: usize, j: usize) -> M::T {
67        self.accumulate0(i + 1, j + 1)
68    }
69    #[inline]
70    pub fn update(&mut self, i: usize, j: usize, x: M::T) {
71        let mut a = i + 1;
72        while a <= self.h {
73            let mut b = j + 1;
74            while b <= self.w {
75                self.bit[a][b] = M::operate(&self.bit[a][b], &x);
76                b += b & (!b + 1);
77            }
78            a += a & (!a + 1);
79        }
80    }
81}
82
83impl<G> BinaryIndexedTree2D<G>
84where
85    G: Group,
86{
87    #[inline]
88    /// 0-indexed [i1, i2) x [j1, j2)
89    pub fn fold(&self, i1: usize, j1: usize, i2: usize, j2: usize) -> G::T {
90        let mut res = G::unit();
91        res = G::operate(&res, &self.accumulate0(i1, j1));
92        res = G::rinv_operate(&res, &self.accumulate0(i1, j2));
93        res = G::rinv_operate(&res, &self.accumulate0(i2, j1));
94        res = G::operate(&res, &self.accumulate0(i2, j2));
95        res
96    }
97    #[inline]
98    pub fn get(&self, i: usize, j: usize) -> G::T {
99        self.fold(i, j, i + 1, j + 1)
100    }
101    #[inline]
102    pub fn set(&mut self, i: usize, j: usize, x: G::T) {
103        let y = G::inverse(&self.get(i, j));
104        let z = G::operate(&y, &x);
105        self.update(i, j, z);
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use crate::{
113        algebra::{AdditiveOperation, MaxOperation},
114        tools::Xorshift,
115    };
116
117    const Q: usize = 100_000;
118    const A: u64 = 1_000_000_000;
119    const B: i64 = 1_000_000_000;
120
121    #[test]
122    fn test_binary_indexed_tree_2d() {
123        let mut rng = Xorshift::new();
124        const H: usize = 150;
125        const W: usize = 250;
126        let mut bit = BinaryIndexedTree2D::<AdditiveOperation<_>>::new(H, W);
127        let mut arr = vec![vec![0; W]; H];
128        for (i, j, v) in rng.random_iter((..H, ..W, ..A)).take(Q) {
129            bit.update(i, j, v);
130            arr[i][j] += v;
131        }
132        for arr in arr.iter_mut() {
133            for j in 0..W - 1 {
134                arr[j + 1] += arr[j];
135            }
136        }
137        for i in 0..H - 1 {
138            for j in 0..W {
139                arr[i + 1][j] += arr[i][j];
140            }
141        }
142        for (i, arr) in arr.iter().enumerate() {
143            for (j, a) in arr.iter().cloned().enumerate() {
144                assert_eq!(bit.accumulate(i, j), a);
145            }
146        }
147
148        let mut bit = BinaryIndexedTree2D::<MaxOperation<_>>::new(H, W);
149        let mut arr = vec![vec![0; W]; H];
150        for (i, j, v) in rng.random_iter((..H, ..W, ..A)).take(Q) {
151            bit.update(i, j, v);
152            arr[i][j] = std::cmp::max(arr[i][j], v);
153        }
154        for arr in arr.iter_mut() {
155            for j in 0..W - 1 {
156                arr[j + 1] = std::cmp::max(arr[j + 1], arr[j]);
157            }
158        }
159        for i in 0..H - 1 {
160            for j in 0..W {
161                arr[i + 1][j] = std::cmp::max(arr[i + 1][j], arr[i][j]);
162            }
163        }
164        for (i, arr) in arr.iter().enumerate() {
165            for (j, a) in arr.iter().cloned().enumerate() {
166                assert_eq!(bit.accumulate(i, j), a);
167            }
168        }
169    }
170
171    #[test]
172    fn test_group_binary_indexed_tree2d() {
173        let mut rng = Xorshift::new();
174        const H: usize = 15;
175        const W: usize = 25;
176        let mut bit = BinaryIndexedTree2D::<AdditiveOperation<_>>::new(H, W);
177        let mut arr = vec![vec![0; W + 1]; H + 1];
178        for (i, j, v) in rng.random_iter((..H, ..W, -B..B)).take(Q) {
179            bit.set(i, j, v);
180            arr[i + 1][j + 1] = v;
181        }
182        for arr in arr.iter_mut() {
183            for j in 0..W {
184                arr[j + 1] += arr[j];
185            }
186        }
187        for i in 0..H {
188            for j in 0..W + 1 {
189                arr[i + 1][j] += arr[i][j];
190            }
191        }
192        for i1 in 0..H {
193            for i2 in i1 + 1..H + 1 {
194                for j1 in 0..W {
195                    for j2 in j1 + 1..W + 1 {
196                        assert_eq!(
197                            bit.fold(i1, j1, i2, j2),
198                            arr[i2][j2] - arr[i2][j1] - arr[i1][j2] + arr[i1][j1]
199                        );
200                    }
201                }
202            }
203        }
204    }
205}