Skip to main content

competitive/algorithm/
doubling.rs

1use super::{Group, LevelAncestor, Monoid, UndirectedSparseGraph};
2use std::collections::VecDeque;
3
4pub struct Doubling<M>
5where
6    M: Monoid,
7{
8    size: usize,
9    table: Vec<(usize, M::T)>,
10}
11
12impl<M> Doubling<M>
13where
14    M: Monoid,
15{
16    pub fn new(size: usize, f: impl Fn(usize) -> (usize, M::T)) -> Self {
17        let mut table = Vec::with_capacity(size);
18        for i in 0..size {
19            table.push(f(i));
20        }
21        Self { size, table }
22    }
23
24    pub fn double(&mut self) {
25        let base = self.table.len() - self.size;
26        for i in 0..self.size {
27            let &(to, ref val) = &self.table[base + i];
28            if to != !0 {
29                let &(to2, ref val2) = &self.table[base + to];
30                self.table.push((to2, M::operate(val, val2)));
31            } else {
32                self.table.push((!0, M::unit()));
33            }
34        }
35    }
36
37    pub fn kth(&mut self, mut pos: usize, mut k: usize) -> (usize, M::T) {
38        let mut x = M::unit();
39        for chunk in self.table.chunks_exact(self.size) {
40            if k & 1 == 1 {
41                let &(to, ref val) = &chunk[pos];
42                if to == !0 {
43                    return (!0, M::unit());
44                }
45                x = M::operate(&x, val);
46                pos = to;
47            }
48            k >>= 1;
49            if k == 0 {
50                break;
51            }
52        }
53        while k > 0 {
54            self.double();
55            if k & 1 == 1 {
56                let base = self.table.len() - self.size;
57                let &(to, ref val) = &self.table[base + pos];
58                if to == !0 {
59                    return (!0, M::unit());
60                }
61                x = M::operate(&x, val);
62                pos = to;
63            }
64            k >>= 1;
65        }
66        (pos, x)
67    }
68
69    /// queries: (pos, k)
70    /// Return: (pos, acc)
71    pub fn kth_multiple(
72        &self,
73        queries: impl IntoIterator<Item = (usize, usize)>,
74    ) -> Vec<(usize, M::T)> {
75        let (mut ks, mut results): (Vec<usize>, Vec<(usize, M::T)>) = queries
76            .into_iter()
77            .map(|(start, k)| (k, (start, M::unit())))
78            .unzip();
79        for chunk in self.table.chunks_exact(self.size) {
80            for (i, k) in ks.iter_mut().enumerate() {
81                if *k & 1 == 1 {
82                    let &(to, ref val) = &chunk[results[i].0];
83                    if to == !0 {
84                        results[i] = (!0, M::unit());
85                        *k = 0;
86                    } else {
87                        results[i].1 = M::operate(&results[i].1, val);
88                        results[i].0 = to;
89                    }
90                }
91                *k >>= 1;
92            }
93        }
94        if ks.iter().any(|&k| k > 0) {
95            let mut dp = self.table[self.table.len() - self.size..].to_vec();
96            while ks.iter().any(|&k| k > 0) {
97                let mut ndp = Vec::with_capacity(dp.len());
98                for i in 0..self.size {
99                    let &(to, ref val) = &dp[i];
100                    if to != !0 {
101                        let &(to2, ref val2) = &dp[to];
102                        ndp.push((to2, M::operate(val, val2)));
103                    } else {
104                        ndp.push((!0, M::unit()));
105                    }
106                }
107                dp = ndp;
108                for (i, k) in ks.iter_mut().enumerate() {
109                    if *k & 1 == 1 {
110                        let &(to, ref val) = &dp[results[i].0];
111                        if to == !0 {
112                            results[i] = (!0, M::unit());
113                            *k = 0;
114                        } else {
115                            results[i].1 = M::operate(&results[i].1, val);
116                            results[i].0 = to;
117                        }
118                    }
119                    *k >>= 1;
120                }
121            }
122        }
123        results
124    }
125
126    /// Return: (k, (pos, acc))
127    pub fn find_last(
128        &self,
129        mut pos: usize,
130        mut pred: impl FnMut(usize, &M::T) -> bool,
131    ) -> (usize, (usize, M::T)) {
132        let mut k = 0usize;
133        let mut x = M::unit();
134        assert!(pred(pos, &x));
135        for (i, chunk) in self.table.chunks_exact(self.size).enumerate().rev() {
136            let &(to, ref val) = &chunk[pos];
137            let nx = M::operate(&x, val);
138            if pred(to, &nx) {
139                x = nx;
140                pos = to;
141                k |= 1 << i;
142            }
143        }
144        (k, (pos, x))
145    }
146
147    /// Return: (k, (pos, acc))
148    pub fn find_first(
149        &self,
150        pos: usize,
151        mut pred: impl FnMut(usize, &M::T) -> bool,
152    ) -> Option<(usize, (usize, M::T))> {
153        let (mut k, (mut pos, mut x)) = self.find_last(pos, |k, x| !pred(k, x));
154        k += 1;
155        M::operate_assign(&mut x, &self.table[pos].1);
156        pos = self.table[pos].0;
157        if pred(pos, &x) {
158            Some((k, (pos, x)))
159        } else {
160            None
161        }
162    }
163}
164
165pub struct FunctionalGraphDoubling<M>
166where
167    M: Group,
168{
169    depth_to_cycle: Vec<usize>,
170    cycle_entry: Vec<usize>,
171    cycle_id: Vec<usize>,
172    cycle_pos: Vec<usize>,
173    cycles: Vec<Vec<usize>>,
174    cycle_prefix: Vec<Vec<M::T>>,
175    prefix_up: Vec<M::T>,
176    la: LevelAncestor,
177}
178
179impl<M> FunctionalGraphDoubling<M>
180where
181    M: Group,
182{
183    pub fn new(size: usize, f: impl Fn(usize) -> (usize, M::T)) -> Self {
184        let (next, value): (Vec<_>, Vec<_>) = (0..size).map(f).unzip();
185
186        let mut indeg = vec![0usize; size];
187        for &to in &next {
188            indeg[to] += 1;
189        }
190        let mut in_cycle = vec![true; size];
191        let mut deq = VecDeque::new();
192        for (u, &deg) in indeg.iter().enumerate() {
193            if deg == 0 {
194                deq.push_back(u);
195            }
196        }
197        while let Some(u) = deq.pop_front() {
198            in_cycle[u] = false;
199            indeg[next[u]] -= 1;
200            if indeg[next[u]] == 0 {
201                deq.push_back(next[u]);
202            }
203        }
204
205        let mut cycle_id = vec![!0; size];
206        let mut cycle_pos = vec![!0; size];
207        let mut cycles = Vec::new();
208        for i in 0..size {
209            if in_cycle[i] && cycle_id[i] == !0 {
210                let mut cycle = Vec::new();
211                let mut u = i;
212                loop {
213                    cycle_id[u] = cycles.len();
214                    cycle_pos[u] = cycle.len();
215                    cycle.push(u);
216                    u = next[u];
217                    if u == i {
218                        break;
219                    }
220                }
221                cycles.push(cycle);
222            }
223        }
224
225        let mut rev = vec![Vec::new(); size];
226        for u in 0..size {
227            rev[next[u]].push(u);
228        }
229
230        let mut depth_to_cycle = vec![0usize; size];
231        let mut cycle_entry = vec![!0; size];
232        let mut prefix_up = Vec::with_capacity(size);
233        prefix_up.resize_with(size, M::unit);
234        let mut q = VecDeque::new();
235        for i in 0..size {
236            if in_cycle[i] {
237                cycle_entry[i] = i;
238                prefix_up[i] = M::operate(&value[i], &M::unit());
239                q.push_back(i);
240            }
241        }
242        while let Some(u) = q.pop_front() {
243            for &v in &rev[u] {
244                if in_cycle[v] || cycle_entry[v] != !0 {
245                    continue;
246                }
247                cycle_entry[v] = cycle_entry[u];
248                depth_to_cycle[v] = depth_to_cycle[u] + 1;
249                cycle_id[v] = cycle_id[u];
250                prefix_up[v] = M::operate(&value[v], &prefix_up[u]);
251                q.push_back(v);
252            }
253        }
254
255        let mut cycle_prefix = Vec::with_capacity(cycles.len());
256        for cycle in &cycles {
257            let len = cycle.len();
258            let mut pref = Vec::with_capacity(2 * len + 1);
259            pref.push(M::unit());
260            for i in 0..2 * len {
261                let v = cycle[i % len];
262                let next_val = M::operate(pref.last().unwrap(), &value[v]);
263                pref.push(next_val);
264            }
265            cycle_prefix.push(pref);
266        }
267
268        let root = size;
269        let mut edges = Vec::with_capacity(size);
270        for u in 0..size {
271            if in_cycle[u] {
272                edges.push((u, root));
273            } else {
274                edges.push((u, next[u]));
275            }
276        }
277        let graph = UndirectedSparseGraph::from_edges(size + 1, edges);
278        let la = graph.level_ancestor(root);
279
280        Self {
281            depth_to_cycle,
282            cycle_entry,
283            cycle_id,
284            cycle_pos,
285            cycles,
286            cycle_prefix,
287            prefix_up,
288            la,
289        }
290    }
291
292    fn acc_to_ancestor(&self, u: usize, ancestor: usize) -> M::T {
293        let inv = M::inverse(&self.prefix_up[ancestor]);
294        M::operate(&self.prefix_up[u], &inv)
295    }
296
297    fn cycle_segment(&self, cycle_id: usize, start: usize, len: usize) -> M::T {
298        if len == 0 {
299            return M::unit();
300        }
301        let pref = &self.cycle_prefix[cycle_id];
302        let inv = M::inverse(&pref[start]);
303        M::operate(&inv, &pref[start + len])
304    }
305
306    fn cycle_acc_from(&self, entry: usize, steps: usize) -> M::T {
307        if steps == 0 {
308            return M::unit();
309        }
310        let cycle_id = self.cycle_id[entry];
311        let start = self.cycle_pos[entry];
312        let len = self.cycles[cycle_id].len();
313        let q = steps / len;
314        let r = steps % len;
315        let rem = self.cycle_segment(cycle_id, start, r);
316        if q == 0 {
317            return rem;
318        }
319        let full = self.cycle_segment(cycle_id, start, len);
320        let pow = M::pow(full, q);
321        M::operate(&pow, &rem)
322    }
323
324    fn cycle_jump_from(&self, entry: usize, steps: usize) -> usize {
325        if steps == 0 {
326            return entry;
327        }
328        let cycle_id = self.cycle_id[entry];
329        let start = self.cycle_pos[entry];
330        let len = self.cycles[cycle_id].len();
331        let idx = (start + steps % len) % len;
332        self.cycles[cycle_id][idx]
333    }
334
335    pub fn kth(&self, u: usize, k: usize) -> (usize, M::T) {
336        let depth = self.depth_to_cycle[u];
337        if k <= depth {
338            let ancestor = self.la.la(u, k).unwrap();
339            let acc = self.acc_to_ancestor(u, ancestor);
340            return (ancestor, acc);
341        }
342        let entry = self.cycle_entry[u];
343        let acc_tree = self.acc_to_ancestor(u, entry);
344        let steps = k - depth;
345        let pos = self.cycle_jump_from(entry, steps);
346        let acc_cycle = self.cycle_acc_from(entry, steps);
347        let acc = M::operate(&acc_tree, &acc_cycle);
348        (pos, acc)
349    }
350
351    /// queries: (pos, k)
352    /// Return: (pos, acc)
353    pub fn kth_multiple(
354        &self,
355        queries: impl IntoIterator<Item = (usize, usize)>,
356    ) -> Vec<(usize, M::T)> {
357        queries.into_iter().map(|(u, k)| self.kth(u, k)).collect()
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use crate::{
365        algebra::{AdditiveOperation, LinearOperation, Magma as _, Unital as _},
366        num::{Zero as _, mint_basic::MInt998244353},
367        tools::Xorshift,
368    };
369
370    #[test]
371    fn test_kth() {
372        let mut rng = Xorshift::default();
373        for _ in 0..200 {
374            let n = rng.random(1usize..100);
375            let to: Vec<_> = rng
376                .random_iter(0..=n)
377                .take(n)
378                .map(|x| x.wrapping_sub(1))
379                .collect();
380            let w: Vec<MInt998244353> = rng.random_iter(..).take(n).collect();
381            let mut doubling = Doubling::<AdditiveOperation<_>>::new(n, |i| (to[i], w[i]));
382            let mut queries = vec![];
383            let mut results = vec![];
384            for s in 0..n {
385                let mut pos = s;
386                let mut x = MInt998244353::zero();
387                for k in 0..100 {
388                    if pos == !0 {
389                        assert_eq!(doubling.kth(s, k), (pos, MInt998244353::zero()));
390                        queries.push((s, k));
391                        results.push((pos, MInt998244353::zero()));
392                    } else {
393                        assert_eq!(doubling.kth(s, k), (pos, x));
394                        x += w[pos];
395                        pos = to[pos];
396                    }
397                }
398            }
399            let doubling = Doubling::<AdditiveOperation<_>>::new(n, |i| (to[i], w[i]));
400            assert_eq!(doubling.kth_multiple(queries), results);
401        }
402    }
403
404    #[test]
405    fn test_find() {
406        let mut rng = Xorshift::default();
407        for _ in 0..200 {
408            let n = rng.random(1usize..100);
409            let to: Vec<_> = rng.random_iter(0..n).take(n).collect();
410            let w: Vec<u64> = rng.random_iter(1..100).take(n).collect();
411            let mut doubling = Doubling::<AdditiveOperation<_>>::new(n, |i| (to[i], w[i]));
412            for _ in 0..10 {
413                doubling.double();
414            }
415            for s in 0..n {
416                let mut k = 0usize;
417                let mut pos = s;
418                let mut acc = 0u64;
419                for x in 0u64..200 {
420                    while acc + w[pos] <= x {
421                        acc += w[pos];
422                        pos = to[pos];
423                        k += 1;
424                    }
425                    assert_eq!(doubling.find_last(s, |_, &v| v <= x), (k, (pos, acc)));
426                    assert_eq!(
427                        doubling.find_first(s, |_, &v| v > x),
428                        Some((k + 1, (to[pos], acc + w[pos])))
429                    );
430                }
431                assert_eq!(doubling.find_first(s, |_, &v| v > 1_000_000), None);
432            }
433        }
434    }
435
436    #[test]
437    fn test_functional_graph_doubling_kth() {
438        let mut rng = Xorshift::default();
439        type M = LinearOperation<MInt998244353>;
440        for _ in 0..200 {
441            let n = rng.random(1usize..50);
442            let to: Vec<_> = rng.random_iter(0..n).take(n).collect();
443            let w: Vec<_> = rng
444                .random_iter((1..MInt998244353::get_mod(), 0..MInt998244353::get_mod()))
445                .take(n)
446                .map(|(a, b)| (MInt998244353::new(a), MInt998244353::new(b)))
447                .collect();
448            let doubling = FunctionalGraphDoubling::<M>::new(n, |i| (to[i], w[i]));
449            let mut queries = vec![];
450            let mut results = vec![];
451            for s in 0..n {
452                let mut pos = s;
453                let mut acc = M::unit();
454                for k in 0..200 {
455                    assert_eq!(doubling.kth(s, k), (pos, acc));
456                    queries.push((s, k));
457                    results.push((pos, acc));
458                    acc = M::operate(&acc, &w[pos]);
459                    pos = to[pos];
460                }
461            }
462            assert_eq!(doubling.kth_multiple(queries), results);
463        }
464    }
465}