competitive/algorithm/
sqrt_decomposition.rs

1use super::{Magma, Monoid, RangeBoundsExt, Unital};
2use std::{marker::PhantomData, ops::RangeBounds};
3
4pub trait SqrtDecomposition: Sized {
5    type M: Monoid;
6    type B;
7    fn bucket(bsize: usize) -> Self::B;
8    fn update_bucket(bucket: &mut Self::B, x: &<Self::M as Magma>::T);
9    fn update_cell(
10        bucket: &mut Self::B,
11        cell: &mut <Self::M as Magma>::T,
12        x: &<Self::M as Magma>::T,
13    );
14    fn fold_bucket(bucket: &Self::B) -> <Self::M as Magma>::T;
15    fn fold_cell(bucket: &Self::B, cell: &<Self::M as Magma>::T) -> <Self::M as Magma>::T;
16    fn sqrt_decomposition(n: usize, bucket_size: Option<usize>) -> SqrtDecompositionBuckets<Self> {
17        let bucket_size = bucket_size
18            .unwrap_or((n as f64).sqrt().ceil() as usize)
19            .max(1);
20        let mut buckets = vec![];
21        for l in (0..n).step_by(bucket_size) {
22            let bsize = (l + bucket_size).min(n) - l;
23            let bucket = Self::bucket(bsize);
24            buckets.push(Bucket {
25                cells: vec![Self::M::unit(); bsize],
26                bucket,
27            });
28        }
29        SqrtDecompositionBuckets {
30            n,
31            bucket_size,
32            buckets,
33            _marker: PhantomData,
34        }
35    }
36}
37
38struct Bucket<T, B> {
39    cells: Vec<T>,
40    bucket: B,
41}
42
43pub struct SqrtDecompositionBuckets<S>
44where
45    S: SqrtDecomposition,
46{
47    n: usize,
48    bucket_size: usize,
49    buckets: Vec<Bucket<<S::M as Magma>::T, S::B>>,
50    _marker: PhantomData<fn() -> S>,
51}
52impl<S> SqrtDecompositionBuckets<S>
53where
54    S: SqrtDecomposition,
55{
56    pub fn update_cell(&mut self, i: usize, x: <S::M as Magma>::T) {
57        let Bucket { cells, bucket } = &mut self.buckets[i / self.bucket_size];
58        let j = i % self.bucket_size;
59        S::update_cell(bucket, &mut cells[j], &x);
60    }
61    pub fn update<R>(&mut self, range: R, x: <S::M as Magma>::T)
62    where
63        R: RangeBounds<usize>,
64    {
65        let range = range.to_range_bounded(0, self.n).expect("invalid range");
66        for (i, Bucket { cells, bucket }) in self.buckets.iter_mut().enumerate() {
67            let s = i * self.bucket_size;
68            let t = s + cells.len();
69            if t <= range.start || range.end <= s {
70            } else if range.start <= s && t <= range.end {
71                S::update_bucket(bucket, &x);
72            } else {
73                for cell in &mut cells[range.start.max(s) - s..range.end.min(t) - s] {
74                    S::update_cell(bucket, cell, &x);
75                }
76            }
77        }
78    }
79    pub fn get(&self, i: usize) -> <S::M as Magma>::T {
80        let Bucket { cells, bucket } = &self.buckets[i / self.bucket_size];
81        let j = i % self.bucket_size;
82        S::fold_cell(bucket, &cells[j])
83    }
84    pub fn fold<R>(&self, range: R) -> <S::M as Magma>::T
85    where
86        R: RangeBounds<usize>,
87    {
88        let range = range.to_range_bounded(0, self.n).expect("invalid range");
89        let mut res = S::M::unit();
90        for (i, Bucket { cells, bucket }) in self.buckets.iter().enumerate() {
91            let s = i * self.bucket_size;
92            let t = s + cells.len();
93            if t <= range.start || range.end <= s {
94            } else if range.start <= s && t <= range.end {
95                <S::M as Magma>::operate_assign(&mut res, &S::fold_bucket(bucket));
96            } else {
97                for cell in &cells[range.start.max(s) - s..range.end.min(t) - s] {
98                    <S::M as Magma>::operate_assign(&mut res, &S::fold_cell(bucket, cell));
99                }
100            }
101        }
102        res
103    }
104}
105
106pub struct RangeUpdateRangeFoldSqrtDecomposition<M>
107where
108    M: Monoid,
109{
110    _marker: PhantomData<fn() -> M>,
111}
112
113impl<M> SqrtDecomposition for RangeUpdateRangeFoldSqrtDecomposition<M>
114where
115    M: Monoid,
116{
117    type M = M;
118    // fold, lazy, size
119    type B = (M::T, M::T, usize);
120    fn bucket(bsize: usize) -> Self::B {
121        (M::unit(), M::unit(), bsize)
122    }
123    fn update_bucket(bucket: &mut Self::B, x: &<Self::M as Magma>::T) {
124        M::operate_assign(&mut bucket.1, x);
125    }
126    fn update_cell(
127        bucket: &mut Self::B,
128        cell: &mut <Self::M as Magma>::T,
129        x: &<Self::M as Magma>::T,
130    ) {
131        M::operate_assign(&mut bucket.0, x);
132        M::operate_assign(cell, x);
133    }
134    fn fold_bucket(bucket: &Self::B) -> <Self::M as Magma>::T {
135        M::operate(&bucket.0, &M::pow(bucket.1.clone(), bucket.2))
136    }
137    fn fold_cell(bucket: &Self::B, cell: &<Self::M as Magma>::T) -> <Self::M as Magma>::T {
138        M::operate(cell, &bucket.1)
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use crate::{
146        algebra::AdditiveOperation,
147        rand,
148        tools::{NotEmptySegment as Nes, Xorshift},
149    };
150
151    #[test]
152    fn test_sqrt_decomposition() {
153        let mut rng = Xorshift::default();
154        for _ in 0..100 {
155            rand!(rng, n: 1..100, mut a: [0i64..1000; n]);
156            let mut s =
157                RangeUpdateRangeFoldSqrtDecomposition::<AdditiveOperation<i64>>::sqrt_decomposition(
158                    n, None,
159                );
160            for (i, &a) in a.iter().enumerate() {
161                s.update_cell(i, a);
162            }
163            for _ in 0..100 {
164                rand!(rng, ty: 0..3, (l, r): Nes(n), x: 0i64..1000);
165                match ty {
166                    0 => {
167                        s.update(l..r, x);
168                        for a in &mut a[l..r] {
169                            *a += x;
170                        }
171                    }
172                    1 => {
173                        assert_eq!(s.fold(l..r), a[l..r].iter().sum::<i64>())
174                    }
175                    _ => {
176                        assert_eq!(s.get(l), a[l]);
177                    }
178                }
179            }
180        }
181    }
182}