competitive/algorithm/
stern_brocot_tree.rs

1use super::{URational, Unsigned};
2use std::mem::swap;
3
4pub trait SternBrocotTree: From<URational<Self::T>> + FromIterator<Self::T> {
5    type T: Unsigned;
6
7    fn root() -> Self;
8
9    fn is_root(&self) -> bool;
10
11    fn eval(&self) -> URational<Self::T>;
12
13    fn down_left(&mut self, count: Self::T);
14
15    fn down_right(&mut self, count: Self::T);
16
17    /// Returns the remaining count after moving up.
18    fn up(&mut self, count: Self::T) -> Self::T;
19}
20
21#[derive(Clone, Copy, Debug, Eq, PartialEq)]
22pub struct SbtNode<T>
23where
24    T: Unsigned,
25{
26    pub l: URational<T>,
27    pub r: URational<T>,
28}
29
30#[derive(Clone, Debug, Eq, PartialEq, Default)]
31pub struct SbtPath<T>
32where
33    T: Unsigned,
34{
35    pub path: Vec<T>,
36}
37
38impl<T> From<URational<T>> for SbtNode<T>
39where
40    T: Unsigned,
41{
42    fn from(r: URational<T>) -> Self {
43        SbtPath::from(r).to_node()
44    }
45}
46
47impl<T> FromIterator<T> for SbtNode<T>
48where
49    T: Unsigned,
50{
51    fn from_iter<I>(iter: I) -> Self
52    where
53        I: IntoIterator<Item = T>,
54    {
55        let mut node = SbtNode::root();
56        for (i, count) in iter.into_iter().enumerate() {
57            if i % 2 == 0 {
58                node.down_right(count);
59            } else {
60                node.down_left(count);
61            }
62        }
63        node
64    }
65}
66
67impl<T> From<URational<T>> for SbtPath<T>
68where
69    T: Unsigned,
70{
71    fn from(r: URational<T>) -> Self {
72        assert!(!r.num.is_zero(), "rational must be positive");
73        assert!(!r.den.is_zero(), "rational must be positive");
74
75        let (mut a, mut b) = (r.num, r.den);
76        let mut path = vec![];
77        loop {
78            let x = a / b;
79            a %= b;
80            if a.is_zero() {
81                if !x.is_one() {
82                    path.push(x - T::one());
83                }
84                break;
85            }
86            path.push(x);
87            swap(&mut a, &mut b);
88        }
89        Self { path }
90    }
91}
92
93impl<T> FromIterator<T> for SbtPath<T>
94where
95    T: Unsigned,
96{
97    fn from_iter<I>(iter: I) -> Self
98    where
99        I: IntoIterator<Item = T>,
100    {
101        let mut path = SbtPath::root();
102        for (i, count) in iter.into_iter().enumerate() {
103            if i % 2 == 0 {
104                path.down_right(count);
105            } else {
106                path.down_left(count);
107            }
108        }
109        path
110    }
111}
112
113impl<T> IntoIterator for SbtPath<T>
114where
115    T: Unsigned,
116{
117    type Item = T;
118    type IntoIter = std::vec::IntoIter<T>;
119
120    fn into_iter(self) -> Self::IntoIter {
121        self.path.into_iter()
122    }
123}
124
125impl<'a, T> IntoIterator for &'a SbtPath<T>
126where
127    T: Unsigned,
128{
129    type Item = T;
130    type IntoIter = std::iter::Cloned<std::slice::Iter<'a, T>>;
131
132    fn into_iter(self) -> Self::IntoIter {
133        self.path.iter().cloned()
134    }
135}
136
137impl<T> SternBrocotTree for SbtNode<T>
138where
139    T: Unsigned,
140{
141    type T = T;
142
143    fn root() -> Self {
144        Self {
145            l: URational::new(T::zero(), T::one()),
146            r: URational::new(T::one(), T::zero()),
147        }
148    }
149
150    fn is_root(&self) -> bool {
151        self.l.num.is_zero() && self.r.den.is_zero()
152    }
153
154    fn eval(&self) -> URational<Self::T> {
155        URational::new_unchecked(self.l.num + self.r.num, self.l.den + self.r.den)
156    }
157
158    fn down_left(&mut self, count: Self::T) {
159        self.r.num += self.l.num * count;
160        self.r.den += self.l.den * count;
161    }
162
163    fn down_right(&mut self, count: Self::T) {
164        self.l.num += self.r.num * count;
165        self.l.den += self.r.den * count;
166    }
167
168    fn up(&mut self, mut count: Self::T) -> Self::T {
169        while count > T::zero() && !self.is_root() {
170            if self.l.den > self.r.den {
171                let x = count.min(self.l.num / self.r.num);
172                count -= x;
173                self.l.num -= self.r.num * x;
174                self.l.den -= self.r.den * x;
175            } else {
176                let x = count.min(self.r.den / self.l.den);
177                count -= x;
178                self.r.num -= self.l.num * x;
179                self.r.den -= self.l.den * x;
180            }
181        }
182        count
183    }
184}
185
186impl<T> SternBrocotTree for SbtPath<T>
187where
188    T: Unsigned,
189{
190    type T = T;
191
192    fn root() -> Self {
193        Self::default()
194    }
195
196    fn is_root(&self) -> bool {
197        self.path.is_empty()
198    }
199
200    fn eval(&self) -> URational<Self::T> {
201        self.to_node().eval()
202    }
203
204    fn down_left(&mut self, count: Self::T) {
205        if count.is_zero() {
206            return;
207        }
208        if self.path.len().is_multiple_of(2) {
209            if let Some(last) = self.path.last_mut() {
210                *last += count;
211            } else {
212                self.path.push(T::zero());
213                self.path.push(count);
214            }
215        } else {
216            self.path.push(count);
217        }
218    }
219
220    fn down_right(&mut self, count: Self::T) {
221        if count.is_zero() {
222            return;
223        }
224        if self.path.len().is_multiple_of(2) {
225            self.path.push(count);
226        } else {
227            *self.path.last_mut().unwrap() += count;
228        }
229    }
230
231    fn up(&mut self, mut count: Self::T) -> Self::T {
232        while let Some(last) = self.path.last_mut() {
233            let x = count.min(*last);
234            *last -= x;
235            count -= x;
236            if !last.is_zero() {
237                break;
238            }
239            self.path.pop();
240        }
241        count
242    }
243}
244
245impl<T> SbtNode<T>
246where
247    T: Unsigned,
248{
249    pub fn to_path(&self) -> SbtPath<T> {
250        self.eval().into()
251    }
252    pub fn lca<I, J>(path1: I, path2: J) -> Self
253    where
254        I: IntoIterator<Item = T>,
255        J: IntoIterator<Item = T>,
256    {
257        let mut node = SbtNode::root();
258        for (i, (count1, count2)) in path1.into_iter().zip(path2).enumerate() {
259            let count = count1.min(count2);
260            if i % 2 == 0 {
261                node.down_right(count);
262            } else {
263                node.down_left(count);
264            }
265            if count1 != count2 {
266                break;
267            }
268        }
269        node
270    }
271}
272
273impl<T> SbtPath<T>
274where
275    T: Unsigned,
276{
277    pub fn to_node(&self) -> SbtNode<T> {
278        self.path.iter().cloned().collect()
279    }
280    pub fn depth(&self) -> T {
281        self.path.iter().cloned().sum()
282    }
283}
284
285pub fn rational_binary_search<T>(mut f: impl FnMut(&URational<T>) -> bool, n: T) -> SbtNode<T>
286where
287    T: Unsigned,
288{
289    let mut node = SbtNode::root();
290    let lb = f(&node.l);
291    let rb = f(&node.r);
292    assert_ne!(lb, rb, "f(0/1) and f(1/0) must be different");
293    let two = T::one() + T::one();
294    while node.l.num + node.r.num <= n && node.l.den + node.r.den <= n {
295        {
296            let mut k = T::one();
297            loop {
298                let old = node.l;
299                node.down_right(k);
300                if node.l.num > n || node.l.den > n || f(&node.l) != lb {
301                    node.l = old;
302                    break;
303                }
304                k *= two;
305            }
306            while k > T::zero() {
307                let old = node.l;
308                node.down_right(k);
309                if node.l.num > n || node.l.den > n || f(&node.l) != lb {
310                    node.l = old;
311                }
312                k /= two;
313            }
314        }
315        {
316            let mut k = T::one();
317            loop {
318                let old = node.r;
319                node.down_left(k);
320                if node.r.num > n || node.r.den > n || f(&node.r) != rb {
321                    node.r = old;
322                    break;
323                }
324                k *= two;
325            }
326            while k > T::zero() {
327                let old = node.r;
328                node.down_left(k);
329                if node.r.num > n || node.r.den > n || f(&node.r) != rb {
330                    node.r = old;
331                }
332                k /= two;
333            }
334        }
335    }
336    node
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342    use crate::tools::Xorshift;
343
344    #[test]
345    fn test_sbt_path_encode_decode() {
346        for a in 1u32..50 {
347            for b in 1u32..50 {
348                let r = URational::new(a, b);
349                let path = SbtPath::from(r);
350                let node = path.to_node();
351                assert_eq!(node.eval(), r);
352            }
353        }
354    }
355
356    #[test]
357    fn test_sbt_explore() {
358        let mut rng = Xorshift::default();
359        for _ in 0..10000 {
360            let mut node = SbtNode::<u128>::root();
361            let mut path = SbtPath::<u128>::root();
362            for _ in 0..30 {
363                match rng.random(0..3) {
364                    0 => {
365                        let count = rng.random(0..=100);
366                        node.down_left(count);
367                        path.down_left(count);
368                    }
369                    1 => {
370                        let count = rng.random(0..=100);
371                        node.down_right(count);
372                        path.down_right(count);
373                    }
374                    _ => {
375                        let count = rng.random(0..=100);
376                        let r1 = path.up(count);
377                        let r2 = node.up(count);
378                        assert_eq!(r1, r2);
379                    }
380                }
381                assert_eq!(node, path.to_node());
382                assert_eq!(node.eval(), path.eval());
383                assert_eq!(node.is_root(), path.is_root());
384                assert_eq!(node.to_path(), path);
385                assert_eq!(node, path.to_node());
386            }
387        }
388    }
389
390    #[test]
391    fn test_rational_binary_search() {
392        let mut rng = Xorshift::default();
393        for _ in 0..200 {
394            let n = rng.rand(100) + 1;
395            let target = URational::new(rng.rand(1_000_000_000), rng.rand(1_000_000_000) + 1);
396            let node = rational_binary_search(|candidate| &target < candidate, n);
397
398            assert!(target >= node.l);
399            assert!(target < node.r);
400            assert!(node.l.num <= n && node.l.den <= n);
401            assert!(node.r.num <= n && node.r.den <= n);
402
403            let candidates: Vec<_> = (0..=n)
404                .flat_map(|a| (1..=n).map(move |b| URational::new(a, b)))
405                .collect();
406
407            let expected_left = candidates
408                .iter()
409                .copied()
410                .filter(|q| q <= &target)
411                .max()
412                .unwrap_or_else(|| URational::new_unchecked(0, 1));
413            assert_eq!(node.l, expected_left);
414
415            let expected_right = candidates
416                .iter()
417                .copied()
418                .filter(|q| &target < q)
419                .min()
420                .unwrap_or_else(|| URational::new_unchecked(1, 0));
421            assert_eq!(node.r, expected_right);
422        }
423    }
424}