Skip to main content

competitive/algorithm/
ternary_search.rs

1use std::ops::RangeInclusive;
2
3/// fibonacci search helper
4pub trait FibonacciSearch: Sized {
5    fn fibonacci_search<T, F>(self, other: Self, f: F) -> (Self, T)
6    where
7        T: PartialOrd,
8        F: FnMut(Self) -> T;
9}
10macro_rules! impl_fibonacci_search_unsigned {
11    ($($t:ty)*) => {
12        $(impl FibonacciSearch for $t {
13            fn fibonacci_search<T, F>(self, other: Self, mut f: F) -> (Self, T)
14            where
15                T: PartialOrd,
16                F: FnMut(Self) -> T,
17            {
18                let l = self;
19                let r = other;
20                assert!(l <= r);
21                const W: usize = [12, 23, 46, 92, 185][<$t>::BITS.ilog2() as usize - 3];
22                const FIB: [$t; W] = {
23                    let mut fib = [0; W];
24                    fib[0] = 1;
25                    fib[1] = 2;
26                    let mut i = 2;
27                    while i < W {
28                        fib[i] = fib[i - 1] + fib[i - 2];
29                        i += 1;
30                    }
31                    fib
32                };
33                let mut s = l;
34                let mut v0 = None;
35                let mut v1 = None;
36                let mut v2 = None;
37                let mut v3 = None;
38                for w in FIB[..FIB.partition_point(|&f| f < r - l)].windows(2).rev() {
39                    let (w0, w1) = (w[0], w[1]);
40                    if w1 > r - s || v1.get_or_insert_with(|| f(s + w0)) <= v2.get_or_insert_with(|| f(s + w1)) {
41                        v3 = v2;
42                        v2 = v1;
43                        v1 = None;
44                    } else {
45                        v0 = v1;
46                        v1 = v2;
47                        v2 = None;
48                        s += w0;
49                    }
50                }
51                let mut kv = (s, v0.unwrap_or_else(|| f(s)));
52                if s < r {
53                    let v = v1.or(v2).unwrap_or_else(|| f(s + 1));
54                    if v < kv.1 {
55                        kv = (s + 1, v);
56                    }
57                    if s + 1 < r {
58                        let v = v3.unwrap_or_else(|| f(s + 2));
59                        if v < kv.1 {
60                            kv = (s + 2, v);
61                        }
62                    }
63                }
64                kv
65            }
66        })*
67    };
68}
69impl_fibonacci_search_unsigned!(u8 u16 u32 u64 u128 usize);
70
71/// ternary search helper
72pub trait Trisect: Clone {
73    type Key: FibonacciSearch;
74    fn trisect_key(self) -> Self::Key;
75    fn trisect_unkey(key: Self::Key) -> Self;
76}
77
78macro_rules! impl_trisect_unsigned {
79    ($($t:ty)*) => {
80        $(impl Trisect for $t {
81            type Key = $t;
82            fn trisect_key(self) -> Self::Key {
83                self
84            }
85            fn trisect_unkey(key: Self::Key) -> Self {
86                key
87            }
88        })*
89    };
90}
91macro_rules! impl_trisect_signed {
92    ($({$i:ident $u:ident})*) => {
93        $(impl Trisect for $i {
94            type Key = $u;
95            fn trisect_key(self) -> Self::Key {
96                (self as $u) ^ (1 << <$u>::BITS - 1)
97            }
98            fn trisect_unkey(key: Self::Key) -> Self {
99                (key ^ (1 << <$u>::BITS - 1)) as $i
100            }
101        })*
102    };
103}
104macro_rules! impl_trisect_float {
105    ($({$t:ident $u:ident $i:ident})*) => {
106        $(impl Trisect for $t {
107            type Key = $u;
108            fn trisect_key(self) -> Self::Key {
109                let a = self.to_bits() as $i;
110                (a ^ (((a >> <$u>::BITS - 1) as $u) >> 1) as $i) as $u ^ (1 << <$u>::BITS - 1)
111            }
112            fn trisect_unkey(key: Self::Key) -> Self {
113                let key = (key  ^ (1 << <$u>::BITS - 1)) as $i;
114                $t::from_bits((key ^ (((key >> <$u>::BITS - 1) as $u) >> 1) as $i) as _)
115            }
116        })*
117    };
118}
119
120impl_trisect_unsigned!(u8 u16 u32 u64 u128 usize);
121impl_trisect_signed!({i8 u8} {i16 u16} {i32 u32} {i64 u64} {i128 u128} {isize usize});
122impl_trisect_float!({f32 u32 i32} {f64 u64 i64});
123
124/// Returns the element that gives the minimum value from the strictly concave up function and the minimum value.
125pub fn ternary_search<K, V, F>(range: RangeInclusive<K>, mut f: F) -> (K, V)
126where
127    K: Trisect,
128    V: PartialOrd,
129    F: FnMut(K) -> V,
130{
131    let (l, r) = range.into_inner();
132    let (k, v) =
133        <K::Key as FibonacciSearch>::fibonacci_search(l.trisect_key(), r.trisect_key(), |x| {
134            f(Trisect::trisect_unkey(x))
135        });
136    (K::trisect_unkey(k), v)
137}
138
139pub fn piecewise_ternary_search<const N: usize, K, V, F>(piece: [K; N], mut f: F) -> (K, V)
140where
141    K: Trisect,
142    V: PartialOrd,
143    F: FnMut(K) -> V,
144{
145    piece
146        .windows(2)
147        .map(|w| ternary_search(w[0].clone()..=w[1].clone(), &mut f))
148        .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
149        .unwrap_or_else(|| (piece[0].clone(), f(piece[0].clone())))
150}
151
152pub fn golden_ternary_search<T>(
153    range: RangeInclusive<f64>,
154    count: usize,
155    mut f: impl FnMut(f64) -> T,
156) -> (f64, T)
157where
158    T: PartialOrd,
159{
160    let (mut l, mut r) = range.into_inner();
161    // FIXME: 1.94.0: std::f64::consts::GOLDEN_RATIO;
162    const GOLDEN_RATIO_INV: f64 = 1f64 / 1.618_033_988_749_895_f64;
163    let mut v0 = None;
164    let mut v1 = None;
165    let mut v2 = None;
166    let mut v3 = None;
167    for _ in 0..count {
168        let w = (r - l) * GOLDEN_RATIO_INV;
169        if v1.get_or_insert_with(|| f(r - w)) <= v2.get_or_insert_with(|| f(l + w)) {
170            v3 = v2;
171            v2 = v1;
172            v1 = None;
173            r = l + w;
174        } else {
175            v0 = v1;
176            v1 = v2;
177            v2 = None;
178            l = r - w;
179        }
180    }
181    let kv = (l, v0.unwrap_or_else(|| f(l)));
182    let kv2 = (r, v3.unwrap_or_else(|| f(r)));
183    if kv2.1 < kv.1 { kv2 } else { kv }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use crate::{num::DoubleDouble, tools::Xorshift};
190
191    #[test]
192    fn test_trisect_unsigned() {
193        for p in 0u8..=u8::MAX {
194            assert_eq!(p, u8::trisect_unkey(p.trisect_key()));
195            for q in 0u8..=u8::MAX {
196                assert_eq!(p.cmp(&q), p.trisect_key().cmp(&q.trisect_key()));
197            }
198        }
199    }
200
201    #[test]
202    fn test_trisect_signed() {
203        for p in i8::MIN..=i8::MAX {
204            assert_eq!(p, i8::trisect_unkey(p.trisect_key()));
205            for q in i8::MIN..=i8::MAX {
206                assert_eq!(p.cmp(&q), p.trisect_key().cmp(&q.trisect_key()));
207            }
208        }
209    }
210
211    #[test]
212    fn test_trisect_float() {
213        let mut rng = Xorshift::default();
214        for _ in 0..1000 {
215            let p = (rng.randf() - 0.5) * 200.;
216            assert_eq!(p, f64::trisect_unkey(p.trisect_key()));
217            let q = (rng.randf() - 0.5) * 200.;
218            assert_eq!(
219                p.partial_cmp(&q),
220                p.trisect_key().partial_cmp(&q.trisect_key())
221            );
222        }
223    }
224
225    #[test]
226    fn test_ternary_search_unsigned() {
227        for p in 0u8..=u8::MAX {
228            for l in 0u8..=u8::MAX {
229                for r in l..=u8::MAX {
230                    let f = |x| p.abs_diff(x);
231                    assert_eq!(
232                        f(l).min(f(r)).min(f(p.clamp(l, r))),
233                        ternary_search(l..=r, f).1,
234                    );
235                }
236            }
237        }
238    }
239
240    #[test]
241    fn test_ternary_search_signed() {
242        for p in -20..=20 {
243            assert_eq!(
244                p.clamp(-10, 10),
245                ternary_search(-10i64..=10, |x| 10 * (x - p).pow(2) + 5).0,
246            );
247        }
248    }
249
250    #[test]
251    fn test_ternary_search_float() {
252        assert_eq!(
253            std::f64::consts::PI,
254            ternary_search(f64::MIN..=f64::MAX, |x| (DoubleDouble::from(x)
255                - DoubleDouble::from(std::f64::consts::PI))
256            .abs())
257            .0,
258        );
259
260        for a in 0..1000 {
261            assert_eq!(
262                0.0f64,
263                piecewise_ternary_search([0.0, 1e-9, 1.0], |x| (x - (a as f64) / 1000.0).powi(2)).1,
264            )
265        }
266    }
267
268    #[test]
269    fn test_golden_ternary_search_float() {
270        assert_eq!(
271            std::f64::consts::PI,
272            golden_ternary_search(-1e100..=1e100, 1000, |x| (DoubleDouble::from(x)
273                - DoubleDouble::from(std::f64::consts::PI))
274            .abs())
275            .0,
276        );
277    }
278}