competitive/data_structure/
compressed_binary_indexed_tree.rs1use super::{Monoid, SliceBisectExt};
2use std::{
3 fmt::{self, Debug},
4 marker::PhantomData,
5 ops::{Bound, RangeBounds},
6};
7
8pub struct CompressedBinaryIndexedTree<M, X, Inner>
9where
10 M: Monoid,
11{
12 compress: Vec<X>,
13 bits: Vec<Inner>,
14 _marker: PhantomData<fn() -> M>,
15}
16impl<M, X, Inner> Debug for CompressedBinaryIndexedTree<M, X, Inner>
17where
18 M: Monoid,
19 X: Debug,
20 Inner: Debug,
21{
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 f.debug_struct("CompressedBinaryIndexedTree")
24 .field("compress", &self.compress)
25 .field("bits", &self.bits)
26 .finish()
27 }
28}
29impl<M, X, Inner> Clone for CompressedBinaryIndexedTree<M, X, Inner>
30where
31 M: Monoid,
32 X: Clone,
33 Inner: Clone,
34{
35 fn clone(&self) -> Self {
36 Self {
37 compress: self.compress.clone(),
38 bits: self.bits.clone(),
39 _marker: self._marker,
40 }
41 }
42}
43impl<M, X, Inner> Default for CompressedBinaryIndexedTree<M, X, Inner>
44where
45 M: Monoid,
46{
47 fn default() -> Self {
48 Self {
49 compress: Default::default(),
50 bits: Default::default(),
51 _marker: Default::default(),
52 }
53 }
54}
55#[repr(transparent)]
56pub struct Tag<M>(M::T)
57where
58 M: Monoid;
59impl<M> Debug for Tag<M>
60where
61 M: Monoid<T: Debug>,
62{
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 self.0.fmt(f)
65 }
66}
67impl<M> Clone for Tag<M>
68where
69 M: Monoid,
70{
71 fn clone(&self) -> Self {
72 Self(self.0.clone())
73 }
74}
75
76macro_rules! impl_compressed_binary_indexed_tree {
77 (@tuple ($($l:tt)*) ($($r:tt)*) $T:ident) => {
78 ($($l)* $T $($r)*,)
79 };
80 (@tuple ($($l:tt)*) ($($r:tt)*) $T:ident $($Rest:ident)+) => {
81 ($($l)* $T $($r)*, impl_compressed_binary_indexed_tree!(@tuple ($($l)*) ($($r)*) $($Rest)+))
82 };
83 (@cst $M:ident) => {
84 Tag<$M>
85 };
86 (@cst $M:ident $T:ident $($Rest:ident)*) => {
87 CompressedBinaryIndexedTree<$M, $T, impl_compressed_binary_indexed_tree!(@cst $M $($Rest)*)>
88 };
89 (@from_iter $M:ident $points:ident $T:ident) => {{
90 let mut compress: Vec<_> = $points.into_iter().map(|t| t.0.clone()).collect();
91 compress.sort_unstable();
92 compress.dedup();
93 let n = compress.len();
94 Self {
95 compress,
96 bits: vec![Tag(M::unit()); n + 1],
97 _marker: PhantomData,
98 }
99 }};
100 (@from_iter $M:ident $points:ident $T:ident $U:ident $($Rest:ident)*) => {{
101 let mut compress: Vec<_> = $points.clone().into_iter().map(|t| t.0.clone()).collect();
102 compress.sort_unstable();
103 compress.dedup();
104 let n = compress.len();
105 let mut bits = vec![CompressedBinaryIndexedTree::default(); n + 1];
106 let mut ps = vec![vec![]; n + 1];
107 for (x, q) in $points {
108 let i = compress.position_bisect(|c| x <= c);
109 ps[i + 1].push(q);
110 }
111 for i in 1..=n {
112 bits[i] = CompressedBinaryIndexedTree::<_, _, impl_compressed_binary_indexed_tree!(@cst $M $($Rest)*)>::from_iter(ps[i].iter().cloned());
113 let j = i + (i & (!i + 1));
114 if j <= n {
115 let [s, ns] = ps.get_disjoint_mut([i, j]).unwrap();
116 ns.append(s);
117 }
118 }
119 Self {
120 compress,
121 bits,
122 _marker: PhantomData,
123 }
124 }};
125 (@acc $e:expr, $rng:ident $T:ident) => {
126 $e.0
127 };
128 (@acc $e:expr, $rng:ident $T:ident $($Rest:ident)+) => {
129 $e.accumulate(&$rng.1)
130 };
131 (@update $e:expr, $M:ident $key:ident $x:ident $T:ident) => {
132 $M::operate_assign(&mut $e.0, $x);
133 };
134 (@update $e:expr, $M:ident $key:ident $x:ident $T:ident $($Rest:ident)+) => {
135 $e.update(&$key.1, $x);
136 };
137 (@impl $C:ident $($T:ident)*, $($Q:ident)*) => {
138 impl<M, $($T,)*> impl_compressed_binary_indexed_tree!(@cst M $($T)*)
139 where
140 M: Monoid,
141 $($T: Clone + Ord,)*
142 {
143 pub fn new(points: &[impl_compressed_binary_indexed_tree!(@tuple () () $($T)*)]) -> Self {
144 Self::from_iter(points)
145 }
146 fn from_iter<'a, Iter>(points: Iter) -> Self
147 where
148 $($T: 'a,)*
149 Iter: IntoIterator<Item = &'a impl_compressed_binary_indexed_tree!(@tuple () () $($T)*)> + Clone,
150 {
151 impl_compressed_binary_indexed_tree!(@from_iter M points $($T)*)
152 }
153 pub fn accumulate<$($Q,)*>(&self, range: &impl_compressed_binary_indexed_tree!(@tuple () () $($Q)*)) -> M::T
154 where
155 $($Q: RangeBounds<$T>,)*
156 {
157 match range.0.start_bound() {
158 Bound::Unbounded => (),
159 _ => panic!("expected `Bound::Unbounded`"),
160 };
161 let mut k = match range.0.end_bound() {
162 Bound::Included(index) => self.compress.position_bisect(|x| x > &index),
163 Bound::Excluded(index) => self.compress.position_bisect(|x| x >= &index),
164 Bound::Unbounded => self.compress.len(),
165 };
166 let mut x = M::unit();
167 while k > 0 {
168 x = M::operate(&x, &impl_compressed_binary_indexed_tree!(@acc self.bits[k], range $($T)*));
169 k -= k & (!k + 1);
170 }
171 x
172 }
173 pub fn update(&mut self, key: &impl_compressed_binary_indexed_tree!(@tuple () () $($T)*), x: &M::T) {
174 let mut k = self.compress.binary_search(&key.0).expect("not exist key") + 1;
175 while k < self.bits.len() {
176 impl_compressed_binary_indexed_tree!(@update self.bits[k], M key x $($T)*);
177 k += k & (!k + 1);
178 }
179 }
180 }
181 pub type $C<M, $($T),*> = impl_compressed_binary_indexed_tree!(@cst M $($T)*);
182 };
183 (@inner [$C:ident][$($T:ident)*][$($Q:ident)*][]) => {
184 impl_compressed_binary_indexed_tree!(@impl $C $($T)*, $($Q)*);
185 };
186 (@inner [$C:ident][$($T:ident)*][$($Q:ident)*][$D:ident $U:ident $R:ident $($Rest:ident)*]) => {
187 impl_compressed_binary_indexed_tree!(@impl $C $($T)*, $($Q)*);
188 impl_compressed_binary_indexed_tree!(@inner [$D][$($T)* $U][$($Q)* $R][$($Rest)*]);
189 };
190 ($C:ident $T:ident $Q:ident $($Rest:ident)* $(;$($t:tt)*)?) => {
191 impl_compressed_binary_indexed_tree!(@inner [$C][$T][$Q][$($Rest)*]);
192 };
193 ($($t:tt)*) => {
194 compile_error!($($t:tt)*)
195 }
196}
197
198impl_compressed_binary_indexed_tree!(
199 CompressedBinaryIndexedTree1d A QA
200 CompressedBinaryIndexedTree2d B QB
201 CompressedBinaryIndexedTree3d C QC
202 CompressedBinaryIndexedTree4d D QD;
203 CompressedBinaryIndexedTree5d E QE
204 CompressedBinaryIndexedTree6d F QF
205 CompressedBinaryIndexedTree7d G QG
206 CompressedBinaryIndexedTree8d H QH
207 CompressedBinaryIndexedTree9d I QI
208);
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::{algebra::AdditiveOperation, tools::Xorshift};
214 use std::{collections::HashMap, ops::RangeTo};
215
216 #[test]
217 fn test_bit4d() {
218 let mut rng = Xorshift::default();
219 const N: usize = 100;
220 const Q: usize = 5000;
221 const A: RangeTo<u64> = ..1_000;
222 let mut points: Vec<_> = rng.random_iter(((A), (A, (A, (A,))))).take(N).collect();
223 points.sort();
224 points.dedup();
225 let mut map: HashMap<_, _> = points.iter().map(|p| (p, 0u64)).collect();
226 let mut bit =
227 CompressedBinaryIndexedTree4d::<AdditiveOperation<u64>, _, _, _, _>::new(&points);
228 for _ in 0..Q {
229 let p = &points[rng.random(0..points.len())];
230 let x = rng.random(A);
231 *map.get_mut(p).unwrap() += x;
232 bit.update(p, &x);
233
234 let mut f = || {
235 (
236 Bound::Unbounded,
237 match rng.rand(3) {
238 0 => Bound::Excluded(rng.random(A)),
239 1 => Bound::Included(rng.random(A)),
240 _ => Bound::Unbounded,
241 },
242 )
243 };
244
245 let range = (f(), (f(), (f(), (f(),))));
246 let (r0, (r1, (r2, (r3,)))) = range;
247 let expected: u64 = map
248 .iter()
249 .filter_map(|((p0, (p1, (p2, (p3,)))), x)| {
250 if RangeBounds::contains(&r0, p0)
251 && RangeBounds::contains(&r1, p1)
252 && RangeBounds::contains(&r2, p2)
253 && RangeBounds::contains(&r3, p3)
254 {
255 Some(*x)
256 } else {
257 None
258 }
259 })
260 .sum();
261 let result = bit.accumulate(&range);
262 assert_eq!(expected, result);
263 }
264 }
265}