competitive/data_structure/
binary_indexed_tree.rs

1use super::{Group, Monoid};
2use std::fmt::{self, Debug, Formatter};
3
4pub struct BinaryIndexedTree<M>
5where
6    M: Monoid,
7{
8    n: usize,
9    bit: Vec<M::T>,
10}
11
12impl<M> Clone for BinaryIndexedTree<M>
13where
14    M: Monoid,
15{
16    fn clone(&self) -> Self {
17        Self {
18            n: self.n,
19            bit: self.bit.clone(),
20        }
21    }
22}
23
24impl<M> Debug for BinaryIndexedTree<M>
25where
26    M: Monoid<T: Debug>,
27{
28    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
29        f.debug_struct("BinaryIndexedTree")
30            .field("n", &self.n)
31            .field("bit", &self.bit)
32            .finish()
33    }
34}
35
36impl<M> BinaryIndexedTree<M>
37where
38    M: Monoid,
39{
40    #[inline]
41    pub fn new(n: usize) -> Self {
42        let bit = vec![M::unit(); n + 1];
43        Self { n, bit }
44    }
45    #[inline]
46    pub fn from_slice(slice: &[M::T]) -> Self {
47        let n = slice.len();
48        let mut bit = vec![M::unit(); n + 1];
49        for (i, x) in slice.iter().enumerate() {
50            let k = i + 1;
51            M::operate_assign(&mut bit[k], x);
52            let j = k + (k & (!k + 1));
53            if j <= n {
54                bit[j] = M::operate(&bit[j], &bit[k]);
55            }
56        }
57        Self { n, bit }
58    }
59    #[inline]
60    /// fold [0, k)
61    pub fn accumulate0(&self, mut k: usize) -> M::T {
62        debug_assert!(k <= self.n);
63        let mut res = M::unit();
64        while k > 0 {
65            res = M::operate(&res, &self.bit[k]);
66            k -= k & (!k + 1);
67        }
68        res
69    }
70    #[inline]
71    /// fold [0, k]
72    pub fn accumulate(&self, k: usize) -> M::T {
73        self.accumulate0(k + 1)
74    }
75    #[inline]
76    pub fn update(&mut self, k: usize, x: M::T) {
77        debug_assert!(k < self.n);
78        let mut k = k + 1;
79        while k <= self.n {
80            self.bit[k] = M::operate(&self.bit[k], &x);
81            k += k & (!k + 1);
82        }
83    }
84}
85
86impl<G: Group> BinaryIndexedTree<G> {
87    #[inline]
88    pub fn fold(&self, l: usize, r: usize) -> G::T {
89        debug_assert!(l <= self.n && r <= self.n);
90        G::operate(&G::inverse(&self.accumulate0(l)), &self.accumulate0(r))
91    }
92    #[inline]
93    pub fn get(&self, k: usize) -> G::T {
94        self.fold(k, k + 1)
95    }
96    #[inline]
97    pub fn set(&mut self, k: usize, x: G::T) {
98        self.update(k, G::operate(&G::inverse(&self.get(k)), &x));
99    }
100}
101
102impl<M> BinaryIndexedTree<M>
103where
104    M: Monoid<T: Ord>,
105{
106    #[inline]
107    pub fn lower_bound(&self, x: M::T) -> usize {
108        let n = self.n;
109        let mut acc = M::unit();
110        let mut pos = 0;
111        let mut k = n.next_power_of_two();
112        while k > 0 {
113            if k + pos <= n && M::operate(&acc, &self.bit[k + pos]) < x {
114                pos += k;
115                acc = M::operate(&acc, &self.bit[pos]);
116            }
117            k >>= 1;
118        }
119        pos
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use crate::{
127        algebra::{AdditiveOperation, MaxOperation},
128        algorithm::SliceBisectExt as _,
129        tools::Xorshift,
130    };
131
132    const N: usize = 10_000;
133    const Q: usize = 100_000;
134    const A: u64 = 1_000_000_000;
135    const B: i64 = 1_000_000_000;
136
137    #[test]
138    fn test_binary_indexed_tree() {
139        let mut rng = Xorshift::default();
140        let mut arr: Vec<_> = rng.random_iter(..A).take(N).collect();
141        let mut bit = BinaryIndexedTree::<AdditiveOperation<_>>::from_slice(&arr);
142        for (k, v) in rng.random_iter((..N, ..A)).take(Q) {
143            bit.update(k, v);
144            arr[k] += v;
145        }
146        for i in 0..N - 1 {
147            arr[i + 1] += arr[i];
148        }
149        for (i, a) in arr.iter().cloned().enumerate() {
150            assert_eq!(bit.accumulate(i), a);
151        }
152
153        let mut arr: Vec<_> = rng.random_iter(..A).take(N).collect();
154        let mut bit = BinaryIndexedTree::<MaxOperation<_>>::from_slice(&arr);
155        for (k, v) in rng.random_iter((..N, ..A)).take(Q) {
156            bit.update(k, v);
157            arr[k] = std::cmp::max(arr[k], v);
158        }
159        for i in 0..N - 1 {
160            arr[i + 1] = std::cmp::max(arr[i], arr[i + 1]);
161        }
162        for (i, a) in arr.iter().cloned().enumerate() {
163            assert_eq!(bit.accumulate(i), a);
164        }
165    }
166
167    #[test]
168    fn test_group_binary_indexed_tree() {
169        const N: usize = 2_000;
170        let mut rng = Xorshift::default();
171        let mut arr: Vec<_> = rng.random_iter(-B..B).take(N).collect();
172        let mut bit = BinaryIndexedTree::<AdditiveOperation<_>>::from_slice(&arr);
173        for (k, v) in rng.random_iter((..N, -B..B)).take(Q) {
174            bit.set(k, v);
175            arr[k] = v;
176        }
177        for i in 0..N - 1 {
178            arr[i + 1] += arr[i];
179        }
180        for i in 0..N {
181            for j in i + 1..N + 1 {
182                assert_eq!(
183                    bit.fold(i, j),
184                    arr[j - 1] - if i == 0 { 0 } else { arr[i - 1] }
185                );
186            }
187        }
188    }
189
190    #[test]
191    fn test_binary_indexed_tree_lower_bound() {
192        let mut rng = Xorshift::default();
193        let mut arr: Vec<_> = rng.random_iter(1..B).take(N).collect();
194        let mut bit = BinaryIndexedTree::<AdditiveOperation<_>>::from_slice(&arr);
195        for (k, v) in rng.random_iter((..N, 1..B)).take(Q) {
196            bit.set(k, v);
197            arr[k] = v;
198        }
199        for i in 0..N - 1 {
200            arr[i + 1] += arr[i];
201        }
202        for x in rng.random_iter(1..B * N as i64).take(Q) {
203            assert_eq!(bit.lower_bound(x), arr.position_bisect(|&a| a >= x));
204        }
205    }
206}