Skip to main content

competitive/tree/
static_top_tree.rs

1use super::{Magma, Monoid, UndirectedSparseGraph, Unital};
2use std::mem::MaybeUninit;
3
4pub trait MonoidCluster {
5    type Vertex;
6    type Edge;
7    type PathMonoid: Monoid;
8    type PointMonoid: Monoid;
9
10    fn add_edge(
11        path: &<<Self as MonoidCluster>::PathMonoid as Magma>::T,
12    ) -> <<Self as MonoidCluster>::PointMonoid as Magma>::T;
13    fn add_vertex(
14        point: &<<Self as MonoidCluster>::PointMonoid as Magma>::T,
15        vertex: &Self::Vertex,
16        parent_edge: Option<&Self::Edge>,
17    ) -> <<Self as MonoidCluster>::PathMonoid as Magma>::T;
18}
19
20pub trait Cluster {
21    type Vertex;
22    type Edge;
23    type Path: Clone;
24    type Point: Clone;
25
26    fn unit_path() -> Self::Path;
27    fn unit_point() -> Self::Point;
28    fn compress(left: &Self::Path, right: &Self::Path) -> Self::Path;
29    fn rake(left: &Self::Point, right: &Self::Point) -> Self::Point;
30    fn add_edge(path: &Self::Path) -> Self::Point;
31    fn add_vertex(
32        point: &Self::Point,
33        vertex: &Self::Vertex,
34        parent_edge: Option<&Self::Edge>,
35    ) -> Self::Path;
36
37    fn vertex(vertex: &Self::Vertex, parent_edge: Option<&Self::Edge>) -> Self::Path {
38        Self::add_vertex(&Self::unit_point(), vertex, parent_edge)
39    }
40}
41
42impl<C> Cluster for C
43where
44    C: MonoidCluster,
45{
46    type Vertex = C::Vertex;
47    type Edge = C::Edge;
48    type Path = <<C as MonoidCluster>::PathMonoid as Magma>::T;
49    type Point = <<C as MonoidCluster>::PointMonoid as Magma>::T;
50
51    fn unit_path() -> Self::Path {
52        <C::PathMonoid as Unital>::unit()
53    }
54
55    fn unit_point() -> Self::Point {
56        <C::PointMonoid as Unital>::unit()
57    }
58
59    fn compress(left: &Self::Path, right: &Self::Path) -> Self::Path {
60        <C::PathMonoid as Magma>::operate(left, right)
61    }
62
63    fn rake(left: &Self::Point, right: &Self::Point) -> Self::Point {
64        <C::PointMonoid as Magma>::operate(left, right)
65    }
66
67    fn add_edge(path: &Self::Path) -> Self::Point {
68        <C as MonoidCluster>::add_edge(path)
69    }
70
71    fn add_vertex(
72        point: &Self::Point,
73        vertex: &Self::Vertex,
74        parent_edge: Option<&Self::Edge>,
75    ) -> Self::Path {
76        <C as MonoidCluster>::add_vertex(point, vertex, parent_edge)
77    }
78}
79
80#[derive(Clone)]
81pub struct StaticTopTree {
82    root: usize,
83    n: usize,
84    edge_child: Vec<usize>,
85    parent_edge: Vec<usize>,
86    compressed: Vec<InnerNode>,
87    raked: Vec<InnerNode>,
88    vertex_links: Vec<VertexLinks>,
89    compress_roots: Vec<Option<Slot>>,
90    rake_roots: Vec<Option<Slot>>,
91}
92
93#[derive(Clone)]
94struct InnerNode {
95    left: Slot,
96    right: Slot,
97    parent: usize,
98}
99
100#[derive(Clone)]
101struct InnerValue<T> {
102    left: T,
103    right: T,
104}
105
106pub struct StaticTopTreeDp<'a, C>
107where
108    C: Cluster,
109{
110    tree: &'a StaticTopTree,
111    vertices: Vec<<C as Cluster>::Vertex>,
112    edges: Vec<<C as Cluster>::Edge>,
113    compressed: Vec<InnerValue<<C as Cluster>::Path>>,
114    raked: Vec<InnerValue<<C as Cluster>::Point>>,
115    light_points: Vec<<C as Cluster>::Point>,
116    all_point: <C as Cluster>::Point,
117}
118
119#[derive(Debug, Clone, Copy)]
120struct VertexLinks {
121    heavy_parent: usize,
122    compress_parent: usize,
123    rake_parent: usize,
124}
125
126#[derive(Debug)]
127struct Node {
128    depth: usize,
129    slot: Slot,
130}
131
132#[derive(Debug, Clone, Copy)]
133enum Slot {
134    CompressLeaf(usize),
135    CompressInner(usize),
136    RakeLeaf(usize),
137    RakeInner(usize),
138}
139
140struct RootedInfo {
141    order: Vec<usize>,
142    children_start: Vec<usize>,
143    children: Vec<usize>,
144    edge_child: Vec<usize>,
145    parent_edge: Vec<usize>,
146}
147
148impl UndirectedSparseGraph {
149    pub fn static_top_tree(&self, root: usize) -> StaticTopTree {
150        StaticTopTree::new(root, self)
151    }
152}
153
154impl StaticTopTree {
155    pub fn new(root: usize, graph: &UndirectedSparseGraph) -> Self {
156        let n = graph.vertices_size();
157        assert!(n > 0);
158        assert!(root < n);
159        assert_eq!(graph.edges_size() + 1, n);
160
161        let RootedInfo {
162            order,
163            children_start,
164            children,
165            edge_child,
166            parent_edge,
167        } = rooted_children(graph, root);
168        let mut this = Self {
169            root,
170            n,
171            edge_child,
172            parent_edge,
173            compressed: Vec::with_capacity(n.saturating_sub(1)),
174            raked: Vec::with_capacity(n.saturating_sub(1)),
175            vertex_links: vec![
176                VertexLinks {
177                    heavy_parent: usize::MAX,
178                    compress_parent: usize::MAX,
179                    rake_parent: usize::MAX,
180                };
181                n
182            ],
183            compress_roots: vec![None; n],
184            rake_roots: vec![None; n],
185        };
186
187        let mut heavy_child = vec![usize::MAX; n];
188        let mut mask = vec![1u64; n];
189        let mut buckets: [Vec<Node>; 64] = std::array::from_fn(|_| Vec::new());
190
191        for &u in order.iter().rev() {
192            let children = &children[children_start[u]..children_start[u + 1]];
193            let mut sum_rake = 0u64;
194            for &v in children {
195                sum_rake += bit_ceil(mask[v]) << 1;
196            }
197            mask[u] = bit_ceil(sum_rake);
198            for &v in children {
199                let child = bit_ceil(mask[v]) << 1;
200                let depth = bit_ceil(sum_rake - child).trailing_zeros() as usize;
201                let step = 1u64 << depth;
202                let cand = ((mask[v] + step - 1) >> depth << depth) + step;
203                if cand <= mask[u] {
204                    mask[u] = cand;
205                    heavy_child[u] = v;
206                }
207            }
208
209            let mut has = 0u64;
210            let mut num_light = 0usize;
211            for &v in children {
212                if v == heavy_child[u] {
213                    continue;
214                }
215                num_light += 1;
216                let child = bit_ceil(mask[v]) << 1;
217                let depth = bit_ceil(sum_rake - child).trailing_zeros() as usize;
218                this.build_compress(v, &heavy_child, &mask);
219                buckets[depth].push(Node {
220                    depth,
221                    slot: Slot::RakeLeaf(v),
222                });
223                has |= 1u64 << depth;
224            }
225            if num_light == 0 {
226                continue;
227            }
228
229            while num_light > 1 {
230                let left = pop_bucket(&mut buckets, &mut has);
231                let right = pop_bucket(&mut buckets, &mut has);
232                let node = this.merge_rake(left, right);
233                let depth = node.depth;
234                buckets[depth].push(node);
235                has |= 1u64 << depth;
236                num_light -= 1;
237            }
238
239            let root = pop_bucket(&mut buckets, &mut has);
240            this.rake_roots[u] = Some(root.slot);
241            for &v0 in children {
242                if v0 == heavy_child[u] {
243                    continue;
244                }
245                let rake_parent = this.vertex_links[v0].rake_parent;
246                let mut v = v0;
247                while v != usize::MAX {
248                    this.vertex_links[v].heavy_parent = u;
249                    this.vertex_links[v].rake_parent = rake_parent;
250                    v = heavy_child[v];
251                }
252            }
253        }
254
255        this.build_compress(root, &heavy_child, &mask);
256        this
257    }
258
259    pub fn vertices_size(&self) -> usize {
260        self.n
261    }
262
263    pub fn edges_size(&self) -> usize {
264        self.edge_child.len()
265    }
266
267    pub fn dp<C>(
268        &self,
269        vertices: Vec<<C as Cluster>::Vertex>,
270        edges: Vec<<C as Cluster>::Edge>,
271    ) -> StaticTopTreeDp<'_, C>
272    where
273        C: Cluster,
274    {
275        StaticTopTreeDp::new(self, vertices, edges)
276    }
277
278    fn build_compress(&mut self, mut vertex: usize, heavy_child: &[usize], mask: &[u64]) -> Node {
279        let start = vertex;
280        let mut stack = Vec::new();
281        while vertex != usize::MAX {
282            stack.push(Node {
283                depth: bit_ceil(mask[vertex]).trailing_zeros() as usize,
284                slot: Slot::CompressLeaf(vertex),
285            });
286            loop {
287                let len = stack.len();
288                if len >= 3
289                    && (stack[len - 3].depth == stack[len - 2].depth
290                        || stack[len - 3].depth <= stack[len - 1].depth)
291                {
292                    let tail = stack.pop().unwrap();
293                    let right = stack.pop().unwrap();
294                    let left = stack.pop().unwrap();
295                    let node = self.merge_compress(left, right);
296                    stack.push(node);
297                    stack.push(tail);
298                } else if len >= 2 && stack[len - 2].depth <= stack[len - 1].depth {
299                    let right = stack.pop().unwrap();
300                    let left = stack.pop().unwrap();
301                    stack.push(self.merge_compress(left, right));
302                } else {
303                    break;
304                }
305            }
306            vertex = heavy_child[vertex];
307        }
308        while stack.len() > 1 {
309            let right = stack.pop().unwrap();
310            let left = stack.pop().unwrap();
311            stack.push(self.merge_compress(left, right));
312        }
313        let root = stack.pop().unwrap();
314        self.compress_roots[start] = Some(root.slot);
315        root
316    }
317
318    fn merge_compress(&mut self, left: Node, right: Node) -> Node {
319        let id = self.compressed.len();
320        self.set_parent(left.slot, id << 1);
321        self.set_parent(right.slot, id << 1 | 1);
322        self.compressed.push(InnerNode {
323            left: left.slot,
324            right: right.slot,
325            parent: usize::MAX,
326        });
327        Node {
328            depth: left.depth.max(right.depth) + 1,
329            slot: Slot::CompressInner(id),
330        }
331    }
332
333    fn merge_rake(&mut self, left: Node, right: Node) -> Node {
334        let id = self.raked.len();
335        self.set_parent(left.slot, id << 1);
336        self.set_parent(right.slot, id << 1 | 1);
337        self.raked.push(InnerNode {
338            left: left.slot,
339            right: right.slot,
340            parent: usize::MAX,
341        });
342        Node {
343            depth: left.depth.max(right.depth) + 1,
344            slot: Slot::RakeInner(id),
345        }
346    }
347
348    fn set_parent(&mut self, slot: Slot, parent: usize) {
349        match slot {
350            Slot::CompressLeaf(v) => self.vertex_links[v].compress_parent = parent,
351            Slot::CompressInner(i) => self.compressed[i].parent = parent,
352            Slot::RakeLeaf(v) => self.vertex_links[v].rake_parent = parent,
353            Slot::RakeInner(i) => self.raked[i].parent = parent,
354        }
355    }
356
357    fn init_compress<C>(
358        &self,
359        data: &mut StaticTopTreeDataBuilder<C>,
360        vertices: &[<C as Cluster>::Vertex],
361        edges: &[<C as Cluster>::Edge],
362        slot: Slot,
363    ) -> <C as Cluster>::Path
364    where
365        C: Cluster,
366    {
367        match slot {
368            Slot::CompressLeaf(vertex) => {
369                let point = self.init_point(data, vertices, edges, vertex);
370                C::add_vertex(
371                    &point,
372                    &vertices[vertex],
373                    self.parent_edge_ref(edges, vertex),
374                )
375            }
376            Slot::CompressInner(id) => {
377                let node = &self.compressed[id];
378                let left = self.init_compress(data, vertices, edges, node.left);
379                let right = self.init_compress(data, vertices, edges, node.right);
380                data.compressed[id].write(InnerValue {
381                    left: left.clone(),
382                    right: right.clone(),
383                });
384                C::compress(&left, &right)
385            }
386            Slot::RakeLeaf(_) | Slot::RakeInner(_) => unreachable!(),
387        }
388    }
389
390    fn init_point<C>(
391        &self,
392        data: &mut StaticTopTreeDataBuilder<C>,
393        vertices: &[<C as Cluster>::Vertex],
394        edges: &[<C as Cluster>::Edge],
395        vertex: usize,
396    ) -> <C as Cluster>::Point
397    where
398        C: Cluster,
399    {
400        let point = if let Some(slot) = self.rake_roots[vertex] {
401            self.init_rake(data, vertices, edges, slot)
402        } else {
403            C::unit_point()
404        };
405        data.light_points[vertex] = point.clone();
406        point
407    }
408
409    fn init_rake<C>(
410        &self,
411        data: &mut StaticTopTreeDataBuilder<C>,
412        vertices: &[<C as Cluster>::Vertex],
413        edges: &[<C as Cluster>::Edge],
414        slot: Slot,
415    ) -> <C as Cluster>::Point
416    where
417        C: Cluster,
418    {
419        match slot {
420            Slot::RakeLeaf(vertex) => {
421                let path = self.init_compress(
422                    data,
423                    vertices,
424                    edges,
425                    self.compress_roots[vertex].expect("light child path must exist"),
426                );
427                C::add_edge(&path)
428            }
429            Slot::RakeInner(id) => {
430                let node = &self.raked[id];
431                let left = self.init_rake(data, vertices, edges, node.left);
432                let right = self.init_rake(data, vertices, edges, node.right);
433                data.raked[id].write(InnerValue {
434                    left: left.clone(),
435                    right: right.clone(),
436                });
437                C::rake(&left, &right)
438            }
439            Slot::CompressLeaf(_) | Slot::CompressInner(_) => unreachable!(),
440        }
441    }
442
443    fn parent_edge_ref<'a, T>(&self, edges: &'a [T], vertex: usize) -> Option<&'a T> {
444        let edge = self.parent_edge[vertex];
445        if edge == usize::MAX {
446            None
447        } else {
448            Some(&edges[edge])
449        }
450    }
451}
452
453impl<'a, C> StaticTopTreeDp<'a, C>
454where
455    C: Cluster,
456{
457    pub fn new(
458        tree: &'a StaticTopTree,
459        vertices: Vec<<C as Cluster>::Vertex>,
460        edges: Vec<<C as Cluster>::Edge>,
461    ) -> Self {
462        assert_eq!(vertices.len(), tree.vertices_size());
463        assert_eq!(edges.len(), tree.edges_size());
464
465        let mut data: StaticTopTreeDataBuilder<C> = StaticTopTreeDataBuilder::new(tree);
466        let path = tree.init_compress(
467            &mut data,
468            &vertices,
469            &edges,
470            tree.compress_roots[tree.root].expect("root compress tree must exist"),
471        );
472        let all_point = C::add_edge(&path);
473        Self {
474            tree,
475            vertices,
476            edges,
477            compressed: unsafe { assume_init_vec(data.compressed) },
478            raked: unsafe { assume_init_vec(data.raked) },
479            light_points: data.light_points,
480            all_point,
481        }
482    }
483
484    pub fn set_vertex(&mut self, vertex: usize, value: <C as Cluster>::Vertex) {
485        assert!(vertex < self.vertices.len());
486        self.vertices[vertex] = value;
487        self.update_from_vertex(vertex);
488    }
489
490    pub fn set_edge(&mut self, edge: usize, value: <C as Cluster>::Edge) {
491        assert!(edge < self.edges.len());
492        self.edges[edge] = value;
493        self.update_from_vertex(self.tree.edge_child[edge]);
494    }
495
496    pub fn fold_all(&self) -> &<C as Cluster>::Point {
497        &self.all_point
498    }
499
500    pub fn fold_path(&self, mut vertex: usize) -> <C as Cluster>::Path {
501        assert!(vertex < self.tree.n);
502        let mut path = C::unit_path();
503        let mut point = self.light_points[vertex].clone();
504        loop {
505            let links = self.tree.vertex_links[vertex];
506            let mut left = C::unit_path();
507            let mut right = C::unit_path();
508            let mut compress_parent = links.compress_parent;
509            while compress_parent != usize::MAX {
510                let inner = &self.compressed[compress_parent / 2];
511                if compress_parent & 1 == 0 {
512                    right = C::compress(&right, &inner.right);
513                } else {
514                    left = C::compress(&inner.left, &left);
515                }
516                compress_parent = self.tree.compressed[compress_parent / 2].parent;
517            }
518            let right_point = C::add_edge(&right);
519            point = C::rake(&point, &right_point);
520            let mid = C::add_vertex(
521                &point,
522                &self.vertices[vertex],
523                self.tree.parent_edge_ref(&self.edges, vertex),
524            );
525            let mid = C::compress(&mid, &path);
526            path = C::compress(&left, &mid);
527            if links.heavy_parent == usize::MAX {
528                return path;
529            }
530
531            point = C::unit_point();
532            let mut rake_parent = links.rake_parent;
533            while rake_parent != usize::MAX {
534                let inner = &self.raked[rake_parent / 2];
535                if rake_parent & 1 == 0 {
536                    point = C::rake(&point, &inner.right);
537                } else {
538                    point = C::rake(&inner.left, &point);
539                }
540                rake_parent = self.tree.raked[rake_parent / 2].parent;
541            }
542            vertex = links.heavy_parent;
543        }
544    }
545
546    fn update_from_vertex(&mut self, mut vertex: usize) {
547        assert!(vertex < self.tree.n);
548        while vertex != usize::MAX {
549            let links = self.tree.vertex_links[vertex];
550            let base = C::add_vertex(
551                &self.light_points[vertex],
552                &self.vertices[vertex],
553                self.tree.parent_edge_ref(&self.edges, vertex),
554            );
555            let path = self.update_compress(links.compress_parent, base);
556            let point = C::add_edge(&path);
557            let point = self.update_rake(links.rake_parent, point);
558            if links.heavy_parent == usize::MAX {
559                self.all_point = point;
560            } else {
561                self.light_points[links.heavy_parent] = point;
562            }
563            vertex = links.heavy_parent;
564        }
565    }
566
567    fn update_compress(
568        &mut self,
569        mut id: usize,
570        mut path: <C as Cluster>::Path,
571    ) -> <C as Cluster>::Path {
572        while id != usize::MAX {
573            let inner = &mut self.compressed[id / 2];
574            if id & 1 == 0 {
575                inner.left = path;
576            } else {
577                inner.right = path;
578            }
579            path = C::compress(&inner.left, &inner.right);
580            id = self.tree.compressed[id / 2].parent;
581        }
582        path
583    }
584
585    fn update_rake(
586        &mut self,
587        mut id: usize,
588        mut point: <C as Cluster>::Point,
589    ) -> <C as Cluster>::Point {
590        while id != usize::MAX {
591            let inner = &mut self.raked[id / 2];
592            if id & 1 == 0 {
593                inner.left = point;
594            } else {
595                inner.right = point;
596            }
597            point = C::rake(&inner.left, &inner.right);
598            id = self.tree.raked[id / 2].parent;
599        }
600        point
601    }
602}
603
604struct StaticTopTreeDataBuilder<C>
605where
606    C: Cluster,
607{
608    compressed: Vec<MaybeUninit<InnerValue<<C as Cluster>::Path>>>,
609    raked: Vec<MaybeUninit<InnerValue<<C as Cluster>::Point>>>,
610    light_points: Vec<<C as Cluster>::Point>,
611}
612
613impl<C> StaticTopTreeDataBuilder<C>
614where
615    C: Cluster,
616{
617    fn new(tree: &StaticTopTree) -> Self {
618        let mut compressed = Vec::with_capacity(tree.compressed.len());
619        compressed.resize_with(tree.compressed.len(), MaybeUninit::uninit);
620        let mut raked = Vec::with_capacity(tree.raked.len());
621        raked.resize_with(tree.raked.len(), MaybeUninit::uninit);
622        Self {
623            compressed,
624            raked,
625            light_points: vec![C::unit_point(); tree.n],
626        }
627    }
628}
629
630unsafe fn assume_init_vec<T>(mut vec: Vec<MaybeUninit<T>>) -> Vec<T> {
631    let len = vec.len();
632    let cap = vec.capacity();
633    let ptr = vec.as_mut_ptr() as *mut T;
634    std::mem::forget(vec);
635    unsafe { Vec::from_raw_parts(ptr, len, cap) }
636}
637
638fn bit_ceil(x: u64) -> u64 {
639    if x <= 1 { 1 } else { x.next_power_of_two() }
640}
641
642fn rooted_children(graph: &UndirectedSparseGraph, root: usize) -> RootedInfo {
643    let n = graph.vertices_size();
644    let mut order = Vec::with_capacity(n);
645    let mut parent = vec![usize::MAX; n];
646    let mut parent_edge = vec![usize::MAX; n];
647    let mut edge_child = vec![0; graph.edges_size()];
648    order.push(root);
649    parent[root] = usize::MAX;
650    for i in 0..n {
651        let u = order[i];
652        for a in graph.adjacencies(u) {
653            if a.to == parent[u] {
654                continue;
655            }
656            parent[a.to] = u;
657            parent_edge[a.to] = a.id;
658            edge_child[a.id] = a.to;
659            order.push(a.to);
660        }
661    }
662    let mut children_start = vec![0usize; n + 1];
663    for &v in order.iter().skip(1) {
664        children_start[parent[v] + 1] += 1;
665    }
666    for i in 1..=n {
667        children_start[i] += children_start[i - 1];
668    }
669    let mut children = vec![0; n.saturating_sub(1)];
670    let mut child_pos = children_start.clone();
671    for &v in order.iter().skip(1) {
672        let pos = child_pos[parent[v]];
673        children[pos] = v;
674        child_pos[parent[v]] += 1;
675    }
676    RootedInfo {
677        order,
678        children_start,
679        children,
680        edge_child,
681        parent_edge,
682    }
683}
684
685fn pop_bucket(buckets: &mut [Vec<Node>; 64], has: &mut u64) -> Node {
686    let depth = has.trailing_zeros() as usize;
687    let node = buckets[depth].pop().unwrap();
688    if buckets[depth].is_empty() {
689        *has &= !(1u64 << depth);
690    }
691    node
692}
693
694#[cfg(test)]
695mod tests {
696    use super::*;
697    use crate::{
698        algebra::{Associative, Magma, Unital},
699        graph::UndirectedSparseGraph,
700        num::{One, Zero, mint_basic::MInt998244353},
701        tools::Xorshift,
702        tree::{PathTree, PruferSequence, StarTree},
703    };
704
705    type MInt = MInt998244353;
706
707    #[derive(Clone, Copy, Debug, PartialEq, Eq)]
708    struct Point {
709        sum: MInt,
710        cnt: MInt,
711    }
712
713    #[derive(Clone, Copy, Debug, PartialEq, Eq)]
714    struct Path {
715        a: MInt,
716        b: MInt,
717        sum: MInt,
718        cnt: MInt,
719    }
720
721    #[derive(Clone, Copy, Debug, PartialEq, Eq)]
722    struct PathPair {
723        forward: Path,
724        reverse: Path,
725    }
726
727    struct PointMonoid;
728    impl Magma for PointMonoid {
729        type T = Point;
730        fn operate(x: &Self::T, y: &Self::T) -> Self::T {
731            Point {
732                sum: x.sum + y.sum,
733                cnt: x.cnt + y.cnt,
734            }
735        }
736    }
737    impl Unital for PointMonoid {
738        fn unit() -> Self::T {
739            Point {
740                sum: MInt::zero(),
741                cnt: MInt::zero(),
742            }
743        }
744    }
745    impl Associative for PointMonoid {}
746
747    struct PathMonoid;
748    impl Magma for PathMonoid {
749        type T = Path;
750        fn operate(x: &Self::T, y: &Self::T) -> Self::T {
751            Path {
752                a: x.a * y.a,
753                b: x.b + x.a * y.b,
754                sum: x.sum + x.a * y.sum + x.b * y.cnt,
755                cnt: x.cnt + y.cnt,
756            }
757        }
758    }
759    impl Unital for PathMonoid {
760        fn unit() -> Self::T {
761            Path {
762                a: MInt::one(),
763                b: MInt::zero(),
764                sum: MInt::zero(),
765                cnt: MInt::zero(),
766            }
767        }
768    }
769    impl Associative for PathMonoid {}
770
771    struct PathPairMonoid;
772    impl Magma for PathPairMonoid {
773        type T = PathPair;
774        fn operate(x: &Self::T, y: &Self::T) -> Self::T {
775            PathPair {
776                forward: PathMonoid::operate(&x.forward, &y.forward),
777                reverse: PathMonoid::operate(&y.reverse, &x.reverse),
778            }
779        }
780    }
781    impl Unital for PathPairMonoid {
782        fn unit() -> Self::T {
783            PathPair {
784                forward: PathMonoid::unit(),
785                reverse: PathMonoid::unit(),
786            }
787        }
788    }
789    impl Associative for PathPairMonoid {}
790
791    struct FixedCluster;
792    impl MonoidCluster for FixedCluster {
793        type Vertex = MInt;
794        type Edge = (MInt, MInt);
795        type PointMonoid = PointMonoid;
796        type PathMonoid = PathMonoid;
797
798        fn add_vertex(point: &Point, vertex: &MInt, parent_edge: Option<&(MInt, MInt)>) -> Path {
799            let cnt = point.cnt + MInt::one();
800            let subtotal = point.sum + *vertex;
801            let (a, b) = parent_edge.copied().unwrap_or((MInt::one(), MInt::zero()));
802            Path {
803                a,
804                b,
805                sum: a * subtotal + b * cnt,
806                cnt,
807            }
808        }
809
810        fn add_edge(path: &Path) -> Point {
811            Point {
812                sum: path.sum,
813                cnt: path.cnt,
814            }
815        }
816    }
817
818    struct RerootCluster;
819    impl MonoidCluster for RerootCluster {
820        type Vertex = MInt;
821        type Edge = (MInt, MInt);
822        type PointMonoid = PointMonoid;
823        type PathMonoid = PathPairMonoid;
824
825        fn add_vertex(
826            point: &Point,
827            vertex: &MInt,
828            parent_edge: Option<&(MInt, MInt)>,
829        ) -> PathPair {
830            let cnt = point.cnt + MInt::one();
831            let subtotal = point.sum + *vertex;
832            let (a, b) = parent_edge.copied().unwrap_or((MInt::one(), MInt::zero()));
833            PathPair {
834                forward: Path {
835                    a,
836                    b,
837                    sum: a * subtotal + b * cnt,
838                    cnt,
839                },
840                reverse: Path {
841                    a,
842                    b,
843                    sum: subtotal,
844                    cnt,
845                },
846            }
847        }
848
849        fn add_edge(path: &PathPair) -> Point {
850            Point {
851                sum: path.forward.sum,
852                cnt: path.forward.cnt,
853            }
854        }
855    }
856
857    fn naive_rooted(
858        graph: &UndirectedSparseGraph,
859        vertices: &[MInt],
860        edges: &[(MInt, MInt)],
861        root: usize,
862    ) -> Point {
863        fn dfs(
864            graph: &UndirectedSparseGraph,
865            vertices: &[MInt],
866            edges: &[(MInt, MInt)],
867            u: usize,
868            p: usize,
869            in_edge: Option<usize>,
870        ) -> Point {
871            let mut point = PointMonoid::unit();
872            for a in graph.adjacencies(u) {
873                if a.to != p {
874                    point = PointMonoid::operate(
875                        &point,
876                        &dfs(graph, vertices, edges, a.to, u, Some(a.id)),
877                    );
878                }
879            }
880            let cnt = point.cnt + MInt::one();
881            let subtotal = point.sum + vertices[u];
882            let (a, b) = in_edge
883                .map(|eid| edges[eid])
884                .unwrap_or((MInt::one(), MInt::zero()));
885            Point {
886                sum: a * subtotal + b * cnt,
887                cnt,
888            }
889        }
890        dfs(graph, vertices, edges, root, usize::MAX, None)
891    }
892
893    fn balanced_tree(n: usize) -> UndirectedSparseGraph {
894        let edges = (1..n).map(|v| ((v - 1) / 2, v)).collect::<Vec<_>>();
895        UndirectedSparseGraph::from_edges(n, edges)
896    }
897
898    fn gen_weights(rng: &mut Xorshift, n: usize, m: usize) -> (Vec<MInt>, Vec<(MInt, MInt)>) {
899        let vertices = (0..n)
900            .map(|_| MInt::from(rng.random(0u32..10)))
901            .collect::<Vec<_>>();
902        let edges = (0..m)
903            .map(|_| {
904                (
905                    MInt::from(rng.random(0u32..10)),
906                    MInt::from(rng.random(0u32..10)),
907                )
908            })
909            .collect::<Vec<_>>();
910        (vertices, edges)
911    }
912
913    fn run_fixed_case(graph: &UndirectedSparseGraph, rounds: usize, rng: &mut Xorshift) {
914        let n = graph.vertices_size();
915        let m = graph.edges_size();
916        let (mut vertices, mut edges) = gen_weights(rng, n, m);
917        let tree = graph.static_top_tree(0);
918        let mut dp = tree.dp::<FixedCluster>(vertices.clone(), edges.clone());
919        assert_eq!(*dp.fold_all(), naive_rooted(graph, &vertices, &edges, 0));
920
921        for _ in 0..rounds {
922            if rng.random(0u32..2) == 0 {
923                let v = rng.random(0..n);
924                let x = MInt::from(rng.random(0u32..20));
925                vertices[v] = x;
926                dp.set_vertex(v, x);
927            } else if m > 0 {
928                let eid = rng.random(0..m);
929                let edge = (
930                    MInt::from(rng.random(0u32..20)),
931                    MInt::from(rng.random(0u32..20)),
932                );
933                edges[eid] = edge;
934                dp.set_edge(eid, edge);
935            }
936            assert_eq!(*dp.fold_all(), naive_rooted(graph, &vertices, &edges, 0));
937        }
938    }
939
940    fn run_reroot_case(graph: &UndirectedSparseGraph, rounds: usize, rng: &mut Xorshift) {
941        let n = graph.vertices_size();
942        let m = graph.edges_size();
943        let (mut vertices, mut edges) = gen_weights(rng, n, m);
944        let tree = graph.static_top_tree(0);
945        let mut dp = tree.dp::<RerootCluster>(vertices.clone(), edges.clone());
946        assert_eq!(
947            dp.fold_all().sum,
948            naive_rooted(graph, &vertices, &edges, 0).sum
949        );
950
951        for _ in 0..rounds {
952            if rng.random(0u32..2) == 0 {
953                let v = rng.random(0..n);
954                let x = MInt::from(rng.random(0u32..20));
955                vertices[v] = x;
956                dp.set_vertex(v, x);
957            } else if m > 0 {
958                let eid = rng.random(0..m);
959                let edge = (
960                    MInt::from(rng.random(0u32..20)),
961                    MInt::from(rng.random(0u32..20)),
962                );
963                edges[eid] = edge;
964                dp.set_edge(eid, edge);
965            } else {
966                let v = rng.random(0..n);
967                let x = MInt::from(rng.random(0u32..20));
968                vertices[v] = x;
969                dp.set_vertex(v, x);
970            }
971            for root in 0..n {
972                let got = dp.fold_path(root).reverse.sum;
973                let want = naive_rooted(graph, &vertices, &edges, root).sum;
974                assert_eq!(got, want, "root={root}");
975            }
976        }
977    }
978
979    #[test]
980    fn static_top_tree_fixed_random() {
981        let mut rng = Xorshift::default();
982        for _ in 0..30 {
983            let graph = rng.random(PruferSequence(2..=14usize));
984            run_fixed_case(&graph, 40, &mut rng);
985        }
986    }
987
988    #[test]
989    fn static_top_tree_reroot_random() {
990        let mut rng = Xorshift::default();
991        for _ in 0..20 {
992            let graph = rng.random(PruferSequence(2..=12usize));
993            run_reroot_case(&graph, 30, &mut rng);
994        }
995    }
996
997    #[test]
998    fn static_top_tree_shapes() {
999        let mut rng = Xorshift::default();
1000        for graph in [
1001            UndirectedSparseGraph::from_edges(1, vec![]),
1002            rng.random(PathTree(2..=16usize)),
1003            rng.random(StarTree(2..=16usize)),
1004            balanced_tree(15),
1005        ] {
1006            run_fixed_case(&graph, 30, &mut rng);
1007            run_reroot_case(&graph, 20, &mut rng);
1008        }
1009    }
1010}