competitive/data_structure/
bit_vector.rs

1use std::iter::FromIterator;
2
3/// rank_i(select_i(k)) = k
4/// rank_i(select_i(k) + 1) = k + 1
5pub trait RankSelectDictionaries {
6    fn bit_length(&self) -> usize;
7    /// get k-th bit
8    fn access(&self, k: usize) -> bool;
9    /// the number of 1 in [0, k)
10    fn rank1(&self, k: usize) -> usize {
11        (0..k).filter(|&i| self.access(i)).count()
12    }
13    /// the number of 0 in [0, k)
14    fn rank0(&self, k: usize) -> usize {
15        k - self.rank1(k)
16    }
17    /// index of k-th 1
18    fn select1(&self, k: usize) -> Option<usize> {
19        let n = self.bit_length();
20        if self.rank1(n) <= k {
21            return None;
22        }
23        let (mut l, mut r) = (0, n);
24        while r - l > 1 {
25            let m = (l + r) / 2;
26            if self.rank1(m) <= k {
27                l = m;
28            } else {
29                r = m;
30            }
31        }
32        Some(l)
33    }
34    /// index of k-th 0
35    fn select0(&self, k: usize) -> Option<usize> {
36        let n = self.bit_length();
37        if self.rank0(n) <= k {
38            return None;
39        }
40        let (mut l, mut r) = (0, n);
41        while r - l > 1 {
42            let m = (l + r) / 2;
43            if self.rank0(m) <= k {
44                l = m;
45            } else {
46                r = m;
47            }
48        }
49        Some(l)
50    }
51}
52macro_rules! impl_rank_select_for_bits {
53    ($($t:ty)*) => {$(
54        impl RankSelectDictionaries for $t {
55            fn bit_length(&self) -> usize {
56                const WORD_SIZE: usize = (0 as $t).count_zeros() as usize;
57                WORD_SIZE
58            }
59            fn access(&self, k: usize) -> bool {
60                const WORD_SIZE: usize = (0 as $t).count_zeros() as usize;
61                if k < WORD_SIZE {
62                    self & (1 as $t) << k != 0
63                } else {
64                    false
65                }
66            }
67            fn rank1(&self, k: usize) -> usize {
68                const WORD_SIZE: usize = (0 as $t).count_zeros() as usize;
69                if k < WORD_SIZE {
70                    (self & !(!(0 as $t) << k)).count_ones() as usize
71                } else {
72                    self.count_ones() as usize
73                }
74            }
75        })*
76    };
77}
78impl_rank_select_for_bits!(u8 u16 u32 u64 usize i8 i16 i32 i64 isize u128 i128);
79
80#[derive(Debug, Clone)]
81pub struct BitVector {
82    /// [(bit, sum)]
83    data: Vec<(usize, usize)>,
84    sum: usize,
85}
86impl BitVector {
87    const WORD_SIZE: usize = 0usize.count_zeros() as usize;
88}
89impl RankSelectDictionaries for BitVector {
90    fn bit_length(&self) -> usize {
91        self.data.len() * Self::WORD_SIZE
92    }
93    fn access(&self, k: usize) -> bool {
94        self.data[k / Self::WORD_SIZE].0 & (1 << (k % Self::WORD_SIZE)) != 0
95    }
96    fn rank1(&self, k: usize) -> usize {
97        let (bit, sum) = self.data[k / Self::WORD_SIZE];
98        sum + (bit & !(usize::MAX << (k % Self::WORD_SIZE))).count_ones() as usize
99    }
100    fn select1(&self, mut k: usize) -> Option<usize> {
101        let (mut l, mut r) = (0, self.data.len());
102        if self.sum <= k {
103            return None;
104        }
105        while r - l > 1 {
106            let m = (l + r) / 2;
107            if self.data[m].1 <= k {
108                l = m;
109            } else {
110                r = m;
111            }
112        }
113        let (bit, sum) = self.data[l];
114        k -= sum;
115        Some(l * Self::WORD_SIZE + bit.select1(k).unwrap())
116    }
117    fn select0(&self, mut k: usize) -> Option<usize> {
118        let (mut l, mut r) = (0, self.data.len());
119        if r * Self::WORD_SIZE - self.sum <= k {
120            return None;
121        }
122        while r - l > 1 {
123            let m = (l + r) / 2;
124            if m * Self::WORD_SIZE - self.data[m].1 <= k {
125                l = m;
126            } else {
127                r = m;
128            }
129        }
130        let (bit, sum) = self.data[l];
131        k -= l * Self::WORD_SIZE - sum;
132        Some(l * Self::WORD_SIZE + bit.select0(k).unwrap())
133    }
134}
135impl FromIterator<bool> for BitVector {
136    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
137        let mut iter = iter.into_iter();
138        let mut data = Vec::new();
139        let mut sum = 0;
140        'outer: loop {
141            let mut bit = 0;
142            let mut nsum = sum;
143            for i in 0..Self::WORD_SIZE {
144                if let Some(b) = iter.next() {
145                    if b {
146                        bit |= 1 << i;
147                        nsum += 1;
148                    }
149                } else {
150                    data.push((bit, sum));
151                    sum = nsum;
152                    break 'outer;
153                }
154            }
155            data.push((bit, sum));
156            sum = nsum;
157        }
158        Self { data, sum }
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use crate::tools::Xorshift;
166
167    const Q: usize = 5_000;
168
169    #[test]
170    fn test_rank_select_usize() {
171        const WORD_SIZE: usize = 0usize.count_zeros() as usize;
172        let mut rng = Xorshift::new();
173        for x in rng.random_iter(0u64..).take(Q) {
174            for k in 0..=WORD_SIZE {
175                assert_eq!(x.rank1(k), (0..k).filter(|&i| x.access(i)).count());
176                assert_eq!(x.rank0(k), (0..k).filter(|&i| !x.access(i)).count());
177                if let Some(i) = x.select1(k) {
178                    assert_eq!((0..i).filter(|&j| x.access(j)).count(), k);
179                    assert!(x.access(i));
180                } else {
181                    assert!(x.rank1(WORD_SIZE) <= k);
182                }
183                if let Some(i) = x.select0(k) {
184                    assert_eq!((0..i).filter(|&j| !x.access(j)).count(), k);
185                    assert!(!x.access(i));
186                } else {
187                    assert!(x.rank0(WORD_SIZE) <= k);
188                }
189            }
190        }
191    }
192
193    #[test]
194    fn test_rank_select_bit_vector() {
195        const N: usize = 1_000;
196        let mut rng = Xorshift::new();
197        let x: BitVector = (0..N).map(|_| rng.rand(2) != 0).collect();
198        for k in rng.random_iter(..=N).take(Q) {
199            assert_eq!(x.rank1(k), (0..k).filter(|&i| x.access(i)).count());
200            assert_eq!(x.rank0(k), (0..k).filter(|&i| !x.access(i)).count());
201
202            if let Some(i) = x.select1(k) {
203                assert_eq!((0..i).filter(|&j| x.access(j)).count(), k);
204                assert!(x.access(i));
205            } else {
206                assert!(x.rank1(N) <= k);
207            }
208
209            if let Some(i) = x.select0(k) {
210                assert_eq!((0..i).filter(|&j| !x.access(j)).count(), k);
211                assert!(!x.access(i));
212            } else {
213                assert!(x.rank0(N) <= k);
214            }
215        }
216        assert_eq!(x.rank1(0), 0);
217        assert_eq!(x.rank0(0), 0);
218    }
219}