1use super::Dinic;
2use std::{cmp::Ordering, collections::HashMap};
3
4#[derive(Debug, Default, Clone)]
5pub struct ProjectSelectionProblem {
6 n_values: Vec<usize>,
7 start: Vec<usize>,
8 cost1: Vec<Vec<i64>>,
9 cost2: HashMap<(usize, usize), u64>,
10 totalcost: i64,
11}
12impl ProjectSelectionProblem {
13 pub fn new(n_project: usize, n_value: usize) -> Self {
14 Self {
15 n_values: vec![n_value; n_project],
16 start: (0..=n_project * (n_value - 1))
17 .step_by(n_value - 1)
18 .collect(),
19 cost1: vec![vec![0i64; n_value]; n_project],
20 cost2: Default::default(),
21 totalcost: 0i64,
22 }
23 }
24 pub fn with_n_values(n_values: Vec<usize>) -> Self {
25 let mut start = Vec::with_capacity(n_values.len() + 1);
26 start.push(0usize);
27 for nv in n_values.iter() {
28 start.push(start.last().unwrap() + nv - 1);
29 }
30 let cost1 = n_values.iter().map(|&n| vec![0i64; n]).collect();
31 Self {
32 n_values,
33 start,
34 cost1,
35 cost2: Default::default(),
36 totalcost: 0i64,
37 }
38 }
39 pub fn add_cost1(&mut self, p: usize, v: usize, c: i64) {
40 self.cost1[p][v] += c;
41 }
42 pub fn add_cost2_01(&mut self, p1: usize, p2: usize, v1: usize, v2: usize, c: u64) {
44 debug_assert!(0 < v1 && v1 < self.n_values[p1]);
45 debug_assert!(0 < v2 && v2 < self.n_values[p2]);
46 let key = (self.start[p1] + v1 - 1, self.start[p2] + v2 - 1);
47 if c > 0 {
48 *self.cost2.entry(key).or_default() += c;
49 }
50 }
51 pub fn add_cost2_10(&mut self, p1: usize, p2: usize, v1: usize, v2: usize, c: u64) {
53 self.add_cost2_01(p2, p1, v2, v1, c);
54 }
55 pub fn add_cost2<F>(&mut self, p1: usize, p2: usize, mut cost: F)
57 where
58 F: FnMut(usize, usize) -> i64,
59 {
60 debug_assert_ne!(p1, p2);
61 let nv1 = self.n_values[p1];
62 let nv2 = self.n_values[p2];
63 debug_assert_ne!(nv1, 0);
64 debug_assert_ne!(nv2, 0);
65 let c00 = cost(0, 0);
66 self.totalcost += c00;
67 for v1 in 1usize..nv1 {
68 self.add_cost1(p1, v1, cost(v1, 0) - c00);
69 }
70 for v2 in 1usize..nv2 {
71 self.add_cost1(p2, v2, cost(0, v2) - c00);
72 }
73 let mut acc = 0i64;
74 for v1 in 1usize..nv1 {
75 for v2 in 1usize..nv2 {
76 let c = cost(v1 - 1, v2) + cost(v1, v2 - 1) - cost(v1, v2) - cost(v1 - 1, v2 - 1);
77 debug_assert!(c >= 0, "cost is not monge");
78 let key = (self.start[p1] + v1 - 1, self.start[p2] + v2 - 1);
79 if c > 0 {
80 *self.cost2.entry(key).or_default() += c as u64;
81 }
82 acc -= c;
83 }
84 self.add_cost1(p1, v1, acc);
85 }
86 }
87 pub fn solve(&self) -> (i64, Vec<usize>) {
88 let vsize = *self.start.last().unwrap();
89 let esize_expect = vsize * 2 + self.cost2.len();
90 let mut builder = Dinic::builder(vsize + 2, esize_expect);
91 let mut totalcost = self.totalcost;
92 let s = vsize;
93 let t = s + 1;
94 for (p, c) in self.cost1.iter().enumerate() {
95 let nv = self.n_values[p];
96 totalcost += c[nv - 1];
97 for v in 1usize..nv {
98 let u = self.start[p] + v - 1;
99 let d = c[v] - c[v - 1];
100 match d.cmp(&0) {
101 Ordering::Less => {
102 builder.add_edge(s, u, (-d) as u64);
103 }
104 Ordering::Greater => {
105 builder.add_edge(u, t, d as u64);
106 totalcost -= d;
107 }
108 Ordering::Equal => {}
109 }
110 if v >= 2 {
111 builder.add_edge(u, u - 1, u64::MAX);
112 }
113 }
114 }
115 for (&(x, y), &c) in self.cost2.iter() {
116 builder.add_edge(x, y, c);
117 }
118 let dgraph = builder.gen_graph();
119 let mut dinic = builder.build(&dgraph);
120 let res = dinic.maximum_flow(s, t) as i64 + totalcost;
121 let visited = dinic.minimum_cut(s);
122 let mut values = vec![0usize; self.n_values.len()];
123 for (p, &nv) in self.n_values.iter().enumerate() {
124 for v in 1usize..nv {
125 values[p] += visited[self.start[p] + v - 1] as usize;
126 }
127 }
128 (res, values)
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use crate::tools::Xorshift;
136
137 fn brute_force<F>(n_values: &[usize], mut evaluate: F) -> (i64, Vec<Vec<usize>>)
138 where
139 F: FnMut(&[usize]) -> i64,
140 {
141 let mut best_cost = None;
142 let mut best_assignments = Vec::new();
143 let mut current = vec![0; n_values.len()];
144
145 fn dfs<F>(
146 n_values: &[usize],
147 idx: usize,
148 current: &mut [usize],
149 best_cost: &mut Option<i64>,
150 best_assignments: &mut Vec<Vec<usize>>,
151 evaluate: &mut F,
152 ) where
153 F: FnMut(&[usize]) -> i64,
154 {
155 if idx == current.len() {
156 let cost = evaluate(current);
157 match best_cost {
158 None => {
159 *best_cost = Some(cost);
160 best_assignments.push(current.to_vec());
161 }
162 Some(bc) if cost < *bc => {
163 *best_cost = Some(cost);
164 best_assignments.clear();
165 best_assignments.push(current.to_vec());
166 }
167 Some(bc) if cost == *bc => {
168 best_assignments.push(current.to_vec());
169 }
170 _ => {}
171 }
172 return;
173 }
174 for value in 0..n_values[idx] {
175 current[idx] = value;
176 dfs(
177 n_values,
178 idx + 1,
179 current,
180 best_cost,
181 best_assignments,
182 evaluate,
183 );
184 }
185 }
186
187 dfs(
188 n_values,
189 0,
190 &mut current,
191 &mut best_cost,
192 &mut best_assignments,
193 &mut evaluate,
194 );
195 (best_cost.unwrap(), best_assignments)
196 }
197
198 #[test]
199 fn test_project_selection_problem() {
200 #[derive(Clone)]
201 struct Penalty {
202 p1: usize,
203 p2: usize,
204 v1: usize,
205 v2: usize,
206 dir_01: bool,
207 cost: u64,
208 }
209
210 #[derive(Clone)]
211 struct MongePair {
212 p1: usize,
213 p2: usize,
214 costs: Vec<Vec<i64>>,
215 }
216
217 let mut rng = Xorshift::default();
218 for _ in 0..200 {
219 let n_projects = rng.random(1..=5);
220 let mut n_values = Vec::with_capacity(n_projects);
221 for _ in 0..n_projects {
222 n_values.push(rng.random(2..=5));
223 }
224
225 let mut unary_costs = Vec::with_capacity(n_projects);
226 for &nv in &n_values {
227 let mut costs = Vec::with_capacity(nv);
228 for _ in 0..nv {
229 costs.push(rng.random(-50..=50));
230 }
231 unary_costs.push(costs);
232 }
233
234 let mut penalties = Vec::new();
235 let mut monge_pairs = Vec::new();
236 for p1 in 0..n_projects {
237 for p2 in 0..n_projects {
238 if p1 == p2 {
239 continue;
240 }
241 if rng.random(0..=2) == 0 {
242 let v1 = rng.random(1..n_values[p1]);
243 let v2 = rng.random(1..n_values[p2]);
244 let cost = rng.random(0..=50);
245 if cost == 0 {
246 continue;
247 }
248 penalties.push(Penalty {
249 p1,
250 p2,
251 v1,
252 v2,
253 dir_01: rng.random(0..=1) == 0,
254 cost,
255 });
256 }
257 if p1 < p2 && rng.random(0..=3) == 0 {
258 let nv1 = n_values[p1];
259 let nv2 = n_values[p2];
260 let mut costs = vec![vec![0i64; nv2]; nv1];
261 costs[0][0] = rng.random(-30..=30);
262 for y in 1..nv2 {
263 let delta = rng.random(-20i64..=20);
264 costs[0][y] = costs[0][y - 1] + delta;
265 }
266 for x in 1..nv1 {
267 let delta = rng.random(-20i64..=20);
268 costs[x][0] = costs[x - 1][0] + delta;
269 }
270 for x in 1..nv1 {
271 for y in 1..nv2 {
272 let upper = costs[x - 1][y] + costs[x][y - 1] - costs[x - 1][y - 1];
273 let reduction = rng.random(0i64..=40);
274 costs[x][y] = upper - reduction;
275 }
276 }
277 monge_pairs.push(MongePair { p1, p2, costs });
278 }
279 }
280 }
281
282 let (expected_cost, expected_assignments) = brute_force(&n_values, |assignment| {
283 let mut total = 0i64;
284 for (p, &value) in assignment.iter().enumerate() {
285 total += unary_costs[p][value];
286 }
287 for penalty in &penalties {
288 let applies = if penalty.dir_01 {
289 assignment[penalty.p1] >= penalty.v1 && assignment[penalty.p2] < penalty.v2
290 } else {
291 assignment[penalty.p1] < penalty.v1 && assignment[penalty.p2] >= penalty.v2
292 };
293 if applies {
294 total += penalty.cost as i64;
295 }
296 }
297 for monge in &monge_pairs {
298 let x = assignment[monge.p1];
299 let y = assignment[monge.p2];
300 total += monge.costs[x][y];
301 }
302 total
303 });
304
305 let mut psp = if n_values.iter().all(|&nv| nv == n_values[0]) {
306 ProjectSelectionProblem::new(n_values.len(), n_values[0])
307 } else {
308 ProjectSelectionProblem::with_n_values(n_values.clone())
309 };
310 for (p, costs) in unary_costs.iter().enumerate() {
311 for (value, &cost) in costs.iter().enumerate() {
312 psp.add_cost1(p, value, cost);
313 }
314 }
315 for penalty in &penalties {
316 if penalty.dir_01 {
317 psp.add_cost2_01(penalty.p1, penalty.p2, penalty.v1, penalty.v2, penalty.cost);
318 } else {
319 psp.add_cost2_10(penalty.p1, penalty.p2, penalty.v1, penalty.v2, penalty.cost);
320 }
321 }
322 for monge in &monge_pairs {
323 let costs = monge.costs.clone();
324 psp.add_cost2(monge.p1, monge.p2, move |x, y| costs[x][y]);
325 }
326
327 let (cost, values) = psp.solve();
328 assert_eq!(cost, expected_cost);
329 assert!(expected_assignments.contains(&values));
330 }
331 }
332}