competitive/math/
linear_diophantine.rs

1use super::{ExtendedGcd, RangeBoundsExt, Signed, Unsigned};
2use std::ops::RangeInclusive;
3
4/// ax + b
5#[derive(Clone, Copy, Debug)]
6struct Linear<T>
7where
8    T: Signed,
9{
10    a: T,
11    b: T,
12}
13
14impl<T> Linear<T>
15where
16    T: Signed,
17{
18    fn new(a: T, b: T) -> Self {
19        Self { a, b }
20    }
21    fn eval(&self, x: T) -> T {
22        self.a * x + self.b
23    }
24}
25
26/// Solution of ax + by = c
27#[derive(Clone, Copy, Debug)]
28pub struct LinearDiophantineSolution<T>
29where
30    T: Signed,
31{
32    x: Linear<T>,
33    y: Linear<T>,
34    k_range: (T, T),
35}
36
37macro_rules! with_range {
38    ($this:ident, $x:ident, $range:expr) => {
39        let range = $range.to_range_inclusive();
40        let l = *range.start();
41        let r = *range.end();
42        // l <= a * k + b <= r
43        if !l.is_minimum() {
44            if $this.$x.a.is_positive() {
45                // (l - b) / a <= k
46                let t = (l - $this.$x.b + $this.$x.a - T::one()).div_euclid($this.$x.a);
47                $this.k_range.0 = $this.k_range.0.max(t);
48            } else {
49                // k <= (l - b) / a
50                let t = ($this.$x.b - l).div_euclid(-$this.$x.a);
51                $this.k_range.1 = $this.k_range.1.min(t);
52            }
53        }
54        if !r.is_maximum() {
55            if $this.$x.a.is_positive() {
56                // k <= (r - b) / a
57                let t = (r - $this.$x.b).div_euclid($this.$x.a);
58                $this.k_range.1 = $this.k_range.1.min(t);
59            } else {
60                // (r - b) / a <= k
61                let t = ($this.$x.b - r - $this.$x.a - T::one()).div_euclid(-$this.$x.a);
62                $this.k_range.0 = $this.k_range.0.max(t);
63            }
64        }
65    };
66}
67
68impl<T> LinearDiophantineSolution<T>
69where
70    T: Signed,
71{
72    pub fn eval(&self, k: T) -> (T, T) {
73        (self.x.eval(k), self.y.eval(k))
74    }
75    pub fn with_x_range<R>(mut self, range: R) -> Self
76    where
77        R: RangeBoundsExt<T>,
78    {
79        with_range!(self, x, range);
80        self
81    }
82    pub fn with_y_range<R>(mut self, range: R) -> Self
83    where
84        R: RangeBoundsExt<T>,
85    {
86        with_range!(self, y, range);
87        self
88    }
89    pub fn with_x_order(mut self) -> Self {
90        if self.x.a.is_negative() {
91            self.x.a = -self.x.a;
92            self.y.a = -self.y.a;
93            self.k_range = (
94                if self.k_range.1 == T::maximum() {
95                    T::minimum()
96                } else {
97                    -self.k_range.1
98                },
99                if self.k_range.0 == T::minimum() {
100                    T::maximum()
101                } else {
102                    -self.k_range.0
103                },
104            );
105        }
106        self
107    }
108    pub fn with_y_order(mut self) -> Self {
109        if self.y.a.is_negative() {
110            self.x.a = -self.x.a;
111            self.y.a = -self.y.a;
112            self.k_range = (
113                if self.k_range.1 == T::maximum() {
114                    T::minimum()
115                } else {
116                    -self.k_range.1
117                },
118                if self.k_range.0 == T::minimum() {
119                    T::maximum()
120                } else {
121                    -self.k_range.0
122                },
123            );
124        }
125        self
126    }
127    pub fn k_range(&self) -> RangeInclusive<T> {
128        self.k_range.0..=self.k_range.1
129    }
130}
131
132/// Solve ax + by = c
133pub fn solve_linear_diophantine<T>(a: T, b: T, c: T) -> Option<LinearDiophantineSolution<T>>
134where
135    T: Signed,
136{
137    assert!(!a.is_zero(), "a must be non-zero");
138    assert!(!b.is_zero(), "b must be non-zero");
139    let ExtendedGcd { g, x: x0, y: y0 } = a.extgcd(b);
140    let g = g.signed();
141    let a = a / g;
142    let b = b / g;
143    if c.is_zero() {
144        return Some(LinearDiophantineSolution {
145            x: Linear::new(b, T::zero()),
146            y: Linear::new(-a, T::zero()),
147            k_range: (T::minimum(), T::maximum()),
148        });
149    }
150    if !(c % g).is_zero() {
151        return None;
152    }
153    let c = c / g;
154    let x = Linear::new(b, x0 * c);
155    let y = Linear::new(-a, y0 * c);
156    Some(LinearDiophantineSolution {
157        x,
158        y,
159        k_range: (T::minimum(), T::maximum()),
160    })
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::{num::Zero, rand, tools::Xorshift};
167
168    #[test]
169    fn test_solve_linear_diophantine() {
170        let mut rng = Xorshift::default();
171        for t in [2, 10, 100, 1_000_000i64] {
172            rand!(rng, abc: [(-t..=t, -t..=t, -t..=t); 100]);
173            for (a, b, c) in abc {
174                if a.is_zero() || b.is_zero() {
175                    continue;
176                }
177                if let Some(sol) = solve_linear_diophantine(a, b, c) {
178                    for k in -10..=10 {
179                        let (x, y) = sol.eval(k);
180                        assert_eq!(a * x + b * y, c);
181                    }
182                    rand!(rng, lr: [(-100i64..=100, -100i64..=100); 100]);
183                    for (l, r) in lr {
184                        let mut sol = sol;
185                        sol = sol.with_x_range(l..=r);
186                        let sorted = rng.gen_bool(0.5);
187                        if sorted {
188                            sol = sol.with_x_order();
189                        }
190                        for k in sol.k_range().clone() {
191                            let (x, y) = sol.eval(k);
192                            assert_eq!(a * x + b * y, c);
193                            assert!((l..=r).contains(&x));
194                        }
195                        for k in -100..=100 {
196                            let (x, y) = sol.eval(k);
197                            assert_eq!(a * x + b * y, c);
198                            assert_eq!((l..=r).contains(&x), sol.k_range().contains(&k));
199                        }
200                        if sorted {
201                            assert!(sol.k_range().map(|k| sol.eval(k).0).is_sorted());
202                        }
203                    }
204                    rand!(rng, lr: [(-100i64..=100, -100i64..=100); 100]);
205                    for (l, r) in lr {
206                        let mut sol = sol;
207                        sol = sol.with_y_range(l..=r);
208                        let sorted = rng.gen_bool(0.5);
209                        if sorted {
210                            sol = sol.with_y_order();
211                        }
212                        for k in sol.k_range().clone() {
213                            let (x, y) = sol.eval(k);
214                            assert_eq!(a * x + b * y, c);
215                            assert!((l..=r).contains(&y));
216                        }
217                        for k in -100..=100 {
218                            let (x, y) = sol.eval(k);
219                            assert_eq!(a * x + b * y, c);
220                            assert_eq!((l..=r).contains(&y), sol.k_range().contains(&k));
221                        }
222                        if sorted {
223                            assert!(sol.k_range().map(|k| sol.eval(k).1).is_sorted());
224                        }
225                    }
226                } else {
227                    let ExtendedGcd { g, .. } = a.extgcd(b);
228                    assert!(!(c % g.signed()).is_zero());
229                    for x in -100..=100 {
230                        let y = (c - a * x).div_euclid(b);
231                        assert_ne!(a * x + b * y, c);
232                    }
233                }
234            }
235        }
236    }
237}