competitive/data_structure/
bitset.rs

1use std::{
2    cmp::Ordering,
3    ops::{
4        BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, ShlAssign, Shr,
5        ShrAssign,
6    },
7};
8
9#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
10pub struct BitSet {
11    size: usize,
12    bits: Vec<u64>,
13}
14
15impl BitSet {
16    pub fn new(size: usize) -> Self {
17        Self {
18            size,
19            bits: vec![0; size.div_ceil(64)],
20        }
21    }
22
23    pub fn len(&self) -> usize {
24        self.size
25    }
26
27    pub fn is_empty(&self) -> bool {
28        self.size == 0
29    }
30
31    pub fn ones(size: usize) -> Self {
32        let mut self_ = Self {
33            size,
34            bits: vec![u64::MAX; size.div_ceil(64)],
35        };
36        self_.trim();
37        self_
38    }
39
40    pub fn get(&self, i: usize) -> bool {
41        self.bits[i >> 6] & (1 << (i & 63)) != 0
42    }
43
44    pub fn set(&mut self, i: usize, b: bool) {
45        if b {
46            self.bits[i >> 6] |= 1 << (i & 63);
47        } else {
48            self.bits[i >> 6] &= !(1 << (i & 63));
49        }
50    }
51
52    pub fn count_ones(&self) -> u64 {
53        self.bits.iter().map(|x| x.count_ones() as u64).sum()
54    }
55
56    pub fn count_zeros(&self) -> u64 {
57        self.size as u64 - self.count_ones()
58    }
59
60    pub fn push(&mut self, b: bool) {
61        let d = self.size & 63;
62        if d == 0 {
63            self.bits.push(b as u64);
64        } else {
65            *self.bits.last_mut().unwrap() |= (b as u64) << d;
66        }
67        self.size += 1;
68    }
69
70    pub fn resize(&mut self, new_size: usize) {
71        match self.size.cmp(&new_size) {
72            Ordering::Less => self.bits.resize(new_size.div_ceil(64), 0),
73            Ordering::Equal => {}
74            Ordering::Greater => self.bits.truncate(new_size.div_ceil(64)),
75        }
76        self.size = new_size;
77        self.trim();
78    }
79
80    fn trim(&mut self) {
81        if self.size & 63 != 0
82            && let Some(x) = self.bits.last_mut()
83        {
84            *x &= 0xffff_ffff_ffff_ffff >> (64 - (self.size & 63));
85        }
86    }
87
88    pub fn shl_bitor_assign(&mut self, rhs: usize) {
89        let n = self.bits.len();
90        let k = rhs >> 6;
91        let d = rhs & 63;
92        if k < n {
93            if d == 0 {
94                for i in (0..n - k).rev() {
95                    self.bits[i + k] |= self.bits[i];
96                }
97            } else {
98                for i in (1..n - k).rev() {
99                    self.bits[i + k] |= (self.bits[i] << d) | (self.bits[i - 1] >> (64 - d));
100                }
101                self.bits[k] |= self.bits[0] << d;
102            }
103            self.trim();
104        }
105    }
106
107    pub fn shr_bitor_assign(&mut self, rhs: usize) {
108        let n = self.bits.len();
109        let k = rhs >> 6;
110        let d = rhs & 63;
111        if k < n {
112            if d == 0 {
113                for i in k..n {
114                    self.bits[i - k] |= self.bits[i];
115                }
116            } else {
117                for i in k..n - 1 {
118                    self.bits[i - k] |= (self.bits[i] >> d) | (self.bits[i + 1] << (64 - d));
119                }
120                self.bits[n - k - 1] |= self.bits[n - 1] >> d;
121            }
122        }
123    }
124}
125
126impl Extend<bool> for BitSet {
127    fn extend<T: IntoIterator<Item = bool>>(&mut self, iter: T) {
128        let d = self.size & 63;
129        let mut iter = iter.into_iter();
130        let Some(first) = iter.next() else {
131            return;
132        };
133        if d == 0 {
134            self.bits.push(0);
135        }
136        let mut e = self.bits.last_mut().unwrap();
137        *e |= (first as u64) << d;
138        self.size += 1;
139        for b in iter {
140            let d = self.size & 63;
141            if d == 0 {
142                self.bits.push(b as u64);
143                e = self.bits.last_mut().unwrap();
144            } else {
145                *e |= (b as u64) << d;
146            }
147            self.size += 1;
148        }
149    }
150}
151
152impl FromIterator<bool> for BitSet {
153    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
154        let mut set = BitSet::new(0);
155        set.extend(iter);
156        set
157    }
158}
159
160impl ShlAssign<usize> for BitSet {
161    fn shl_assign(&mut self, rhs: usize) {
162        let n = self.bits.len();
163        let k = rhs >> 6;
164        let d = rhs & 63;
165        if k >= n {
166            for x in self.bits.iter_mut() {
167                *x = 0;
168            }
169        } else {
170            if d == 0 {
171                for i in (0..n - k).rev() {
172                    self.bits[i + k] = self.bits[i];
173                }
174            } else {
175                for i in (1..n - k).rev() {
176                    self.bits[i + k] = (self.bits[i] << d) | (self.bits[i - 1] >> (64 - d));
177                }
178                self.bits[k] = self.bits[0] << d;
179            }
180            for x in self.bits[..k].iter_mut() {
181                *x = 0;
182            }
183            self.trim();
184        }
185    }
186}
187
188impl Shl<usize> for BitSet {
189    type Output = Self;
190    fn shl(mut self, rhs: usize) -> Self::Output {
191        self <<= rhs;
192        self
193    }
194}
195
196impl ShrAssign<usize> for BitSet {
197    fn shr_assign(&mut self, rhs: usize) {
198        let n = self.bits.len();
199        let k = rhs >> 6;
200        let d = rhs & 63;
201        if k >= n {
202            for x in self.bits.iter_mut() {
203                *x = 0;
204            }
205        } else {
206            if d == 0 {
207                for i in k..n {
208                    self.bits[i - k] = self.bits[i];
209                }
210            } else {
211                for i in k..n - 1 {
212                    self.bits[i - k] = (self.bits[i] >> d) | (self.bits[i + 1] << (64 - d));
213                }
214                self.bits[n - k - 1] = self.bits[n - 1] >> d;
215            }
216            for x in self.bits[n - k..].iter_mut() {
217                *x = 0;
218            }
219        }
220    }
221}
222
223impl Shr<usize> for BitSet {
224    type Output = Self;
225    fn shr(mut self, rhs: usize) -> Self::Output {
226        self >>= rhs;
227        self
228    }
229}
230
231impl<'a> BitOrAssign<&'a BitSet> for BitSet {
232    fn bitor_assign(&mut self, rhs: &'a Self) {
233        for (l, r) in self.bits.iter_mut().zip(rhs.bits.iter()) {
234            *l |= *r;
235        }
236        self.trim();
237    }
238}
239
240impl<'a> BitOr<&'a BitSet> for BitSet {
241    type Output = Self;
242    fn bitor(mut self, rhs: &'a Self) -> Self::Output {
243        self |= rhs;
244        self
245    }
246}
247
248impl<'b> BitOr<&'b BitSet> for &BitSet {
249    type Output = BitSet;
250    fn bitor(self, rhs: &'b BitSet) -> Self::Output {
251        let mut res = self.clone();
252        res |= rhs;
253        res
254    }
255}
256
257impl<'a> BitAndAssign<&'a BitSet> for BitSet {
258    fn bitand_assign(&mut self, rhs: &'a Self) {
259        for (l, r) in self.bits.iter_mut().zip(rhs.bits.iter()) {
260            *l &= *r;
261        }
262    }
263}
264
265impl<'a> BitAnd<&'a BitSet> for BitSet {
266    type Output = Self;
267    fn bitand(mut self, rhs: &'a Self) -> Self::Output {
268        self &= rhs;
269        self
270    }
271}
272
273impl<'b> BitAnd<&'b BitSet> for &BitSet {
274    type Output = BitSet;
275    fn bitand(self, rhs: &'b BitSet) -> Self::Output {
276        let mut res = self.clone();
277        res &= rhs;
278        res
279    }
280}
281
282impl<'a> BitXorAssign<&'a BitSet> for BitSet {
283    fn bitxor_assign(&mut self, rhs: &'a Self) {
284        for (l, r) in self.bits.iter_mut().zip(rhs.bits.iter()) {
285            *l ^= *r;
286        }
287        self.trim();
288    }
289}
290
291impl<'a> BitXor<&'a BitSet> for BitSet {
292    type Output = Self;
293    fn bitxor(mut self, rhs: &'a Self) -> Self::Output {
294        self ^= rhs;
295        self
296    }
297}
298
299impl<'b> BitXor<&'b BitSet> for &BitSet {
300    type Output = BitSet;
301    fn bitxor(self, rhs: &'b BitSet) -> Self::Output {
302        let mut res = self.clone();
303        res ^= rhs;
304        res
305    }
306}
307
308impl Not for BitSet {
309    type Output = Self;
310    fn not(mut self) -> Self::Output {
311        for x in self.bits.iter_mut() {
312            *x = !*x;
313        }
314        self.trim();
315        self
316    }
317}
318
319impl Not for &BitSet {
320    type Output = BitSet;
321    fn not(self) -> Self::Output {
322        !self.clone()
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use crate::{rand, tools::Xorshift};
330
331    #[test]
332    fn test_access() {
333        for _ in 0..100 {
334            let mut rng = Xorshift::default();
335            rand!(rng, n: 1..=200);
336            let mut bitset = BitSet::new(n);
337            let mut arr = vec![false; n];
338            for _ in 0..200 {
339                rand!(rng, i: 0..n, b: 0..=1u32);
340                bitset.set(i, b != 0);
341                arr[i] = b != 0;
342                assert_eq!(bitset.get(i), arr[i]);
343            }
344            assert_eq!(
345                bitset.count_ones(),
346                arr.iter().filter(|&&x| x).count() as u64
347            );
348            assert_eq!(
349                bitset.count_zeros(),
350                arr.iter().filter(|&&x| !x).count() as u64
351            );
352        }
353    }
354
355    #[test]
356    fn test_push() {
357        for _ in 0..100 {
358            let mut rng = Xorshift::default();
359            rand!(rng, n: 0..=200, arr: [0..=1u32; n]);
360            let mut bitset = BitSet::new(0);
361            for &x in &arr {
362                bitset.push(x != 0);
363            }
364            assert_eq!(bitset.len(), n);
365            for (i, &x) in arr.iter().enumerate() {
366                assert_eq!(bitset.get(i), x != 0);
367            }
368        }
369    }
370
371    #[test]
372    fn test_resize() {
373        for _ in 0..100 {
374            let mut rng = Xorshift::default();
375            rand!(rng, n: 1..=200, m: 0..=400, arr: [0..=1u32; n]);
376            let mut bitset: BitSet = arr.iter().map(|&x| x != 0).collect();
377            bitset.resize(m);
378            assert_eq!(bitset.len(), m);
379            for (i, &x) in arr.iter().enumerate() {
380                if i < n {
381                    assert_eq!(bitset.get(i), x != 0);
382                } else {
383                    assert!(!bitset.get(i));
384                }
385            }
386        }
387    }
388
389    #[test]
390    fn test_shl_bitor_assign() {
391        for _ in 0..100 {
392            let mut rng = Xorshift::default();
393            rand!(rng, n: 1..=200, k: 1..=n, mut arr: [0..=1u32; n]);
394            let mut bitset: BitSet = arr.iter().map(|&x| x != 0).collect();
395            bitset.shl_bitor_assign(k);
396            for i in (k..n).rev() {
397                arr[i] |= arr[i - k];
398            }
399            assert_eq!(bitset, BitSet::from_iter(arr.iter().map(|&x| x != 0)));
400        }
401    }
402
403    #[test]
404    fn test_shr_bitor_assign() {
405        for _ in 0..100 {
406            let mut rng = Xorshift::default();
407            rand!(rng, n: 1..=200, k: 1..=n, mut arr: [0..=1u32; n]);
408            let mut bitset: BitSet = arr.iter().map(|&x| x != 0).collect();
409            bitset.shr_bitor_assign(k);
410            for i in k..n {
411                arr[i - k] |= arr[i];
412            }
413            assert_eq!(bitset, BitSet::from_iter(arr.iter().map(|&x| x != 0)));
414        }
415    }
416
417    #[test]
418    fn test_shl() {
419        for _ in 0..100 {
420            let mut rng = Xorshift::default();
421            rand!(rng, n: 1..=200, k: 1..=n, arr: [0..=1u32; n]);
422            let mut bitset: BitSet = arr.iter().map(|&x| x != 0).collect();
423            bitset <<= k;
424            let mut arr2 = vec![0; n];
425            for i in (k..n).rev() {
426                arr2[i] = arr[i - k];
427            }
428            assert_eq!(bitset, BitSet::from_iter(arr2.iter().map(|&x| x != 0)));
429        }
430    }
431
432    #[test]
433    fn test_shr() {
434        for _ in 0..100 {
435            let mut rng = Xorshift::default();
436            rand!(rng, n: 1..=200, k: 1..=n, arr: [0..=1u32; n]);
437            let mut bitset: BitSet = arr.iter().map(|&x| x != 0).collect();
438            bitset >>= k;
439            let mut arr2 = vec![0; n];
440            for (i, &a) in arr.iter().enumerate().skip(k) {
441                arr2[i - k] = a;
442            }
443            assert_eq!(bitset, BitSet::from_iter(arr2.iter().map(|&x| x != 0)));
444        }
445    }
446
447    #[test]
448    fn test_extend() {
449        for _ in 0..100 {
450            let mut rng = Xorshift::default();
451            rand!(rng, arr: [0..=1u32; 200], n1: 0..=200);
452            let mut bitset: BitSet = arr[..n1].iter().map(|&x| x != 0).collect();
453            bitset.extend(arr[n1..].iter().map(|&x| x != 0));
454            assert_eq!(bitset.len(), 200);
455            for (i, &x) in arr.iter().enumerate() {
456                assert_eq!(bitset.get(i), x != 0);
457            }
458        }
459    }
460}