competitive/algorithm/
stern_brocot_tree.rs1use super::{URational, Unsigned};
2use std::mem::swap;
3
4pub trait SternBrocotTree: From<URational<Self::T>> + FromIterator<Self::T> {
5 type T: Unsigned;
6
7 fn root() -> Self;
8
9 fn is_root(&self) -> bool;
10
11 fn eval(&self) -> URational<Self::T>;
12
13 fn down_left(&mut self, count: Self::T);
14
15 fn down_right(&mut self, count: Self::T);
16
17 fn up(&mut self, count: Self::T) -> Self::T;
19}
20
21#[derive(Clone, Copy, Debug, Eq, PartialEq)]
22pub struct SbtNode<T>
23where
24 T: Unsigned,
25{
26 pub l: URational<T>,
27 pub r: URational<T>,
28}
29
30#[derive(Clone, Debug, Eq, PartialEq, Default)]
31pub struct SbtPath<T>
32where
33 T: Unsigned,
34{
35 pub path: Vec<T>,
36}
37
38impl<T> From<URational<T>> for SbtNode<T>
39where
40 T: Unsigned,
41{
42 fn from(r: URational<T>) -> Self {
43 SbtPath::from(r).to_node()
44 }
45}
46
47impl<T> FromIterator<T> for SbtNode<T>
48where
49 T: Unsigned,
50{
51 fn from_iter<I>(iter: I) -> Self
52 where
53 I: IntoIterator<Item = T>,
54 {
55 let mut node = SbtNode::root();
56 for (i, count) in iter.into_iter().enumerate() {
57 if i % 2 == 0 {
58 node.down_right(count);
59 } else {
60 node.down_left(count);
61 }
62 }
63 node
64 }
65}
66
67impl<T> From<URational<T>> for SbtPath<T>
68where
69 T: Unsigned,
70{
71 fn from(r: URational<T>) -> Self {
72 assert!(!r.num.is_zero(), "rational must be positive");
73 assert!(!r.den.is_zero(), "rational must be positive");
74
75 let (mut a, mut b) = (r.num, r.den);
76 let mut path = vec![];
77 loop {
78 let x = a / b;
79 a %= b;
80 if a.is_zero() {
81 if !x.is_one() {
82 path.push(x - T::one());
83 }
84 break;
85 }
86 path.push(x);
87 swap(&mut a, &mut b);
88 }
89 Self { path }
90 }
91}
92
93impl<T> FromIterator<T> for SbtPath<T>
94where
95 T: Unsigned,
96{
97 fn from_iter<I>(iter: I) -> Self
98 where
99 I: IntoIterator<Item = T>,
100 {
101 let mut path = SbtPath::root();
102 for (i, count) in iter.into_iter().enumerate() {
103 if i % 2 == 0 {
104 path.down_right(count);
105 } else {
106 path.down_left(count);
107 }
108 }
109 path
110 }
111}
112
113impl<T> IntoIterator for SbtPath<T>
114where
115 T: Unsigned,
116{
117 type Item = T;
118 type IntoIter = std::vec::IntoIter<T>;
119
120 fn into_iter(self) -> Self::IntoIter {
121 self.path.into_iter()
122 }
123}
124
125impl<'a, T> IntoIterator for &'a SbtPath<T>
126where
127 T: Unsigned,
128{
129 type Item = T;
130 type IntoIter = std::iter::Cloned<std::slice::Iter<'a, T>>;
131
132 fn into_iter(self) -> Self::IntoIter {
133 self.path.iter().cloned()
134 }
135}
136
137impl<T> SternBrocotTree for SbtNode<T>
138where
139 T: Unsigned,
140{
141 type T = T;
142
143 fn root() -> Self {
144 Self {
145 l: URational::new(T::zero(), T::one()),
146 r: URational::new(T::one(), T::zero()),
147 }
148 }
149
150 fn is_root(&self) -> bool {
151 self.l.num.is_zero() && self.r.den.is_zero()
152 }
153
154 fn eval(&self) -> URational<Self::T> {
155 URational::new_unchecked(self.l.num + self.r.num, self.l.den + self.r.den)
156 }
157
158 fn down_left(&mut self, count: Self::T) {
159 self.r.num += self.l.num * count;
160 self.r.den += self.l.den * count;
161 }
162
163 fn down_right(&mut self, count: Self::T) {
164 self.l.num += self.r.num * count;
165 self.l.den += self.r.den * count;
166 }
167
168 fn up(&mut self, mut count: Self::T) -> Self::T {
169 while count > T::zero() && !self.is_root() {
170 if self.l.den > self.r.den {
171 let x = count.min(self.l.num / self.r.num);
172 count -= x;
173 self.l.num -= self.r.num * x;
174 self.l.den -= self.r.den * x;
175 } else {
176 let x = count.min(self.r.den / self.l.den);
177 count -= x;
178 self.r.num -= self.l.num * x;
179 self.r.den -= self.l.den * x;
180 }
181 }
182 count
183 }
184}
185
186impl<T> SternBrocotTree for SbtPath<T>
187where
188 T: Unsigned,
189{
190 type T = T;
191
192 fn root() -> Self {
193 Self::default()
194 }
195
196 fn is_root(&self) -> bool {
197 self.path.is_empty()
198 }
199
200 fn eval(&self) -> URational<Self::T> {
201 self.to_node().eval()
202 }
203
204 fn down_left(&mut self, count: Self::T) {
205 if count.is_zero() {
206 return;
207 }
208 if self.path.len().is_multiple_of(2) {
209 if let Some(last) = self.path.last_mut() {
210 *last += count;
211 } else {
212 self.path.push(T::zero());
213 self.path.push(count);
214 }
215 } else {
216 self.path.push(count);
217 }
218 }
219
220 fn down_right(&mut self, count: Self::T) {
221 if count.is_zero() {
222 return;
223 }
224 if self.path.len().is_multiple_of(2) {
225 self.path.push(count);
226 } else {
227 *self.path.last_mut().unwrap() += count;
228 }
229 }
230
231 fn up(&mut self, mut count: Self::T) -> Self::T {
232 while let Some(last) = self.path.last_mut() {
233 let x = count.min(*last);
234 *last -= x;
235 count -= x;
236 if !last.is_zero() {
237 break;
238 }
239 self.path.pop();
240 }
241 count
242 }
243}
244
245impl<T> SbtNode<T>
246where
247 T: Unsigned,
248{
249 pub fn to_path(&self) -> SbtPath<T> {
250 self.eval().into()
251 }
252 pub fn lca<I, J>(path1: I, path2: J) -> Self
253 where
254 I: IntoIterator<Item = T>,
255 J: IntoIterator<Item = T>,
256 {
257 let mut node = SbtNode::root();
258 for (i, (count1, count2)) in path1.into_iter().zip(path2).enumerate() {
259 let count = count1.min(count2);
260 if i % 2 == 0 {
261 node.down_right(count);
262 } else {
263 node.down_left(count);
264 }
265 if count1 != count2 {
266 break;
267 }
268 }
269 node
270 }
271}
272
273impl<T> SbtPath<T>
274where
275 T: Unsigned,
276{
277 pub fn to_node(&self) -> SbtNode<T> {
278 self.path.iter().cloned().collect()
279 }
280 pub fn depth(&self) -> T {
281 self.path.iter().cloned().sum()
282 }
283}
284
285pub fn rational_binary_search<T>(mut f: impl FnMut(&URational<T>) -> bool, n: T) -> SbtNode<T>
286where
287 T: Unsigned,
288{
289 let mut node = SbtNode::root();
290 let lb = f(&node.l);
291 let rb = f(&node.r);
292 assert_ne!(lb, rb, "f(0/1) and f(1/0) must be different");
293 let two = T::one() + T::one();
294 while node.l.num + node.r.num <= n && node.l.den + node.r.den <= n {
295 {
296 let mut k = T::one();
297 loop {
298 let old = node.l;
299 node.down_right(k);
300 if node.l.num > n || node.l.den > n || f(&node.l) != lb {
301 node.l = old;
302 break;
303 }
304 k *= two;
305 }
306 while k > T::zero() {
307 let old = node.l;
308 node.down_right(k);
309 if node.l.num > n || node.l.den > n || f(&node.l) != lb {
310 node.l = old;
311 }
312 k /= two;
313 }
314 }
315 {
316 let mut k = T::one();
317 loop {
318 let old = node.r;
319 node.down_left(k);
320 if node.r.num > n || node.r.den > n || f(&node.r) != rb {
321 node.r = old;
322 break;
323 }
324 k *= two;
325 }
326 while k > T::zero() {
327 let old = node.r;
328 node.down_left(k);
329 if node.r.num > n || node.r.den > n || f(&node.r) != rb {
330 node.r = old;
331 }
332 k /= two;
333 }
334 }
335 }
336 node
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use crate::tools::Xorshift;
343
344 #[test]
345 fn test_sbt_path_encode_decode() {
346 for a in 1u32..50 {
347 for b in 1u32..50 {
348 let r = URational::new(a, b);
349 let path = SbtPath::from(r);
350 let node = path.to_node();
351 assert_eq!(node.eval(), r);
352 }
353 }
354 }
355
356 #[test]
357 fn test_sbt_explore() {
358 let mut rng = Xorshift::default();
359 for _ in 0..10000 {
360 let mut node = SbtNode::<u128>::root();
361 let mut path = SbtPath::<u128>::root();
362 for _ in 0..30 {
363 match rng.random(0..3) {
364 0 => {
365 let count = rng.random(0..=100);
366 node.down_left(count);
367 path.down_left(count);
368 }
369 1 => {
370 let count = rng.random(0..=100);
371 node.down_right(count);
372 path.down_right(count);
373 }
374 _ => {
375 let count = rng.random(0..=100);
376 let r1 = path.up(count);
377 let r2 = node.up(count);
378 assert_eq!(r1, r2);
379 }
380 }
381 assert_eq!(node, path.to_node());
382 assert_eq!(node.eval(), path.eval());
383 assert_eq!(node.is_root(), path.is_root());
384 assert_eq!(node.to_path(), path);
385 assert_eq!(node, path.to_node());
386 }
387 }
388 }
389
390 #[test]
391 fn test_rational_binary_search() {
392 let mut rng = Xorshift::default();
393 for _ in 0..200 {
394 let n = rng.rand(100) + 1;
395 let target = URational::new(rng.rand(1_000_000_000), rng.rand(1_000_000_000) + 1);
396 let node = rational_binary_search(|candidate| &target < candidate, n);
397
398 assert!(target >= node.l);
399 assert!(target < node.r);
400 assert!(node.l.num <= n && node.l.den <= n);
401 assert!(node.r.num <= n && node.r.den <= n);
402
403 let candidates: Vec<_> = (0..=n)
404 .flat_map(|a| (1..=n).map(move |b| URational::new(a, b)))
405 .collect();
406
407 let expected_left = candidates
408 .iter()
409 .copied()
410 .filter(|q| q <= &target)
411 .max()
412 .unwrap_or_else(|| URational::new_unchecked(0, 1));
413 assert_eq!(node.l, expected_left);
414
415 let expected_right = candidates
416 .iter()
417 .copied()
418 .filter(|q| &target < q)
419 .min()
420 .unwrap_or_else(|| URational::new_unchecked(1, 0));
421 assert_eq!(node.r, expected_right);
422 }
423 }
424}