competitive/graph/
steiner_tree.rs1use super::{
2 AdjacencyIndex, AdjacencyIndexWithValue, AdjacencyView, BitDpExt, PartialIgnoredOrd,
3 ShortestPathSemiRing, VertexMap, Vertices,
4};
5use std::{cmp::Reverse, collections::BinaryHeap, iter::repeat_with};
6
7pub trait SteinerTreeExt: Vertices {
8 fn steiner_tree<'a, S, M, I>(
9 &self,
10 terminals: I,
11 weight: &'a M,
12 ) -> SteinerTreeOutput<'_, S, Self>
13 where
14 Self: VertexMap<S::T> + AdjacencyView<'a, M, S::T>,
15 S: ShortestPathSemiRing,
16 I: IntoIterator<Item = Self::VIndex> + ExactSizeIterator,
17 {
18 let tsize = terminals.len();
19 if tsize == 0 {
20 return SteinerTreeOutput {
21 g: self,
22 dp: vec![],
23 };
24 }
25 let mut dp: Vec<_> = repeat_with(|| self.construct_vmap(S::inf))
26 .take(1 << tsize)
27 .collect();
28 for (i, t) in terminals.into_iter().enumerate() {
29 *self.vmap_get_mut(&mut dp[1 << i], t) = S::source();
30 }
31 for bit in 1..1 << tsize {
32 for u in self.vertices() {
33 for sub in bit.subsets() {
34 if sub != 0 {
35 let cost =
36 S::mul(self.vmap_get(&dp[sub], u), self.vmap_get(&dp[bit ^ sub], u));
37 S::add_assign(self.vmap_get_mut(&mut dp[bit], u), &cost);
38 }
39 }
40 }
41 let dp = &mut dp[bit];
42 let mut heap: BinaryHeap<_> = self
43 .vertices()
44 .map(|u| PartialIgnoredOrd(Reverse(self.vmap_get(dp, u).clone()), u))
45 .collect();
46 while let Some(PartialIgnoredOrd(Reverse(d), u)) = heap.pop() {
47 if self.vmap_get(dp, u) != &d {
48 continue;
49 }
50 for a in self.aviews(weight, u) {
51 let v = a.vindex();
52 let nd = S::mul(&d, &a.avalue());
53 if S::add_assign(self.vmap_get_mut(dp, v), &nd) {
54 heap.push(PartialIgnoredOrd(Reverse(nd), v));
55 }
56 }
57 }
58 }
59 SteinerTreeOutput { g: self, dp }
60 }
61}
62impl<G> SteinerTreeExt for G where G: Vertices {}
63pub struct SteinerTreeOutput<'g, S, G>
64where
65 G: VertexMap<S::T> + ?Sized,
66 S: ShortestPathSemiRing,
67{
68 g: &'g G,
69 dp: Vec<G::Vmap>,
70}
71impl<S, G> SteinerTreeOutput<'_, S, G>
72where
73 G: VertexMap<S::T> + ?Sized,
74 S: ShortestPathSemiRing,
75{
76 pub fn minimum_from_source(&self, source: G::VIndex) -> S::T {
77 match self.dp.last() {
78 Some(dp) => self.g.vmap_get(dp, source).clone(),
79 None => S::source(),
80 }
81 }
82}