competitive/algorithm/
binary_search.rs

1use std::cmp::Ordering;
2
3/// binary search helper
4pub trait Bisect: Clone {
5    /// Return between two elements if search is not end.
6    fn bisect_middle_point(&self, other: &Self) -> Option<Self>;
7}
8
9macro_rules! impl_bisect_unsigned {
10    ($($t:ty)*) => {
11        $(impl Bisect for $t {
12            fn bisect_middle_point(&self, other: &Self) -> Option<Self> {
13                if self.abs_diff(*other) > 1 { Some(self.midpoint(*other)) } else { None }
14            }
15        })*
16    };
17}
18macro_rules! impl_bisect_signed {
19    ($($t:ty)*) => {
20        $(impl Bisect for $t {
21            fn bisect_middle_point(&self, other: &Self) -> Option<Self> {
22                if self.signum() != other.signum() {
23                    if match self.cmp(other) {
24                        Ordering::Less => self + 1 < *other,
25                        Ordering::Equal => false,
26                        Ordering::Greater => other + 1 < *self,
27                    } {
28                        Some((*self).midpoint(*other))
29                    } else {
30                        None
31                    }
32                } else {
33                    if self.abs_diff(*other) > 1 { Some(self.midpoint(*other)) } else { None }
34                }
35            }
36        })*
37    };
38}
39macro_rules! impl_bisect_float {
40    ($({$t:ident $u:ident $i:ident $e:expr})*) => {
41        $(impl Bisect for $t {
42            fn bisect_middle_point(&self, other: &Self) -> Option<Self> {
43                fn to_float_ord(x: $t) -> $i {
44                    let a = x.to_bits() as $i;
45                    a ^ (((a >> $e) as $u) >> 1) as $i
46                }
47                fn from_float_ord(a: $i) -> $t {
48                    $t::from_bits((a ^ (((a >> $e) as $u) >> 1) as $i) as _)
49                }
50                <$i as Bisect>::bisect_middle_point(&to_float_ord(*self), &to_float_ord(*other)).map(from_float_ord)
51            }
52        })*
53    };
54}
55impl_bisect_unsigned!(u8 u16 u32 u64 u128 usize);
56impl_bisect_signed!(i8 i16 i32 i64 i128 isize);
57impl_bisect_float!({f32 u32 i32 31} {f64 u64 i64 63});
58
59/// binary search for monotone segment
60///
61/// if `ok < err` then search [ok, err) where t(`ok`), t, t, .... t, t(`ret`), f,  ... f, f, f, `err`
62///
63/// if `err < ok` then search (err, ok] where `err`, f, f, f, ... f, t(`ret`), ... t, t, t(`ok`)
64pub fn binary_search<T, F>(mut f: F, mut ok: T, mut err: T) -> T
65where
66    T: Bisect,
67    F: FnMut(&T) -> bool,
68{
69    while let Some(m) = ok.bisect_middle_point(&err) {
70        if f(&m) {
71            ok = m;
72        } else {
73            err = m;
74        }
75    }
76    ok
77}
78
79/// binary search for slice
80pub trait SliceBisectExt<T> {
81    /// Returns the first element that satisfies a predicate.
82    fn find_bisect(&self, f: impl FnMut(&T) -> bool) -> Option<&T>;
83    /// Returns the last element that satisfies a predicate.
84    fn rfind_bisect(&self, f: impl FnMut(&T) -> bool) -> Option<&T>;
85    /// Returns the first index that satisfies a predicate.
86    /// if not found, returns `len()`.
87    fn position_bisect(&self, f: impl FnMut(&T) -> bool) -> usize;
88    /// Returns the last index+1 that satisfies a predicate.
89    /// if not found, returns `0`.
90    fn rposition_bisect(&self, f: impl FnMut(&T) -> bool) -> usize;
91}
92impl<T> SliceBisectExt<T> for [T] {
93    fn find_bisect(&self, f: impl FnMut(&T) -> bool) -> Option<&T> {
94        self.get(self.position_bisect(f))
95    }
96    fn rfind_bisect(&self, f: impl FnMut(&T) -> bool) -> Option<&T> {
97        let pos = self.rposition_bisect(f);
98        if pos == 0 { None } else { self.get(pos - 1) }
99    }
100    fn position_bisect(&self, mut f: impl FnMut(&T) -> bool) -> usize {
101        binary_search(|i| f(&self[*i as usize]), self.len() as i64, -1) as usize
102    }
103    fn rposition_bisect(&self, mut f: impl FnMut(&T) -> bool) -> usize {
104        binary_search(|i| f(&self[i - 1]), 0, self.len() + 1)
105    }
106}
107
108pub fn parallel_binary_search<T, F, G>(mut f: F, q: usize, ok: T, err: T) -> Vec<T>
109where
110    T: Bisect,
111    F: FnMut(&[Option<T>]) -> G,
112    G: Fn(usize) -> bool,
113{
114    let mut ok = vec![ok; q];
115    let mut err = vec![err; q];
116    loop {
117        let m: Vec<_> = ok
118            .iter()
119            .zip(&err)
120            .map(|(ok, err)| ok.bisect_middle_point(err))
121            .collect();
122        if m.iter().all(|m| m.is_none()) {
123            break;
124        }
125        let g = f(&m);
126        for (i, m) in m.into_iter().enumerate() {
127            if let Some(m) = m {
128                if g(i) {
129                    ok[i] = m;
130                } else {
131                    err[i] = m;
132                }
133            }
134        }
135    }
136    ok
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    const V: [i64; 10] = [0i64, 1, 1, 1, 2, 2, 3, 4, 7, 8];
143
144    #[test]
145    fn test_binary_search() {
146        assert_eq!(binary_search(|&x| V[x] >= 1, V.len(), 0), 1);
147        assert_eq!(binary_search(|&x| V[x] >= 2, V.len(), 0), 4);
148        assert_eq!(binary_search(|&x| V[x] >= 3, V.len(), 0), 6);
149        assert_eq!(binary_search(|&x| V[x] <= 1, 0, V.len()), 3);
150        assert_eq!(binary_search(|&x| V[x] <= 2, 0, V.len()), 5);
151        assert_eq!(binary_search(|&x| V[x] <= 3, 0, V.len()), 6);
152
153        assert_eq!(
154            binary_search(&|&x: &i64| V[x as usize] <= -1, -1, V.len() as i64),
155            -1
156        );
157
158        let sq2 = binary_search(|&x| x * x <= 2., 1., 4.);
159        let expect = 1.414_213_562_73;
160        assert!(expect - 1e-8 <= sq2 && sq2 <= expect + 1e-8);
161
162        assert_eq!(
163            binary_search(|&x| x < i64::MAX, i64::MIN, i64::MAX),
164            i64::MAX - 1
165        );
166        assert_eq!(
167            binary_search(|&x| x == i64::MIN, i64::MIN, i64::MAX),
168            i64::MIN
169        );
170        assert_eq!(
171            binary_search(|&x| x == i64::MAX, i64::MAX, i64::MIN),
172            i64::MAX
173        );
174        assert_eq!(
175            binary_search(|&x| x > i64::MIN, i64::MAX, i64::MIN),
176            i64::MIN + 1
177        );
178    }
179
180    #[test]
181    fn test_position() {
182        assert_eq!(V.position_bisect(|&x| x >= -1), 0);
183        assert_eq!(V.position_bisect(|&x| x >= 0), 0);
184        assert_eq!(V.position_bisect(|&x| x >= 1), 1);
185        assert_eq!(V.position_bisect(|&x| x >= 2), 4);
186        assert_eq!(V.position_bisect(|&x| x >= 3), 6);
187        assert_eq!(V.position_bisect(|&x| x >= 5), 8);
188        assert_eq!(V.position_bisect(|&x| x >= 10), 10);
189    }
190
191    #[test]
192    fn test_find() {
193        assert_eq!(V.find_bisect(|&x| x >= -1), Some(&0));
194        assert_eq!(V.find_bisect(|&x| x >= 0), Some(&0));
195        assert_eq!(V.find_bisect(|&x| x >= 1), Some(&1));
196        assert_eq!(V.find_bisect(|&x| x >= 2), Some(&2));
197        assert_eq!(V.find_bisect(|&x| x >= 3), Some(&3));
198        assert_eq!(V.find_bisect(|&x| x >= 5), Some(&7));
199        assert_eq!(V.find_bisect(|&x| x >= 10), None);
200    }
201
202    #[test]
203    fn test_rposition() {
204        assert_eq!(V.rposition_bisect(|&x| x <= -1), 0);
205        assert_eq!(V.rposition_bisect(|&x| x <= 0), 1);
206        assert_eq!(V.rposition_bisect(|&x| x <= 1), 4);
207        assert_eq!(V.rposition_bisect(|&x| x <= 2), 6);
208        assert_eq!(V.rposition_bisect(|&x| x <= 3), 7);
209        assert_eq!(V.rposition_bisect(|&x| x <= 5), 8);
210        assert_eq!(V.rposition_bisect(|&x| x <= 10), 10);
211    }
212
213    #[test]
214    fn test_rfind() {
215        assert_eq!(V.rfind_bisect(|&x| x <= -1), None);
216        assert_eq!(V.rfind_bisect(|&x| x <= 0), Some(&0));
217        assert_eq!(V.rfind_bisect(|&x| x <= 1), Some(&1));
218        assert_eq!(V.rfind_bisect(|&x| x <= 2), Some(&2));
219        assert_eq!(V.rfind_bisect(|&x| x <= 3), Some(&3));
220        assert_eq!(V.rfind_bisect(|&x| x <= 5), Some(&4));
221        assert_eq!(V.rfind_bisect(|&x| x <= 10), Some(&8));
222    }
223}