1use super::{ExtendedGcd, RangeBoundsExt, Signed, Unsigned};
2use std::ops::RangeInclusive;
3
4#[derive(Clone, Copy, Debug)]
6struct Linear<T>
7where
8 T: Signed,
9{
10 a: T,
11 b: T,
12}
13
14impl<T> Linear<T>
15where
16 T: Signed,
17{
18 fn new(a: T, b: T) -> Self {
19 Self { a, b }
20 }
21 fn eval(&self, x: T) -> T {
22 self.a * x + self.b
23 }
24}
25
26#[derive(Clone, Copy, Debug)]
28pub struct LinearDiophantineSolution<T>
29where
30 T: Signed,
31{
32 x: Linear<T>,
33 y: Linear<T>,
34 k_range: (T, T),
35}
36
37macro_rules! with_range {
38 ($this:ident, $x:ident, $range:expr) => {
39 let range = $range.to_range_inclusive();
40 let l = *range.start();
41 let r = *range.end();
42 if !l.is_minimum() {
44 if $this.$x.a.is_positive() {
45 let t = (l - $this.$x.b + $this.$x.a - T::one()).div_euclid($this.$x.a);
47 $this.k_range.0 = $this.k_range.0.max(t);
48 } else {
49 let t = ($this.$x.b - l).div_euclid(-$this.$x.a);
51 $this.k_range.1 = $this.k_range.1.min(t);
52 }
53 }
54 if !r.is_maximum() {
55 if $this.$x.a.is_positive() {
56 let t = (r - $this.$x.b).div_euclid($this.$x.a);
58 $this.k_range.1 = $this.k_range.1.min(t);
59 } else {
60 let t = ($this.$x.b - r - $this.$x.a - T::one()).div_euclid(-$this.$x.a);
62 $this.k_range.0 = $this.k_range.0.max(t);
63 }
64 }
65 };
66}
67
68impl<T> LinearDiophantineSolution<T>
69where
70 T: Signed,
71{
72 pub fn eval(&self, k: T) -> (T, T) {
73 (self.x.eval(k), self.y.eval(k))
74 }
75 pub fn with_x_range<R>(mut self, range: R) -> Self
76 where
77 R: RangeBoundsExt<T>,
78 {
79 with_range!(self, x, range);
80 self
81 }
82 pub fn with_y_range<R>(mut self, range: R) -> Self
83 where
84 R: RangeBoundsExt<T>,
85 {
86 with_range!(self, y, range);
87 self
88 }
89 pub fn with_x_order(mut self) -> Self {
90 if self.x.a.is_negative() {
91 self.x.a = -self.x.a;
92 self.y.a = -self.y.a;
93 self.k_range = (
94 if self.k_range.1 == T::maximum() {
95 T::minimum()
96 } else {
97 -self.k_range.1
98 },
99 if self.k_range.0 == T::minimum() {
100 T::maximum()
101 } else {
102 -self.k_range.0
103 },
104 );
105 }
106 self
107 }
108 pub fn with_y_order(mut self) -> Self {
109 if self.y.a.is_negative() {
110 self.x.a = -self.x.a;
111 self.y.a = -self.y.a;
112 self.k_range = (
113 if self.k_range.1 == T::maximum() {
114 T::minimum()
115 } else {
116 -self.k_range.1
117 },
118 if self.k_range.0 == T::minimum() {
119 T::maximum()
120 } else {
121 -self.k_range.0
122 },
123 );
124 }
125 self
126 }
127 pub fn k_range(&self) -> RangeInclusive<T> {
128 self.k_range.0..=self.k_range.1
129 }
130}
131
132pub fn solve_linear_diophantine<T>(a: T, b: T, c: T) -> Option<LinearDiophantineSolution<T>>
134where
135 T: Signed,
136{
137 assert!(!a.is_zero(), "a must be non-zero");
138 assert!(!b.is_zero(), "b must be non-zero");
139 let ExtendedGcd { g, x: x0, y: y0 } = a.extgcd(b);
140 let g = g.signed();
141 let a = a / g;
142 let b = b / g;
143 if c.is_zero() {
144 return Some(LinearDiophantineSolution {
145 x: Linear::new(b, T::zero()),
146 y: Linear::new(-a, T::zero()),
147 k_range: (T::minimum(), T::maximum()),
148 });
149 }
150 if !(c % g).is_zero() {
151 return None;
152 }
153 let c = c / g;
154 let x = Linear::new(b, x0 * c);
155 let y = Linear::new(-a, y0 * c);
156 Some(LinearDiophantineSolution {
157 x,
158 y,
159 k_range: (T::minimum(), T::maximum()),
160 })
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::{num::Zero, rand, tools::Xorshift};
167
168 #[test]
169 fn test_solve_linear_diophantine() {
170 let mut rng = Xorshift::default();
171 for t in [2, 10, 100, 1_000_000i64] {
172 rand!(rng, abc: [(-t..=t, -t..=t, -t..=t); 100]);
173 for (a, b, c) in abc {
174 if a.is_zero() || b.is_zero() {
175 continue;
176 }
177 if let Some(sol) = solve_linear_diophantine(a, b, c) {
178 for k in -10..=10 {
179 let (x, y) = sol.eval(k);
180 assert_eq!(a * x + b * y, c);
181 }
182 rand!(rng, lr: [(-100i64..=100, -100i64..=100); 100]);
183 for (l, r) in lr {
184 let mut sol = sol;
185 sol = sol.with_x_range(l..=r);
186 let sorted = rng.gen_bool(0.5);
187 if sorted {
188 sol = sol.with_x_order();
189 }
190 for k in sol.k_range().clone() {
191 let (x, y) = sol.eval(k);
192 assert_eq!(a * x + b * y, c);
193 assert!((l..=r).contains(&x));
194 }
195 for k in -100..=100 {
196 let (x, y) = sol.eval(k);
197 assert_eq!(a * x + b * y, c);
198 assert_eq!((l..=r).contains(&x), sol.k_range().contains(&k));
199 }
200 if sorted {
201 assert!(sol.k_range().map(|k| sol.eval(k).0).is_sorted());
202 }
203 }
204 rand!(rng, lr: [(-100i64..=100, -100i64..=100); 100]);
205 for (l, r) in lr {
206 let mut sol = sol;
207 sol = sol.with_y_range(l..=r);
208 let sorted = rng.gen_bool(0.5);
209 if sorted {
210 sol = sol.with_y_order();
211 }
212 for k in sol.k_range().clone() {
213 let (x, y) = sol.eval(k);
214 assert_eq!(a * x + b * y, c);
215 assert!((l..=r).contains(&y));
216 }
217 for k in -100..=100 {
218 let (x, y) = sol.eval(k);
219 assert_eq!(a * x + b * y, c);
220 assert_eq!((l..=r).contains(&y), sol.k_range().contains(&k));
221 }
222 if sorted {
223 assert!(sol.k_range().map(|k| sol.eval(k).1).is_sorted());
224 }
225 }
226 } else {
227 let ExtendedGcd { g, .. } = a.extgcd(b);
228 assert!(!(c % g.signed()).is_zero());
229 for x in -100..=100 {
230 let y = (c - a * x).div_euclid(b);
231 assert_ne!(a * x + b * y, c);
232 }
233 }
234 }
235 }
236 }
237}