competitive/data_structure/
binary_indexed_tree_2d.rs

1use super::{Group, Monoid};
2use std::fmt::{self, Debug, Formatter};
3
4pub struct BinaryIndexedTree2D<M>
5where
6    M: Monoid,
7{
8    h: usize,
9    w: usize,
10    bit: Vec<Vec<M::T>>,
11}
12
13impl<M> Clone for BinaryIndexedTree2D<M>
14where
15    M: Monoid,
16{
17    fn clone(&self) -> Self {
18        Self {
19            h: self.h,
20            w: self.w,
21            bit: self.bit.clone(),
22        }
23    }
24}
25
26impl<M> Debug for BinaryIndexedTree2D<M>
27where
28    M: Monoid<T: Debug>,
29{
30    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
31        f.debug_struct("BinaryIndexedTree2D")
32            .field("h", &self.h)
33            .field("w", &self.w)
34            .field("bit", &self.bit)
35            .finish()
36    }
37}
38
39impl<M> BinaryIndexedTree2D<M>
40where
41    M: Monoid,
42{
43    #[inline]
44    pub fn new(h: usize, w: usize) -> Self {
45        let bit = vec![vec![M::unit(); w + 1]; h + 1];
46        Self { h, w, bit }
47    }
48    #[inline]
49    /// fold [0, i) x [0, j)
50    pub fn accumulate0(&self, i: usize, j: usize) -> M::T {
51        let mut res = M::unit();
52        let mut a = i;
53        while a > 0 {
54            let mut b = j;
55            while b > 0 {
56                res = M::operate(&res, &self.bit[a][b]);
57                b -= b & (!b + 1);
58            }
59            a -= a & (!a + 1);
60        }
61        res
62    }
63    #[inline]
64    /// fold [0, i] x [0, j]
65    pub fn accumulate(&self, i: usize, j: usize) -> M::T {
66        self.accumulate0(i + 1, j + 1)
67    }
68    #[inline]
69    pub fn update(&mut self, i: usize, j: usize, x: M::T) {
70        let mut a = i + 1;
71        while a <= self.h {
72            let mut b = j + 1;
73            while b <= self.w {
74                self.bit[a][b] = M::operate(&self.bit[a][b], &x);
75                b += b & (!b + 1);
76            }
77            a += a & (!a + 1);
78        }
79    }
80}
81
82impl<G> BinaryIndexedTree2D<G>
83where
84    G: Group,
85{
86    #[inline]
87    /// 0-indexed [i1, i2) x [j1, j2)
88    pub fn fold(&self, i1: usize, j1: usize, i2: usize, j2: usize) -> G::T {
89        let mut res = G::unit();
90        res = G::operate(&res, &self.accumulate0(i1, j1));
91        res = G::rinv_operate(&res, &self.accumulate0(i1, j2));
92        res = G::rinv_operate(&res, &self.accumulate0(i2, j1));
93        res = G::operate(&res, &self.accumulate0(i2, j2));
94        res
95    }
96    #[inline]
97    pub fn get(&self, i: usize, j: usize) -> G::T {
98        self.fold(i, j, i + 1, j + 1)
99    }
100    #[inline]
101    pub fn set(&mut self, i: usize, j: usize, x: G::T) {
102        let y = G::inverse(&self.get(i, j));
103        let z = G::operate(&y, &x);
104        self.update(i, j, z);
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::{
112        algebra::{AdditiveOperation, MaxOperation},
113        tools::Xorshift,
114    };
115
116    const Q: usize = 100_000;
117    const A: u64 = 1_000_000_000;
118    const B: i64 = 1_000_000_000;
119
120    #[test]
121    fn test_binary_indexed_tree_2d() {
122        let mut rng = Xorshift::default();
123        const H: usize = 150;
124        const W: usize = 250;
125        let mut bit = BinaryIndexedTree2D::<AdditiveOperation<_>>::new(H, W);
126        let mut arr = vec![vec![0; W]; H];
127        for (i, j, v) in rng.random_iter((..H, ..W, ..A)).take(Q) {
128            bit.update(i, j, v);
129            arr[i][j] += v;
130        }
131        for arr in arr.iter_mut() {
132            for j in 0..W - 1 {
133                arr[j + 1] += arr[j];
134            }
135        }
136        for i in 0..H - 1 {
137            let [a, b] = arr.get_disjoint_mut([i + 1, i]).unwrap();
138            for (a, b) in a.iter_mut().zip(b) {
139                *a += *b;
140            }
141        }
142        for (i, arr) in arr.iter().enumerate() {
143            for (j, a) in arr.iter().cloned().enumerate() {
144                assert_eq!(bit.accumulate(i, j), a);
145            }
146        }
147
148        let mut bit = BinaryIndexedTree2D::<MaxOperation<_>>::new(H, W);
149        let mut arr = vec![vec![0; W]; H];
150        for (i, j, v) in rng.random_iter((..H, ..W, ..A)).take(Q) {
151            bit.update(i, j, v);
152            arr[i][j] = std::cmp::max(arr[i][j], v);
153        }
154        for arr in arr.iter_mut() {
155            for j in 0..W - 1 {
156                arr[j + 1] = std::cmp::max(arr[j + 1], arr[j]);
157            }
158        }
159        for i in 0..H - 1 {
160            let [a, b] = arr.get_disjoint_mut([i + 1, i]).unwrap();
161            for (a, b) in a.iter_mut().zip(b) {
162                *a = std::cmp::max(*a, *b);
163            }
164        }
165        for (i, arr) in arr.iter().enumerate() {
166            for (j, a) in arr.iter().cloned().enumerate() {
167                assert_eq!(bit.accumulate(i, j), a);
168            }
169        }
170    }
171
172    #[test]
173    fn test_group_binary_indexed_tree2d() {
174        let mut rng = Xorshift::default();
175        const H: usize = 15;
176        const W: usize = 25;
177        let mut bit = BinaryIndexedTree2D::<AdditiveOperation<_>>::new(H, W);
178        let mut arr = vec![vec![0; W + 1]; H + 1];
179        for (i, j, v) in rng.random_iter((..H, ..W, -B..B)).take(Q) {
180            bit.set(i, j, v);
181            arr[i + 1][j + 1] = v;
182        }
183        for arr in arr.iter_mut() {
184            for j in 0..W {
185                arr[j + 1] += arr[j];
186            }
187        }
188        for i in 0..H {
189            let [a, b] = arr.get_disjoint_mut([i + 1, i]).unwrap();
190            for (a, b) in a.iter_mut().zip(b) {
191                *a += *b;
192            }
193        }
194        for i1 in 0..H {
195            for i2 in i1 + 1..H + 1 {
196                for j1 in 0..W {
197                    for j2 in j1 + 1..W + 1 {
198                        assert_eq!(
199                            bit.fold(i1, j1, i2, j2),
200                            arr[i2][j2] - arr[i2][j1] - arr[i1][j2] + arr[i1][j1]
201                        );
202                    }
203                }
204            }
205        }
206    }
207}