competitive/algorithm/
stern_brocot_tree.rs

1use super::{URational, Unsigned};
2use std::mem::swap;
3
4pub trait SternBrocotTree: From<URational<Self::T>> + FromIterator<Self::T> {
5    type T: Unsigned;
6
7    fn root() -> Self;
8
9    fn is_root(&self) -> bool;
10
11    fn eval(&self) -> URational<Self::T>;
12
13    fn down_left(&mut self, count: Self::T);
14
15    fn down_right(&mut self, count: Self::T);
16
17    /// Returns the remaining count after moving up.
18    fn up(&mut self, count: Self::T) -> Self::T;
19}
20
21#[derive(Clone, Copy, Debug, Eq, PartialEq)]
22pub struct SbtNode<T>
23where
24    T: Unsigned,
25{
26    pub l: URational<T>,
27    pub r: URational<T>,
28}
29
30#[derive(Clone, Debug, Eq, PartialEq, Default)]
31pub struct SbtPath<T>
32where
33    T: Unsigned,
34{
35    pub path: Vec<T>,
36}
37
38impl<T> From<URational<T>> for SbtNode<T>
39where
40    T: Unsigned,
41{
42    fn from(r: URational<T>) -> Self {
43        SbtPath::from(r).to_node()
44    }
45}
46
47impl<T> FromIterator<T> for SbtNode<T>
48where
49    T: Unsigned,
50{
51    fn from_iter<I>(iter: I) -> Self
52    where
53        I: IntoIterator<Item = T>,
54    {
55        let mut node = SbtNode::root();
56        for (i, count) in iter.into_iter().enumerate() {
57            if i % 2 == 0 {
58                node.down_right(count);
59            } else {
60                node.down_left(count);
61            }
62        }
63        node
64    }
65}
66
67impl<T> From<URational<T>> for SbtPath<T>
68where
69    T: Unsigned,
70{
71    fn from(r: URational<T>) -> Self {
72        assert!(!r.num.is_zero(), "rational must be positive");
73        assert!(!r.den.is_zero(), "rational must be positive");
74
75        let (mut a, mut b) = (r.num, r.den);
76        let mut path = vec![];
77        loop {
78            let x = a / b;
79            a %= b;
80            if a.is_zero() {
81                if !x.is_one() {
82                    path.push(x - T::one());
83                }
84                break;
85            }
86            path.push(x);
87            swap(&mut a, &mut b);
88        }
89        Self { path }
90    }
91}
92
93impl<T> FromIterator<T> for SbtPath<T>
94where
95    T: Unsigned,
96{
97    fn from_iter<I>(iter: I) -> Self
98    where
99        I: IntoIterator<Item = T>,
100    {
101        let mut path = SbtPath::root();
102        for (i, count) in iter.into_iter().enumerate() {
103            if i % 2 == 0 {
104                path.down_right(count);
105            } else {
106                path.down_left(count);
107            }
108        }
109        path
110    }
111}
112
113impl<T> IntoIterator for SbtPath<T>
114where
115    T: Unsigned,
116{
117    type Item = T;
118    type IntoIter = std::vec::IntoIter<T>;
119
120    fn into_iter(self) -> Self::IntoIter {
121        self.path.into_iter()
122    }
123}
124
125impl<'a, T> IntoIterator for &'a SbtPath<T>
126where
127    T: Unsigned,
128{
129    type Item = T;
130    type IntoIter = std::iter::Cloned<std::slice::Iter<'a, T>>;
131
132    fn into_iter(self) -> Self::IntoIter {
133        self.path.iter().cloned()
134    }
135}
136
137impl<T> SternBrocotTree for SbtNode<T>
138where
139    T: Unsigned,
140{
141    type T = T;
142
143    fn root() -> Self {
144        Self {
145            l: URational::new(T::zero(), T::one()),
146            r: URational::new(T::one(), T::zero()),
147        }
148    }
149
150    fn is_root(&self) -> bool {
151        self.l.num.is_zero() && self.r.den.is_zero()
152    }
153
154    fn eval(&self) -> URational<Self::T> {
155        URational::new_unchecked(self.l.num + self.r.num, self.l.den + self.r.den)
156    }
157
158    fn down_left(&mut self, count: Self::T) {
159        self.r.num += self.l.num * count;
160        self.r.den += self.l.den * count;
161    }
162
163    fn down_right(&mut self, count: Self::T) {
164        self.l.num += self.r.num * count;
165        self.l.den += self.r.den * count;
166    }
167
168    fn up(&mut self, mut count: Self::T) -> Self::T {
169        while count > T::zero() && !self.is_root() {
170            if self.l.den > self.r.den {
171                let x = count.min(self.l.num / self.r.num);
172                count -= x;
173                self.l.num -= self.r.num * x;
174                self.l.den -= self.r.den * x;
175            } else {
176                let x = count.min(self.r.den / self.l.den);
177                count -= x;
178                self.r.num -= self.l.num * x;
179                self.r.den -= self.l.den * x;
180            }
181        }
182        count
183    }
184}
185
186impl<T> SternBrocotTree for SbtPath<T>
187where
188    T: Unsigned,
189{
190    type T = T;
191
192    fn root() -> Self {
193        Self::default()
194    }
195
196    fn is_root(&self) -> bool {
197        self.path.is_empty()
198    }
199
200    fn eval(&self) -> URational<Self::T> {
201        self.to_node().eval()
202    }
203
204    fn down_left(&mut self, count: Self::T) {
205        if count.is_zero() {
206            return;
207        }
208        if self.path.len() % 2 == 0 {
209            if let Some(last) = self.path.last_mut() {
210                *last += count;
211            } else {
212                self.path.push(T::zero());
213                self.path.push(count);
214            }
215        } else {
216            self.path.push(count);
217        }
218    }
219
220    fn down_right(&mut self, count: Self::T) {
221        if count.is_zero() {
222            return;
223        }
224        if self.path.len() % 2 == 0 {
225            self.path.push(count);
226        } else {
227            *self.path.last_mut().unwrap() += count;
228        }
229    }
230
231    fn up(&mut self, mut count: Self::T) -> Self::T {
232        while let Some(last) = self.path.last_mut() {
233            let x = count.min(*last);
234            *last -= x;
235            count -= x;
236            if !last.is_zero() {
237                break;
238            }
239            self.path.pop();
240        }
241        count
242    }
243}
244
245impl<T> SbtNode<T>
246where
247    T: Unsigned,
248{
249    pub fn to_path(&self) -> SbtPath<T> {
250        self.eval().into()
251    }
252    pub fn lca<I, J>(path1: I, path2: J) -> Self
253    where
254        I: IntoIterator<Item = T>,
255        J: IntoIterator<Item = T>,
256    {
257        let mut node = SbtNode::root();
258        for (i, (count1, count2)) in path1.into_iter().zip(path2).enumerate() {
259            let count = count1.min(count2);
260            if i % 2 == 0 {
261                node.down_right(count);
262            } else {
263                node.down_left(count);
264            }
265            if count1 != count2 {
266                break;
267            }
268        }
269        node
270    }
271}
272
273impl<T> SbtPath<T>
274where
275    T: Unsigned,
276{
277    pub fn to_node(&self) -> SbtNode<T> {
278        self.path.iter().cloned().collect()
279    }
280    pub fn depth(&self) -> T {
281        self.path.iter().cloned().sum()
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use crate::tools::Xorshift;
289
290    #[test]
291    fn test_sbt_path_encode_decode() {
292        for a in 1u32..50 {
293            for b in 1u32..50 {
294                let r = URational::new(a, b);
295                let path = SbtPath::from(r);
296                let node = path.to_node();
297                assert_eq!(node.eval(), r);
298            }
299        }
300    }
301
302    #[test]
303    fn test_sbt_explore() {
304        let mut rng = Xorshift::default();
305        for _ in 0..10000 {
306            let mut node = SbtNode::<u128>::root();
307            let mut path = SbtPath::<u128>::root();
308            for _ in 0..30 {
309                match rng.random(0..3) {
310                    0 => {
311                        let count = rng.random(0..=100);
312                        node.down_left(count);
313                        path.down_left(count);
314                    }
315                    1 => {
316                        let count = rng.random(0..=100);
317                        node.down_right(count);
318                        path.down_right(count);
319                    }
320                    _ => {
321                        let count = rng.random(0..=100);
322                        let r1 = path.up(count);
323                        let r2 = node.up(count);
324                        assert_eq!(r1, r2);
325                    }
326                }
327                assert_eq!(node, path.to_node());
328                assert_eq!(node.eval(), path.eval());
329                assert_eq!(node.is_root(), path.is_root());
330                assert_eq!(node.to_path(), path);
331                assert_eq!(node, path.to_node());
332            }
333        }
334    }
335}