Skip to main content

competitive/data_structure/
bit_vector.rs

1use std::iter::FromIterator;
2
3/// rank_i(select_i(k)) = k
4/// rank_i(select_i(k) + 1) = k + 1
5pub trait RankSelectDictionaries {
6    fn bit_length(&self) -> usize;
7    /// get k-th bit
8    fn access(&self, k: usize) -> bool;
9    /// the number of 1 in [0, k)
10    fn rank1(&self, k: usize) -> usize {
11        (0..k).filter(|&i| self.access(i)).count()
12    }
13    /// the number of 0 in [0, k)
14    fn rank0(&self, k: usize) -> usize {
15        k - self.rank1(k)
16    }
17    /// index of k-th 1
18    fn select1(&self, k: usize) -> Option<usize> {
19        let n = self.bit_length();
20        if self.rank1(n) <= k {
21            return None;
22        }
23        let (mut l, mut r) = (0, n);
24        while r - l > 1 {
25            let m = l.midpoint(r);
26            if self.rank1(m) <= k {
27                l = m;
28            } else {
29                r = m;
30            }
31        }
32        Some(l)
33    }
34    /// index of k-th 0
35    fn select0(&self, k: usize) -> Option<usize> {
36        let n = self.bit_length();
37        if self.rank0(n) <= k {
38            return None;
39        }
40        let (mut l, mut r) = (0, n);
41        while r - l > 1 {
42            let m = l.midpoint(r);
43            if self.rank0(m) <= k {
44                l = m;
45            } else {
46                r = m;
47            }
48        }
49        Some(l)
50    }
51}
52macro_rules! impl_rank_select_for_bits {
53    ($($t:ty)*) => {$(
54        impl RankSelectDictionaries for $t {
55            fn bit_length(&self) -> usize {
56                const WORD_SIZE: usize = (0 as $t).count_zeros() as usize;
57                WORD_SIZE
58            }
59            fn access(&self, k: usize) -> bool {
60                const WORD_SIZE: usize = (0 as $t).count_zeros() as usize;
61                if k < WORD_SIZE {
62                    self & (1 as $t) << k != 0
63                } else {
64                    false
65                }
66            }
67            fn rank1(&self, k: usize) -> usize {
68                const WORD_SIZE: usize = (0 as $t).count_zeros() as usize;
69                if k < WORD_SIZE {
70                    (self & !(!(0 as $t) << k)).count_ones() as usize
71                } else {
72                    self.count_ones() as usize
73                }
74            }
75        })*
76    };
77}
78impl_rank_select_for_bits!(u8 u16 u32 u64 usize i8 i16 i32 i64 isize u128 i128);
79
80#[derive(Debug, Clone)]
81pub struct BitVector {
82    /// [(bit, sum)]
83    data: Vec<(usize, usize)>,
84    sum: usize,
85    len: usize,
86}
87impl BitVector {
88    const WORD_SIZE: usize = 0usize.count_zeros() as usize;
89
90    pub fn with_capacity(bits: usize) -> Self {
91        let words = bits.div_ceil(Self::WORD_SIZE) + 1;
92        let mut data = Vec::with_capacity(words);
93        data.push((0, 0));
94        Self {
95            data,
96            sum: 0,
97            len: 0,
98        }
99    }
100
101    pub fn push(&mut self, bit: bool) {
102        let word = self.len / Self::WORD_SIZE;
103        let offset = self.len % Self::WORD_SIZE;
104        if word == self.data.len() - 1 {
105            self.data.push((0, self.sum));
106        }
107        if bit {
108            self.data[word].0 |= 1 << offset;
109            self.sum += 1;
110        }
111        self.len += 1;
112        self.data.last_mut().unwrap().1 = self.sum;
113    }
114}
115impl RankSelectDictionaries for BitVector {
116    fn bit_length(&self) -> usize {
117        self.len
118    }
119    fn access(&self, k: usize) -> bool {
120        debug_assert!(k < self.len);
121        self.data[k / Self::WORD_SIZE].0 & (1 << (k % Self::WORD_SIZE)) != 0
122    }
123    fn rank1(&self, k: usize) -> usize {
124        debug_assert!(k <= self.len);
125        let (bit, sum) = self.data[k / Self::WORD_SIZE];
126        sum + (bit & !(usize::MAX << (k % Self::WORD_SIZE))).count_ones() as usize
127    }
128    fn select1(&self, mut k: usize) -> Option<usize> {
129        let (mut l, mut r) = (0, self.data.len());
130        if self.sum <= k {
131            return None;
132        }
133        while r - l > 1 {
134            let m = l.midpoint(r);
135            if self.data[m].1 <= k {
136                l = m;
137            } else {
138                r = m;
139            }
140        }
141        let (bit, sum) = self.data[l];
142        k -= sum;
143        Some(l * Self::WORD_SIZE + bit.select1(k).unwrap())
144    }
145    fn select0(&self, mut k: usize) -> Option<usize> {
146        let (mut l, mut r) = (0, self.data.len());
147        if self.len - self.sum <= k {
148            return None;
149        }
150        while r - l > 1 {
151            let m = l.midpoint(r);
152            if m * Self::WORD_SIZE - self.data[m].1 <= k {
153                l = m;
154            } else {
155                r = m;
156            }
157        }
158        let (bit, sum) = self.data[l];
159        k -= l * Self::WORD_SIZE - sum;
160        Some(l * Self::WORD_SIZE + bit.select0(k).unwrap())
161    }
162}
163impl FromIterator<bool> for BitVector {
164    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
165        let iter = iter.into_iter();
166        let (lower, upper) = iter.size_hint();
167        let mut bit_vector = match upper {
168            Some(upper) => Self::with_capacity(upper),
169            None => Self::with_capacity(lower),
170        };
171        for b in iter {
172            bit_vector.push(b);
173        }
174        bit_vector
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use crate::tools::Xorshift;
182
183    const Q: usize = 5_000;
184
185    #[test]
186    fn test_rank_select_usize() {
187        const WORD_SIZE: usize = 0usize.count_zeros() as usize;
188        let mut rng = Xorshift::default();
189        for x in rng.random_iter(0u64..).take(Q) {
190            for k in 0..=WORD_SIZE {
191                assert_eq!(x.rank1(k), (0..k).filter(|&i| x.access(i)).count());
192                assert_eq!(x.rank0(k), (0..k).filter(|&i| !x.access(i)).count());
193                if let Some(i) = x.select1(k) {
194                    assert_eq!((0..i).filter(|&j| x.access(j)).count(), k);
195                    assert!(x.access(i));
196                } else {
197                    assert!(x.rank1(WORD_SIZE) <= k);
198                }
199                if let Some(i) = x.select0(k) {
200                    assert_eq!((0..i).filter(|&j| !x.access(j)).count(), k);
201                    assert!(!x.access(i));
202                } else {
203                    assert!(x.rank0(WORD_SIZE) <= k);
204                }
205            }
206        }
207    }
208
209    #[test]
210    fn test_rank_select_bit_vector() {
211        const N: usize = 1_000;
212        let mut rng = Xorshift::default();
213        let x: BitVector = (0..N).map(|_| rng.rand(2) != 0).collect();
214        for k in rng.random_iter(..=N).take(Q) {
215            assert_eq!(x.rank1(k), (0..k).filter(|&i| x.access(i)).count());
216            assert_eq!(x.rank0(k), (0..k).filter(|&i| !x.access(i)).count());
217
218            if let Some(i) = x.select1(k) {
219                assert_eq!((0..i).filter(|&j| x.access(j)).count(), k);
220                assert!(x.access(i));
221            } else {
222                assert!(x.rank1(N) <= k);
223            }
224
225            if let Some(i) = x.select0(k) {
226                assert_eq!((0..i).filter(|&j| !x.access(j)).count(), k);
227                assert!(!x.access(i));
228            } else {
229                assert!(x.rank0(N) <= k);
230            }
231        }
232        assert_eq!(x.rank1(0), 0);
233        assert_eq!(x.rank0(0), 0);
234    }
235}