competitive/data_structure/
binary_indexed_tree.rs1use super::{Group, Monoid};
2use std::fmt::{self, Debug, Formatter};
3
4pub struct BinaryIndexedTree<M>
5where
6 M: Monoid,
7{
8 n: usize,
9 bit: Vec<M::T>,
10}
11
12impl<M> Clone for BinaryIndexedTree<M>
13where
14 M: Monoid,
15{
16 fn clone(&self) -> Self {
17 Self {
18 n: self.n,
19 bit: self.bit.clone(),
20 }
21 }
22}
23
24impl<M> Debug for BinaryIndexedTree<M>
25where
26 M: Monoid<T: Debug>,
27{
28 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
29 f.debug_struct("BinaryIndexedTree")
30 .field("n", &self.n)
31 .field("bit", &self.bit)
32 .finish()
33 }
34}
35
36impl<M> BinaryIndexedTree<M>
37where
38 M: Monoid,
39{
40 #[inline]
41 pub fn new(n: usize) -> Self {
42 let bit = vec![M::unit(); n + 1];
43 Self { n, bit }
44 }
45 #[inline]
46 pub fn from_slice(slice: &[M::T]) -> Self {
47 let n = slice.len();
48 let mut bit = vec![M::unit(); n + 1];
49 for (i, x) in slice.iter().enumerate() {
50 let k = i + 1;
51 M::operate_assign(&mut bit[k], x);
52 let j = k + (k & (!k + 1));
53 if j <= n {
54 bit[j] = M::operate(&bit[j], &bit[k]);
55 }
56 }
57 Self { n, bit }
58 }
59 #[inline]
60 pub fn accumulate0(&self, mut k: usize) -> M::T {
62 debug_assert!(k <= self.n);
63 let mut res = M::unit();
64 while k > 0 {
65 res = M::operate(&res, &self.bit[k]);
66 k -= k & (!k + 1);
67 }
68 res
69 }
70 #[inline]
71 pub fn accumulate(&self, k: usize) -> M::T {
73 self.accumulate0(k + 1)
74 }
75 #[inline]
76 pub fn update(&mut self, k: usize, x: M::T) {
77 debug_assert!(k < self.n);
78 let mut k = k + 1;
79 while k <= self.n {
80 self.bit[k] = M::operate(&self.bit[k], &x);
81 k += k & (!k + 1);
82 }
83 }
84}
85
86impl<G: Group> BinaryIndexedTree<G> {
87 #[inline]
88 pub fn fold(&self, l: usize, r: usize) -> G::T {
89 debug_assert!(l <= self.n && r <= self.n);
90 G::operate(&G::inverse(&self.accumulate0(l)), &self.accumulate0(r))
91 }
92 #[inline]
93 pub fn get(&self, k: usize) -> G::T {
94 self.fold(k, k + 1)
95 }
96 #[inline]
97 pub fn set(&mut self, k: usize, x: G::T) {
98 self.update(k, G::operate(&G::inverse(&self.get(k)), &x));
99 }
100}
101
102impl<M> BinaryIndexedTree<M>
103where
104 M: Monoid<T: Ord>,
105{
106 #[inline]
107 pub fn lower_bound(&self, x: M::T) -> usize {
108 let n = self.n;
109 let mut acc = M::unit();
110 let mut pos = 0;
111 let mut k = n.next_power_of_two();
112 while k > 0 {
113 if k + pos <= n && M::operate(&acc, &self.bit[k + pos]) < x {
114 pos += k;
115 acc = M::operate(&acc, &self.bit[pos]);
116 }
117 k >>= 1;
118 }
119 pos
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126 use crate::{
127 algebra::{AdditiveOperation, MaxOperation},
128 algorithm::SliceBisectExt as _,
129 tools::Xorshift,
130 };
131
132 const N: usize = 10_000;
133 const Q: usize = 100_000;
134 const A: u64 = 1_000_000_000;
135 const B: i64 = 1_000_000_000;
136
137 #[test]
138 fn test_binary_indexed_tree() {
139 let mut rng = Xorshift::default();
140 let mut arr: Vec<_> = rng.random_iter(..A).take(N).collect();
141 let mut bit = BinaryIndexedTree::<AdditiveOperation<_>>::from_slice(&arr);
142 for (k, v) in rng.random_iter((..N, ..A)).take(Q) {
143 bit.update(k, v);
144 arr[k] += v;
145 }
146 for i in 0..N - 1 {
147 arr[i + 1] += arr[i];
148 }
149 for (i, a) in arr.iter().cloned().enumerate() {
150 assert_eq!(bit.accumulate(i), a);
151 }
152
153 let mut arr: Vec<_> = rng.random_iter(..A).take(N).collect();
154 let mut bit = BinaryIndexedTree::<MaxOperation<_>>::from_slice(&arr);
155 for (k, v) in rng.random_iter((..N, ..A)).take(Q) {
156 bit.update(k, v);
157 arr[k] = std::cmp::max(arr[k], v);
158 }
159 for i in 0..N - 1 {
160 arr[i + 1] = std::cmp::max(arr[i], arr[i + 1]);
161 }
162 for (i, a) in arr.iter().cloned().enumerate() {
163 assert_eq!(bit.accumulate(i), a);
164 }
165 }
166
167 #[test]
168 fn test_group_binary_indexed_tree() {
169 const N: usize = 2_000;
170 let mut rng = Xorshift::default();
171 let mut arr: Vec<_> = rng.random_iter(-B..B).take(N).collect();
172 let mut bit = BinaryIndexedTree::<AdditiveOperation<_>>::from_slice(&arr);
173 for (k, v) in rng.random_iter((..N, -B..B)).take(Q) {
174 bit.set(k, v);
175 arr[k] = v;
176 }
177 for i in 0..N - 1 {
178 arr[i + 1] += arr[i];
179 }
180 for i in 0..N {
181 for j in i + 1..N + 1 {
182 assert_eq!(
183 bit.fold(i, j),
184 arr[j - 1] - if i == 0 { 0 } else { arr[i - 1] }
185 );
186 }
187 }
188 }
189
190 #[test]
191 fn test_binary_indexed_tree_lower_bound() {
192 let mut rng = Xorshift::default();
193 let mut arr: Vec<_> = rng.random_iter(1..B).take(N).collect();
194 let mut bit = BinaryIndexedTree::<AdditiveOperation<_>>::from_slice(&arr);
195 for (k, v) in rng.random_iter((..N, 1..B)).take(Q) {
196 bit.set(k, v);
197 arr[k] = v;
198 }
199 for i in 0..N - 1 {
200 arr[i + 1] += arr[i];
201 }
202 for x in rng.random_iter(1..B * N as i64).take(Q) {
203 assert_eq!(bit.lower_bound(x), arr.position_bisect(|&a| a >= x));
204 }
205 }
206}