competitive/data_structure/
binary_indexed_tree.rs

1use super::{Group, Monoid};
2use std::fmt::{self, Debug, Formatter};
3
4pub struct BinaryIndexedTree<M>
5where
6    M: Monoid,
7{
8    n: usize,
9    bit: Vec<M::T>,
10}
11
12impl<M> Clone for BinaryIndexedTree<M>
13where
14    M: Monoid,
15{
16    fn clone(&self) -> Self {
17        Self {
18            n: self.n,
19            bit: self.bit.clone(),
20        }
21    }
22}
23
24impl<M> Debug for BinaryIndexedTree<M>
25where
26    M: Monoid,
27    M::T: Debug,
28{
29    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
30        f.debug_struct("BinaryIndexedTree")
31            .field("n", &self.n)
32            .field("bit", &self.bit)
33            .finish()
34    }
35}
36
37impl<M> BinaryIndexedTree<M>
38where
39    M: Monoid,
40{
41    #[inline]
42    pub fn new(n: usize) -> Self {
43        let bit = vec![M::unit(); n + 1];
44        Self { n, bit }
45    }
46    #[inline]
47    pub fn from_slice(slice: &[M::T]) -> Self {
48        let n = slice.len();
49        let mut bit = vec![M::unit(); n + 1];
50        for (i, x) in slice.iter().enumerate() {
51            let k = i + 1;
52            M::operate_assign(&mut bit[k], x);
53            let j = k + (k & (!k + 1));
54            if j <= n {
55                bit[j] = M::operate(&bit[j], &bit[k]);
56            }
57        }
58        Self { n, bit }
59    }
60    #[inline]
61    /// fold [0, k)
62    pub fn accumulate0(&self, mut k: usize) -> M::T {
63        debug_assert!(k <= self.n);
64        let mut res = M::unit();
65        while k > 0 {
66            res = M::operate(&res, &self.bit[k]);
67            k -= k & (!k + 1);
68        }
69        res
70    }
71    #[inline]
72    /// fold [0, k]
73    pub fn accumulate(&self, k: usize) -> M::T {
74        self.accumulate0(k + 1)
75    }
76    #[inline]
77    pub fn update(&mut self, k: usize, x: M::T) {
78        debug_assert!(k < self.n);
79        let mut k = k + 1;
80        while k <= self.n {
81            self.bit[k] = M::operate(&self.bit[k], &x);
82            k += k & (!k + 1);
83        }
84    }
85}
86
87impl<G: Group> BinaryIndexedTree<G> {
88    #[inline]
89    pub fn fold(&self, l: usize, r: usize) -> G::T {
90        debug_assert!(l <= self.n && r <= self.n);
91        G::operate(&G::inverse(&self.accumulate0(l)), &self.accumulate0(r))
92    }
93    #[inline]
94    pub fn get(&self, k: usize) -> G::T {
95        self.fold(k, k + 1)
96    }
97    #[inline]
98    pub fn set(&mut self, k: usize, x: G::T) {
99        self.update(k, G::operate(&G::inverse(&self.get(k)), &x));
100    }
101}
102
103impl<M: Monoid> BinaryIndexedTree<M>
104where
105    M::T: Ord,
106{
107    #[inline]
108    pub fn lower_bound(&self, x: M::T) -> usize {
109        let n = self.n;
110        let mut acc = M::unit();
111        let mut pos = 0;
112        let mut k = n.next_power_of_two();
113        while k > 0 {
114            if k + pos <= n && M::operate(&acc, &self.bit[k + pos]) < x {
115                pos += k;
116                acc = M::operate(&acc, &self.bit[pos]);
117            }
118            k >>= 1;
119        }
120        pos
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use crate::{
128        algebra::{AdditiveOperation, MaxOperation},
129        algorithm::SliceBisectExt as _,
130        tools::Xorshift,
131    };
132
133    const N: usize = 10_000;
134    const Q: usize = 100_000;
135    const A: u64 = 1_000_000_000;
136    const B: i64 = 1_000_000_000;
137
138    #[test]
139    fn test_binary_indexed_tree() {
140        let mut rng = Xorshift::new();
141        let mut arr: Vec<_> = rng.random_iter(..A).take(N).collect();
142        let mut bit = BinaryIndexedTree::<AdditiveOperation<_>>::from_slice(&arr);
143        for (k, v) in rng.random_iter((..N, ..A)).take(Q) {
144            bit.update(k, v);
145            arr[k] += v;
146        }
147        for i in 0..N - 1 {
148            arr[i + 1] += arr[i];
149        }
150        for (i, a) in arr.iter().cloned().enumerate() {
151            assert_eq!(bit.accumulate(i), a);
152        }
153
154        let mut arr: Vec<_> = rng.random_iter(..A).take(N).collect();
155        let mut bit = BinaryIndexedTree::<MaxOperation<_>>::from_slice(&arr);
156        for (k, v) in rng.random_iter((..N, ..A)).take(Q) {
157            bit.update(k, v);
158            arr[k] = std::cmp::max(arr[k], v);
159        }
160        for i in 0..N - 1 {
161            arr[i + 1] = std::cmp::max(arr[i], arr[i + 1]);
162        }
163        for (i, a) in arr.iter().cloned().enumerate() {
164            assert_eq!(bit.accumulate(i), a);
165        }
166    }
167
168    #[test]
169    fn test_group_binary_indexed_tree() {
170        const N: usize = 2_000;
171        let mut rng = Xorshift::new();
172        let mut arr: Vec<_> = rng.random_iter(-B..B).take(N).collect();
173        let mut bit = BinaryIndexedTree::<AdditiveOperation<_>>::from_slice(&arr);
174        for (k, v) in rng.random_iter((..N, -B..B)).take(Q) {
175            bit.set(k, v);
176            arr[k] = v;
177        }
178        for i in 0..N - 1 {
179            arr[i + 1] += arr[i];
180        }
181        for i in 0..N {
182            for j in i + 1..N + 1 {
183                assert_eq!(
184                    bit.fold(i, j),
185                    arr[j - 1] - if i == 0 { 0 } else { arr[i - 1] }
186                );
187            }
188        }
189    }
190
191    #[test]
192    fn test_binary_indexed_tree_lower_bound() {
193        let mut rng = Xorshift::new();
194        let mut arr: Vec<_> = rng.random_iter(1..B).take(N).collect();
195        let mut bit = BinaryIndexedTree::<AdditiveOperation<_>>::from_slice(&arr);
196        for (k, v) in rng.random_iter((..N, 1..B)).take(Q) {
197            bit.set(k, v);
198            arr[k] = v;
199        }
200        for i in 0..N - 1 {
201            arr[i + 1] += arr[i];
202        }
203        for x in rng.random_iter(1..B * N as i64).take(Q) {
204            assert_eq!(bit.lower_bound(x), arr.position_bisect(|&a| a >= x));
205        }
206    }
207}