competitive/combinatorial_optimization/
subset_sum_problem.rs

1use super::BitSet;
2use std::{cmp::Reverse, collections::BinaryHeap, mem::take};
3
4#[derive(Debug, Clone)]
5pub struct SubsetSumProblem {
6    size: usize,
7    dp: BitSet,
8    pending_weights: Vec<Reverse<usize>>,
9}
10
11impl SubsetSumProblem {
12    pub fn new(size: usize) -> Self {
13        let mut dp = BitSet::new(if size == !0 { 0 } else { size } + 1);
14        dp.set(0, true);
15        Self {
16            size,
17            dp,
18            pending_weights: vec![],
19        }
20    }
21
22    pub fn insert(&mut self, weight: usize) {
23        if weight == 0 || self.size < weight {
24            return;
25        }
26        self.pending_weights.push(Reverse(weight));
27    }
28
29    pub fn extend<I>(&mut self, iter: I)
30    where
31        I: IntoIterator<Item = usize>,
32    {
33        for weight in iter {
34            self.insert(weight);
35        }
36    }
37
38    pub fn contains(&mut self, sum: usize) -> bool {
39        self.rebuild();
40        if sum < self.dp.len() {
41            self.dp.get(sum)
42        } else {
43            false
44        }
45    }
46
47    fn rebuild(&mut self) {
48        if self.pending_weights.is_empty() {
49            return;
50        }
51        let mut heap = BinaryHeap::from(take(&mut self.pending_weights));
52        let (mut current_weight, mut count) = match heap.pop() {
53            Some(Reverse(w)) => (w, 1),
54            None => return,
55        };
56        while let Some(Reverse(weight)) = heap.pop() {
57            if weight == current_weight {
58                count += 1;
59                if count >= 3 {
60                    if let Some(w) = current_weight.checked_mul(2) {
61                        heap.push(Reverse(w));
62                    }
63                    count -= 2;
64                }
65                continue;
66            }
67            for _ in 0..count {
68                if self.size == !0 {
69                    self.dp.resize(self.dp.len() + current_weight);
70                }
71                self.dp.shl_bitor_assign(current_weight);
72            }
73            (current_weight, count) = (weight, 1);
74        }
75        for _ in 0..count {
76            if self.size == !0 {
77                self.dp.resize(self.dp.len() + current_weight);
78            }
79            self.dp.shl_bitor_assign(current_weight);
80        }
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use crate::{rand, tools::Xorshift};
88
89    fn naive(weights: &[usize], limit: usize) -> Vec<bool> {
90        let mut dp = vec![false; limit + 1];
91        dp[0] = true;
92        for &w in weights {
93            for s in (w..=limit).rev() {
94                if dp[s - w] {
95                    dp[s] = true;
96                }
97            }
98        }
99        dp
100    }
101
102    #[test]
103    fn test_subset_sum_problem() {
104        let mut rng = Xorshift::default();
105        for _ in 0..200 {
106            rand!(rng, n: 0..=10usize, limit: 0..=400usize, max_weight: 1..=100usize);
107            let mut ssp = SubsetSumProblem::new(limit);
108            let mut weights = vec![];
109            for _ in 0..n {
110                rand!(rng, w: 0..=max_weight, c: 0..=10);
111                for _ in 0..c {
112                    weights.push(w);
113                    ssp.insert(w);
114                }
115            }
116            let sum: usize = weights.iter().sum();
117            let expected = naive(&weights, sum + 2);
118            for (s, &expected) in expected.iter().enumerate() {
119                let expected = if s <= limit { expected } else { false };
120                assert_eq!(ssp.contains(s), expected);
121            }
122            let mut ssp = SubsetSumProblem::new(!0);
123            ssp.extend(weights.iter().cloned());
124            for (s, &expected) in expected.iter().enumerate() {
125                assert_eq!(ssp.contains(s), expected);
126            }
127        }
128    }
129}