competitive/data_structure/
bitset.rs

1#![allow(clippy::suspicious_op_assign_impl)]
2
3use std::ops::{
4    BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, ShlAssign, Shr,
5    ShrAssign,
6};
7
8#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
9pub struct BitSet {
10    size: usize,
11    bits: Vec<u64>,
12}
13
14impl BitSet {
15    pub fn new(size: usize) -> Self {
16        Self {
17            size,
18            bits: vec![0; (size + 63) / 64],
19        }
20    }
21
22    pub fn len(&self) -> usize {
23        self.size
24    }
25
26    pub fn is_empty(&self) -> bool {
27        self.size == 0
28    }
29
30    pub fn ones(size: usize) -> Self {
31        let mut self_ = Self {
32            size,
33            bits: vec![u64::MAX; (size + 63) / 64],
34        };
35        self_.trim();
36        self_
37    }
38
39    pub fn get(&self, i: usize) -> bool {
40        self.bits[i >> 6] & (1 << (i & 63)) != 0
41    }
42
43    pub fn set(&mut self, i: usize, b: bool) {
44        if b {
45            self.bits[i >> 6] |= 1 << (i & 63);
46        } else {
47            self.bits[i >> 6] &= !(1 << (i & 63));
48        }
49    }
50
51    pub fn count_ones(&self) -> u64 {
52        self.bits.iter().map(|x| x.count_ones() as u64).sum()
53    }
54
55    pub fn count_zeros(&self) -> u64 {
56        self.size as u64 - self.count_ones()
57    }
58
59    pub fn push(&mut self, b: bool) {
60        let d = self.size & 63;
61        if d == 0 {
62            self.bits.push(b as u64);
63        } else {
64            *self.bits.last_mut().unwrap() |= (b as u64) << d;
65        }
66        self.size += 1;
67    }
68
69    fn trim(&mut self) {
70        if self.size & 63 != 0 {
71            if let Some(x) = self.bits.last_mut() {
72                *x &= 0xffff_ffff_ffff_ffff >> (64 - (self.size & 63));
73            }
74        }
75    }
76
77    pub fn shl_bitor_assign(&mut self, rhs: usize) {
78        let n = self.bits.len();
79        let k = rhs >> 6;
80        let d = rhs & 63;
81        if k < n {
82            if d == 0 {
83                for i in (0..n - k).rev() {
84                    self.bits[i + k] |= self.bits[i];
85                }
86            } else {
87                for i in (1..n - k).rev() {
88                    self.bits[i + k] |= (self.bits[i] << d) | (self.bits[i - 1] >> (64 - d));
89                }
90                self.bits[k] |= self.bits[0] << d;
91            }
92            self.trim();
93        }
94    }
95
96    pub fn shr_bitor_assign(&mut self, rhs: usize) {
97        let n = self.bits.len();
98        let k = rhs >> 6;
99        let d = rhs & 63;
100        if k < n {
101            if d == 0 {
102                for i in k..n {
103                    self.bits[i - k] |= self.bits[i];
104                }
105            } else {
106                for i in k..n - 1 {
107                    self.bits[i - k] |= (self.bits[i] >> d) | (self.bits[i + 1] << (64 - d));
108                }
109                self.bits[n - k - 1] |= self.bits[n - 1] >> d;
110            }
111        }
112    }
113}
114
115impl Extend<bool> for BitSet {
116    fn extend<T: IntoIterator<Item = bool>>(&mut self, iter: T) {
117        let d = self.size & 63;
118        let mut iter = iter.into_iter();
119        let Some(first) = iter.next() else {
120            return;
121        };
122        if d == 0 {
123            self.bits.push(0);
124        }
125        let mut e = self.bits.last_mut().unwrap();
126        *e |= (first as u64) << d;
127        self.size += 1;
128        for b in iter {
129            let d = self.size & 63;
130            if d == 0 {
131                self.bits.push(b as u64);
132                e = self.bits.last_mut().unwrap();
133            } else {
134                *e |= (b as u64) << d;
135            }
136            self.size += 1;
137        }
138    }
139}
140
141impl FromIterator<bool> for BitSet {
142    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
143        let mut set = BitSet::new(0);
144        set.extend(iter);
145        set
146    }
147}
148
149impl ShlAssign<usize> for BitSet {
150    fn shl_assign(&mut self, rhs: usize) {
151        let n = self.bits.len();
152        let k = rhs >> 6;
153        let d = rhs & 63;
154        if k >= n {
155            for x in self.bits.iter_mut() {
156                *x = 0;
157            }
158        } else {
159            if d == 0 {
160                for i in (0..n - k).rev() {
161                    self.bits[i + k] = self.bits[i];
162                }
163            } else {
164                for i in (1..n - k).rev() {
165                    self.bits[i + k] = (self.bits[i] << d) | (self.bits[i - 1] >> (64 - d));
166                }
167                self.bits[k] = self.bits[0] << d;
168            }
169            for x in self.bits[..k].iter_mut() {
170                *x = 0;
171            }
172            self.trim();
173        }
174    }
175}
176
177impl Shl<usize> for BitSet {
178    type Output = Self;
179    fn shl(mut self, rhs: usize) -> Self::Output {
180        self <<= rhs;
181        self
182    }
183}
184
185impl ShrAssign<usize> for BitSet {
186    fn shr_assign(&mut self, rhs: usize) {
187        let n = self.bits.len();
188        let k = rhs >> 6;
189        let d = rhs & 63;
190        if k >= n {
191            for x in self.bits.iter_mut() {
192                *x = 0;
193            }
194        } else {
195            if d == 0 {
196                for i in k..n {
197                    self.bits[i - k] = self.bits[i];
198                }
199            } else {
200                for i in k..n - 1 {
201                    self.bits[i - k] = (self.bits[i] >> d) | (self.bits[i + 1] << (64 - d));
202                }
203                self.bits[n - k - 1] = self.bits[n - 1] >> d;
204            }
205            for x in self.bits[n - k..].iter_mut() {
206                *x = 0;
207            }
208        }
209    }
210}
211
212impl Shr<usize> for BitSet {
213    type Output = Self;
214    fn shr(mut self, rhs: usize) -> Self::Output {
215        self >>= rhs;
216        self
217    }
218}
219
220impl<'a> BitOrAssign<&'a BitSet> for BitSet {
221    fn bitor_assign(&mut self, rhs: &'a Self) {
222        for (l, r) in self.bits.iter_mut().zip(rhs.bits.iter()) {
223            *l |= *r;
224        }
225        self.trim();
226    }
227}
228
229impl<'a> BitOr<&'a BitSet> for BitSet {
230    type Output = Self;
231    fn bitor(mut self, rhs: &'a Self) -> Self::Output {
232        self |= rhs;
233        self
234    }
235}
236
237impl<'b> BitOr<&'b BitSet> for &BitSet {
238    type Output = BitSet;
239    fn bitor(self, rhs: &'b BitSet) -> Self::Output {
240        let mut res = self.clone();
241        res |= rhs;
242        res
243    }
244}
245
246impl<'a> BitAndAssign<&'a BitSet> for BitSet {
247    fn bitand_assign(&mut self, rhs: &'a Self) {
248        for (l, r) in self.bits.iter_mut().zip(rhs.bits.iter()) {
249            *l &= *r;
250        }
251    }
252}
253
254impl<'a> BitAnd<&'a BitSet> for BitSet {
255    type Output = Self;
256    fn bitand(mut self, rhs: &'a Self) -> Self::Output {
257        self &= rhs;
258        self
259    }
260}
261
262impl<'b> BitAnd<&'b BitSet> for &BitSet {
263    type Output = BitSet;
264    fn bitand(self, rhs: &'b BitSet) -> Self::Output {
265        let mut res = self.clone();
266        res &= rhs;
267        res
268    }
269}
270
271impl<'a> BitXorAssign<&'a BitSet> for BitSet {
272    fn bitxor_assign(&mut self, rhs: &'a Self) {
273        for (l, r) in self.bits.iter_mut().zip(rhs.bits.iter()) {
274            *l ^= *r;
275        }
276        self.trim();
277    }
278}
279
280impl<'a> BitXor<&'a BitSet> for BitSet {
281    type Output = Self;
282    fn bitxor(mut self, rhs: &'a Self) -> Self::Output {
283        self ^= rhs;
284        self
285    }
286}
287
288impl<'b> BitXor<&'b BitSet> for &BitSet {
289    type Output = BitSet;
290    fn bitxor(self, rhs: &'b BitSet) -> Self::Output {
291        let mut res = self.clone();
292        res ^= rhs;
293        res
294    }
295}
296
297impl Not for BitSet {
298    type Output = Self;
299    fn not(mut self) -> Self::Output {
300        for x in self.bits.iter_mut() {
301            *x = !*x;
302        }
303        self.trim();
304        self
305    }
306}
307
308impl Not for &BitSet {
309    type Output = BitSet;
310    fn not(self) -> Self::Output {
311        !self.clone()
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use crate::{rand, tools::Xorshift};
319
320    #[test]
321    fn test_access() {
322        for _ in 0..100 {
323            let mut rng = Xorshift::new();
324            rand!(rng, n: 1..=200);
325            let mut bitset = BitSet::new(n);
326            let mut arr = vec![false; n];
327            for _ in 0..200 {
328                rand!(rng, i: 0..n, b: 0..=1u32);
329                bitset.set(i, b != 0);
330                arr[i] = b != 0;
331                assert_eq!(bitset.get(i), arr[i]);
332            }
333            assert_eq!(
334                bitset.count_ones(),
335                arr.iter().filter(|&&x| x).count() as u64
336            );
337            assert_eq!(
338                bitset.count_zeros(),
339                arr.iter().filter(|&&x| !x).count() as u64
340            );
341        }
342    }
343
344    #[test]
345    fn test_push() {
346        for _ in 0..100 {
347            let mut rng = Xorshift::new();
348            rand!(rng, n: 0..=200, arr: [0..=1u32; n]);
349            let mut bitset = BitSet::new(0);
350            for &x in &arr {
351                bitset.push(x != 0);
352            }
353            assert_eq!(bitset.len(), n);
354            for (i, &x) in arr.iter().enumerate() {
355                assert_eq!(bitset.get(i), x != 0);
356            }
357        }
358    }
359
360    #[test]
361    fn test_shl_bitor_assign() {
362        for _ in 0..100 {
363            let mut rng = Xorshift::new();
364            rand!(rng, n: 1..=200, k: 1..=n, mut arr: [0..=1u32; n]);
365            let mut bitset: BitSet = arr.iter().map(|&x| x != 0).collect();
366            bitset.shl_bitor_assign(k);
367            for i in (k..n).rev() {
368                arr[i] |= arr[i - k];
369            }
370            assert_eq!(bitset, BitSet::from_iter(arr.iter().map(|&x| x != 0)));
371        }
372    }
373
374    #[test]
375    fn test_shr_bitor_assign() {
376        for _ in 0..100 {
377            let mut rng = Xorshift::new();
378            rand!(rng, n: 1..=200, k: 1..=n, mut arr: [0..=1u32; n]);
379            let mut bitset: BitSet = arr.iter().map(|&x| x != 0).collect();
380            bitset.shr_bitor_assign(k);
381            for i in k..n {
382                arr[i - k] |= arr[i];
383            }
384            assert_eq!(bitset, BitSet::from_iter(arr.iter().map(|&x| x != 0)));
385        }
386    }
387
388    #[test]
389    fn test_shl() {
390        for _ in 0..100 {
391            let mut rng = Xorshift::new();
392            rand!(rng, n: 1..=200, k: 1..=n, arr: [0..=1u32; n]);
393            let mut bitset: BitSet = arr.iter().map(|&x| x != 0).collect();
394            bitset <<= k;
395            let mut arr2 = vec![0; n];
396            for i in (k..n).rev() {
397                arr2[i] = arr[i - k];
398            }
399            assert_eq!(bitset, BitSet::from_iter(arr2.iter().map(|&x| x != 0)));
400        }
401    }
402
403    #[test]
404    fn test_shr() {
405        for _ in 0..100 {
406            let mut rng = Xorshift::new();
407            rand!(rng, n: 1..=200, k: 1..=n, arr: [0..=1u32; n]);
408            let mut bitset: BitSet = arr.iter().map(|&x| x != 0).collect();
409            bitset >>= k;
410            let mut arr2 = vec![0; n];
411            for (i, &a) in arr.iter().enumerate().skip(k) {
412                arr2[i - k] = a;
413            }
414            assert_eq!(bitset, BitSet::from_iter(arr2.iter().map(|&x| x != 0)));
415        }
416    }
417
418    #[test]
419    fn test_extend() {
420        for _ in 0..100 {
421            let mut rng = Xorshift::new();
422            rand!(rng, arr: [0..=1u32; 200], n1: 0..=200);
423            let mut bitset: BitSet = arr[..n1].iter().map(|&x| x != 0).collect();
424            bitset.extend(arr[n1..].iter().map(|&x| x != 0));
425            assert_eq!(bitset.len(), 200);
426            for (i, &x) in arr.iter().enumerate() {
427                assert_eq!(bitset.get(i), x != 0);
428            }
429        }
430    }
431}