competitive/graph/
project_selection_problem.rs

1use super::Dinic;
2use std::{cmp::Ordering, collections::HashMap};
3
4#[derive(Debug, Default, Clone)]
5pub struct ProjectSelectionProblem {
6    n_values: Vec<usize>,
7    start: Vec<usize>,
8    cost1: Vec<Vec<i64>>,
9    cost2: HashMap<(usize, usize), u64>,
10    totalcost: i64,
11}
12impl ProjectSelectionProblem {
13    pub fn new(n_project: usize, n_value: usize) -> Self {
14        Self {
15            n_values: vec![n_value; n_project],
16            start: (0..=n_project * (n_value - 1))
17                .step_by(n_value - 1)
18                .collect(),
19            cost1: vec![vec![0i64; n_value]; n_project],
20            cost2: Default::default(),
21            totalcost: 0i64,
22        }
23    }
24    pub fn with_n_values(n_values: Vec<usize>) -> Self {
25        let mut start = Vec::with_capacity(n_values.len() + 1);
26        start.push(0usize);
27        for nv in n_values.iter() {
28            start.push(start.last().unwrap() + nv - 1);
29        }
30        let cost1 = n_values.iter().map(|&n| vec![0i64; n]).collect();
31        Self {
32            n_values,
33            start,
34            cost1,
35            cost2: Default::default(),
36            totalcost: 0i64,
37        }
38    }
39    pub fn add_cost1(&mut self, p: usize, v: usize, c: i64) {
40        self.cost1[p][v] += c;
41    }
42    /// x1 >= v1 && x2 < v2 (0 < v1 < nv1, 0 < v2 < nv2)
43    pub fn add_cost2_01(&mut self, p1: usize, p2: usize, v1: usize, v2: usize, c: u64) {
44        debug_assert!(0 < v1 && v1 < self.n_values[p1]);
45        debug_assert!(0 < v2 && v2 < self.n_values[p2]);
46        let key = (self.start[p1] + v1 - 1, self.start[p2] + v2 - 1);
47        if c > 0 {
48            *self.cost2.entry(key).or_default() += c;
49        }
50    }
51    /// x1 < v1 && x2 >= v2 (0 < v1 < nv1, 0 < v2 < nv2)
52    pub fn add_cost2_10(&mut self, p1: usize, p2: usize, v1: usize, v2: usize, c: u64) {
53        self.add_cost2_01(p2, p1, v2, v1, c);
54    }
55    /// cost is monge: cost(v1-1, v2) + cost(v1, v2-1) >= cost(v1, v2) + cost(v1-1, v2-1)
56    pub fn add_cost2<F>(&mut self, p1: usize, p2: usize, mut cost: F)
57    where
58        F: FnMut(usize, usize) -> i64,
59    {
60        debug_assert_ne!(p1, p2);
61        let nv1 = self.n_values[p1];
62        let nv2 = self.n_values[p2];
63        debug_assert_ne!(nv1, 0);
64        debug_assert_ne!(nv2, 0);
65        let c00 = cost(0, 0);
66        self.totalcost += c00;
67        for v1 in 1usize..nv1 {
68            self.add_cost1(p1, v1, cost(v1, 0) - c00);
69        }
70        for v2 in 1usize..nv2 {
71            self.add_cost1(p2, v2, cost(0, v2) - c00);
72        }
73        let mut acc = 0i64;
74        for v1 in 1usize..nv1 {
75            for v2 in 1usize..nv2 {
76                let c = cost(v1 - 1, v2) + cost(v1, v2 - 1) - cost(v1, v2) - cost(v1 - 1, v2 - 1);
77                debug_assert!(c >= 0, "cost is not monge");
78                let key = (self.start[p1] + v1 - 1, self.start[p2] + v2 - 1);
79                if c > 0 {
80                    *self.cost2.entry(key).or_default() += c as u64;
81                }
82                acc -= c;
83            }
84            self.add_cost1(p1, v1, acc);
85        }
86    }
87    pub fn solve(&self) -> (i64, Vec<usize>) {
88        let vsize = *self.start.last().unwrap();
89        let esize_expect = vsize * 2 + self.cost2.len();
90        let mut builder = Dinic::builder(vsize + 2, esize_expect);
91        let mut totalcost = self.totalcost;
92        let s = vsize;
93        let t = s + 1;
94        for (p, c) in self.cost1.iter().enumerate() {
95            let nv = self.n_values[p];
96            totalcost += c[nv - 1];
97            for v in 1usize..nv {
98                let u = self.start[p] + v - 1;
99                let d = c[v] - c[v - 1];
100                match d.cmp(&0) {
101                    Ordering::Less => {
102                        builder.add_edge(s, u, (-d) as u64);
103                    }
104                    Ordering::Greater => {
105                        builder.add_edge(u, t, d as u64);
106                        totalcost -= d;
107                    }
108                    Ordering::Equal => {}
109                }
110                if v >= 2 {
111                    builder.add_edge(u, u - 1, u64::MAX);
112                }
113            }
114        }
115        for (&(x, y), &c) in self.cost2.iter() {
116            builder.add_edge(x, y, c);
117        }
118        let dgraph = builder.gen_graph();
119        let mut dinic = builder.build(&dgraph);
120        let res = dinic.maximum_flow(s, t) as i64 + totalcost;
121        let visited = dinic.minimum_cut(s);
122        let mut values = vec![0usize; self.n_values.len()];
123        for (p, &nv) in self.n_values.iter().enumerate() {
124            for v in 1usize..nv {
125                values[p] += visited[self.start[p] + v - 1] as usize;
126            }
127        }
128        (res, values)
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use crate::tools::Xorshift;
136
137    fn brute_force<F>(n_values: &[usize], mut evaluate: F) -> (i64, Vec<Vec<usize>>)
138    where
139        F: FnMut(&[usize]) -> i64,
140    {
141        let mut best_cost = None;
142        let mut best_assignments = Vec::new();
143        let mut current = vec![0; n_values.len()];
144
145        fn dfs<F>(
146            n_values: &[usize],
147            idx: usize,
148            current: &mut [usize],
149            best_cost: &mut Option<i64>,
150            best_assignments: &mut Vec<Vec<usize>>,
151            evaluate: &mut F,
152        ) where
153            F: FnMut(&[usize]) -> i64,
154        {
155            if idx == current.len() {
156                let cost = evaluate(current);
157                match best_cost {
158                    None => {
159                        *best_cost = Some(cost);
160                        best_assignments.push(current.to_vec());
161                    }
162                    Some(bc) if cost < *bc => {
163                        *best_cost = Some(cost);
164                        best_assignments.clear();
165                        best_assignments.push(current.to_vec());
166                    }
167                    Some(bc) if cost == *bc => {
168                        best_assignments.push(current.to_vec());
169                    }
170                    _ => {}
171                }
172                return;
173            }
174            for value in 0..n_values[idx] {
175                current[idx] = value;
176                dfs(
177                    n_values,
178                    idx + 1,
179                    current,
180                    best_cost,
181                    best_assignments,
182                    evaluate,
183                );
184            }
185        }
186
187        dfs(
188            n_values,
189            0,
190            &mut current,
191            &mut best_cost,
192            &mut best_assignments,
193            &mut evaluate,
194        );
195        (best_cost.unwrap(), best_assignments)
196    }
197
198    #[test]
199    fn test_project_selection_problem() {
200        #[derive(Clone)]
201        struct Penalty {
202            p1: usize,
203            p2: usize,
204            v1: usize,
205            v2: usize,
206            dir_01: bool,
207            cost: u64,
208        }
209
210        #[derive(Clone)]
211        struct MongePair {
212            p1: usize,
213            p2: usize,
214            costs: Vec<Vec<i64>>,
215        }
216
217        let mut rng = Xorshift::default();
218        for _ in 0..200 {
219            let n_projects = rng.random(1..=5);
220            let mut n_values = Vec::with_capacity(n_projects);
221            for _ in 0..n_projects {
222                n_values.push(rng.random(2..=5));
223            }
224
225            let mut unary_costs = Vec::with_capacity(n_projects);
226            for &nv in &n_values {
227                let mut costs = Vec::with_capacity(nv);
228                for _ in 0..nv {
229                    costs.push(rng.random(-50..=50));
230                }
231                unary_costs.push(costs);
232            }
233
234            let mut penalties = Vec::new();
235            let mut monge_pairs = Vec::new();
236            for p1 in 0..n_projects {
237                for p2 in 0..n_projects {
238                    if p1 == p2 {
239                        continue;
240                    }
241                    if rng.random(0..=2) == 0 {
242                        let v1 = rng.random(1..n_values[p1]);
243                        let v2 = rng.random(1..n_values[p2]);
244                        let cost = rng.random(0..=50);
245                        if cost == 0 {
246                            continue;
247                        }
248                        penalties.push(Penalty {
249                            p1,
250                            p2,
251                            v1,
252                            v2,
253                            dir_01: rng.random(0..=1) == 0,
254                            cost,
255                        });
256                    }
257                    if p1 < p2 && rng.random(0..=3) == 0 {
258                        let nv1 = n_values[p1];
259                        let nv2 = n_values[p2];
260                        let mut costs = vec![vec![0i64; nv2]; nv1];
261                        costs[0][0] = rng.random(-30..=30);
262                        for y in 1..nv2 {
263                            let delta = rng.random(-20i64..=20);
264                            costs[0][y] = costs[0][y - 1] + delta;
265                        }
266                        for x in 1..nv1 {
267                            let delta = rng.random(-20i64..=20);
268                            costs[x][0] = costs[x - 1][0] + delta;
269                        }
270                        for x in 1..nv1 {
271                            for y in 1..nv2 {
272                                let upper = costs[x - 1][y] + costs[x][y - 1] - costs[x - 1][y - 1];
273                                let reduction = rng.random(0i64..=40);
274                                costs[x][y] = upper - reduction;
275                            }
276                        }
277                        monge_pairs.push(MongePair { p1, p2, costs });
278                    }
279                }
280            }
281
282            let (expected_cost, expected_assignments) = brute_force(&n_values, |assignment| {
283                let mut total = 0i64;
284                for (p, &value) in assignment.iter().enumerate() {
285                    total += unary_costs[p][value];
286                }
287                for penalty in &penalties {
288                    let applies = if penalty.dir_01 {
289                        assignment[penalty.p1] >= penalty.v1 && assignment[penalty.p2] < penalty.v2
290                    } else {
291                        assignment[penalty.p1] < penalty.v1 && assignment[penalty.p2] >= penalty.v2
292                    };
293                    if applies {
294                        total += penalty.cost as i64;
295                    }
296                }
297                for monge in &monge_pairs {
298                    let x = assignment[monge.p1];
299                    let y = assignment[monge.p2];
300                    total += monge.costs[x][y];
301                }
302                total
303            });
304
305            let mut psp = if n_values.iter().all(|&nv| nv == n_values[0]) {
306                ProjectSelectionProblem::new(n_values.len(), n_values[0])
307            } else {
308                ProjectSelectionProblem::with_n_values(n_values.clone())
309            };
310            for (p, costs) in unary_costs.iter().enumerate() {
311                for (value, &cost) in costs.iter().enumerate() {
312                    psp.add_cost1(p, value, cost);
313                }
314            }
315            for penalty in &penalties {
316                if penalty.dir_01 {
317                    psp.add_cost2_01(penalty.p1, penalty.p2, penalty.v1, penalty.v2, penalty.cost);
318                } else {
319                    psp.add_cost2_10(penalty.p1, penalty.p2, penalty.v1, penalty.v2, penalty.cost);
320                }
321            }
322            for monge in &monge_pairs {
323                let costs = monge.costs.clone();
324                psp.add_cost2(monge.p1, monge.p2, move |x, y| costs[x][y]);
325            }
326
327            let (cost, values) = psp.solve();
328            assert_eq!(cost, expected_cost);
329            assert!(expected_assignments.contains(&values));
330        }
331    }
332}