competitive/data_structure/
binary_indexed_tree.rs1use super::{Group, Monoid};
2use std::fmt::{self, Debug, Formatter};
3
4pub struct BinaryIndexedTree<M>
5where
6 M: Monoid,
7{
8 n: usize,
9 bit: Vec<M::T>,
10}
11
12impl<M> Clone for BinaryIndexedTree<M>
13where
14 M: Monoid,
15{
16 fn clone(&self) -> Self {
17 Self {
18 n: self.n,
19 bit: self.bit.clone(),
20 }
21 }
22}
23
24impl<M> Debug for BinaryIndexedTree<M>
25where
26 M: Monoid,
27 M::T: Debug,
28{
29 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
30 f.debug_struct("BinaryIndexedTree")
31 .field("n", &self.n)
32 .field("bit", &self.bit)
33 .finish()
34 }
35}
36
37impl<M> BinaryIndexedTree<M>
38where
39 M: Monoid,
40{
41 #[inline]
42 pub fn new(n: usize) -> Self {
43 let bit = vec![M::unit(); n + 1];
44 Self { n, bit }
45 }
46 #[inline]
47 pub fn from_slice(slice: &[M::T]) -> Self {
48 let n = slice.len();
49 let mut bit = vec![M::unit(); n + 1];
50 for (i, x) in slice.iter().enumerate() {
51 let k = i + 1;
52 M::operate_assign(&mut bit[k], x);
53 let j = k + (k & (!k + 1));
54 if j <= n {
55 bit[j] = M::operate(&bit[j], &bit[k]);
56 }
57 }
58 Self { n, bit }
59 }
60 #[inline]
61 pub fn accumulate0(&self, mut k: usize) -> M::T {
63 debug_assert!(k <= self.n);
64 let mut res = M::unit();
65 while k > 0 {
66 res = M::operate(&res, &self.bit[k]);
67 k -= k & (!k + 1);
68 }
69 res
70 }
71 #[inline]
72 pub fn accumulate(&self, k: usize) -> M::T {
74 self.accumulate0(k + 1)
75 }
76 #[inline]
77 pub fn update(&mut self, k: usize, x: M::T) {
78 debug_assert!(k < self.n);
79 let mut k = k + 1;
80 while k <= self.n {
81 self.bit[k] = M::operate(&self.bit[k], &x);
82 k += k & (!k + 1);
83 }
84 }
85}
86
87impl<G: Group> BinaryIndexedTree<G> {
88 #[inline]
89 pub fn fold(&self, l: usize, r: usize) -> G::T {
90 debug_assert!(l <= self.n && r <= self.n);
91 G::operate(&G::inverse(&self.accumulate0(l)), &self.accumulate0(r))
92 }
93 #[inline]
94 pub fn get(&self, k: usize) -> G::T {
95 self.fold(k, k + 1)
96 }
97 #[inline]
98 pub fn set(&mut self, k: usize, x: G::T) {
99 self.update(k, G::operate(&G::inverse(&self.get(k)), &x));
100 }
101}
102
103impl<M: Monoid> BinaryIndexedTree<M>
104where
105 M::T: Ord,
106{
107 #[inline]
108 pub fn lower_bound(&self, x: M::T) -> usize {
109 let n = self.n;
110 let mut acc = M::unit();
111 let mut pos = 0;
112 let mut k = n.next_power_of_two();
113 while k > 0 {
114 if k + pos <= n && M::operate(&acc, &self.bit[k + pos]) < x {
115 pos += k;
116 acc = M::operate(&acc, &self.bit[pos]);
117 }
118 k >>= 1;
119 }
120 pos
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use crate::{
128 algebra::{AdditiveOperation, MaxOperation},
129 algorithm::SliceBisectExt as _,
130 tools::Xorshift,
131 };
132
133 const N: usize = 10_000;
134 const Q: usize = 100_000;
135 const A: u64 = 1_000_000_000;
136 const B: i64 = 1_000_000_000;
137
138 #[test]
139 fn test_binary_indexed_tree() {
140 let mut rng = Xorshift::new();
141 let mut arr: Vec<_> = rng.random_iter(..A).take(N).collect();
142 let mut bit = BinaryIndexedTree::<AdditiveOperation<_>>::from_slice(&arr);
143 for (k, v) in rng.random_iter((..N, ..A)).take(Q) {
144 bit.update(k, v);
145 arr[k] += v;
146 }
147 for i in 0..N - 1 {
148 arr[i + 1] += arr[i];
149 }
150 for (i, a) in arr.iter().cloned().enumerate() {
151 assert_eq!(bit.accumulate(i), a);
152 }
153
154 let mut arr: Vec<_> = rng.random_iter(..A).take(N).collect();
155 let mut bit = BinaryIndexedTree::<MaxOperation<_>>::from_slice(&arr);
156 for (k, v) in rng.random_iter((..N, ..A)).take(Q) {
157 bit.update(k, v);
158 arr[k] = std::cmp::max(arr[k], v);
159 }
160 for i in 0..N - 1 {
161 arr[i + 1] = std::cmp::max(arr[i], arr[i + 1]);
162 }
163 for (i, a) in arr.iter().cloned().enumerate() {
164 assert_eq!(bit.accumulate(i), a);
165 }
166 }
167
168 #[test]
169 fn test_group_binary_indexed_tree() {
170 const N: usize = 2_000;
171 let mut rng = Xorshift::new();
172 let mut arr: Vec<_> = rng.random_iter(-B..B).take(N).collect();
173 let mut bit = BinaryIndexedTree::<AdditiveOperation<_>>::from_slice(&arr);
174 for (k, v) in rng.random_iter((..N, -B..B)).take(Q) {
175 bit.set(k, v);
176 arr[k] = v;
177 }
178 for i in 0..N - 1 {
179 arr[i + 1] += arr[i];
180 }
181 for i in 0..N {
182 for j in i + 1..N + 1 {
183 assert_eq!(
184 bit.fold(i, j),
185 arr[j - 1] - if i == 0 { 0 } else { arr[i - 1] }
186 );
187 }
188 }
189 }
190
191 #[test]
192 fn test_binary_indexed_tree_lower_bound() {
193 let mut rng = Xorshift::new();
194 let mut arr: Vec<_> = rng.random_iter(1..B).take(N).collect();
195 let mut bit = BinaryIndexedTree::<AdditiveOperation<_>>::from_slice(&arr);
196 for (k, v) in rng.random_iter((..N, 1..B)).take(Q) {
197 bit.set(k, v);
198 arr[k] = v;
199 }
200 for i in 0..N - 1 {
201 arr[i + 1] += arr[i];
202 }
203 for x in rng.random_iter(1..B * N as i64).take(Q) {
204 assert_eq!(bit.lower_bound(x), arr.position_bisect(|&a| a >= x));
205 }
206 }
207}