competitive/combinatorial_optimization/
subset_sum_problem.rs1use super::BitSet;
2use std::{cmp::Reverse, collections::BinaryHeap, mem::take};
3
4#[derive(Debug, Clone)]
5pub struct SubsetSumProblem {
6 size: usize,
7 dp: BitSet,
8 pending_weights: Vec<Reverse<usize>>,
9}
10
11impl SubsetSumProblem {
12 pub fn new(size: usize) -> Self {
13 let mut dp = BitSet::new(if size == !0 { 0 } else { size } + 1);
14 dp.set(0, true);
15 Self {
16 size,
17 dp,
18 pending_weights: vec![],
19 }
20 }
21
22 pub fn insert(&mut self, weight: usize) {
23 if weight == 0 || self.size < weight {
24 return;
25 }
26 self.pending_weights.push(Reverse(weight));
27 }
28
29 pub fn extend<I>(&mut self, iter: I)
30 where
31 I: IntoIterator<Item = usize>,
32 {
33 for weight in iter {
34 self.insert(weight);
35 }
36 }
37
38 pub fn contains(&mut self, sum: usize) -> bool {
39 self.rebuild();
40 if sum < self.dp.len() {
41 self.dp.get(sum)
42 } else {
43 false
44 }
45 }
46
47 fn rebuild(&mut self) {
48 if self.pending_weights.is_empty() {
49 return;
50 }
51 let mut heap = BinaryHeap::from(take(&mut self.pending_weights));
52 let (mut current_weight, mut count) = match heap.pop() {
53 Some(Reverse(w)) => (w, 1),
54 None => return,
55 };
56 while let Some(Reverse(weight)) = heap.pop() {
57 if weight == current_weight {
58 count += 1;
59 if count >= 3 {
60 if let Some(w) = current_weight.checked_mul(2) {
61 heap.push(Reverse(w));
62 }
63 count -= 2;
64 }
65 continue;
66 }
67 for _ in 0..count {
68 if self.size == !0 {
69 self.dp.resize(self.dp.len() + current_weight);
70 }
71 self.dp.shl_bitor_assign(current_weight);
72 }
73 (current_weight, count) = (weight, 1);
74 }
75 for _ in 0..count {
76 if self.size == !0 {
77 self.dp.resize(self.dp.len() + current_weight);
78 }
79 self.dp.shl_bitor_assign(current_weight);
80 }
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87 use crate::{rand, tools::Xorshift};
88
89 fn naive(weights: &[usize], limit: usize) -> Vec<bool> {
90 let mut dp = vec![false; limit + 1];
91 dp[0] = true;
92 for &w in weights {
93 for s in (w..=limit).rev() {
94 if dp[s - w] {
95 dp[s] = true;
96 }
97 }
98 }
99 dp
100 }
101
102 #[test]
103 fn test_subset_sum_problem() {
104 let mut rng = Xorshift::default();
105 for _ in 0..200 {
106 rand!(rng, n: 0..=10usize, limit: 0..=400usize, max_weight: 1..=100usize);
107 let mut ssp = SubsetSumProblem::new(limit);
108 let mut weights = vec![];
109 for _ in 0..n {
110 rand!(rng, w: 0..=max_weight, c: 0..=10);
111 for _ in 0..c {
112 weights.push(w);
113 ssp.insert(w);
114 }
115 }
116 let sum: usize = weights.iter().sum();
117 let expected = naive(&weights, sum + 2);
118 for (s, &expected) in expected.iter().enumerate() {
119 let expected = if s <= limit { expected } else { false };
120 assert_eq!(ssp.contains(s), expected);
121 }
122 let mut ssp = SubsetSumProblem::new(!0);
123 ssp.extend(weights.iter().cloned());
124 for (s, &expected) in expected.iter().enumerate() {
125 assert_eq!(ssp.contains(s), expected);
126 }
127 }
128 }
129}