competitive/math/
gcd.rs

1pub fn gcd_loop(mut a: u64, mut b: u64) -> u64 {
2    while b != 0 {
3        a %= b;
4        std::mem::swap(&mut a, &mut b);
5    }
6    a
7}
8
9#[codesnip::entry]
10/// binary gcd
11pub fn gcd(mut a: u64, mut b: u64) -> u64 {
12    if a == 0 {
13        return b;
14    }
15    if b == 0 {
16        return a;
17    }
18    let u = a.trailing_zeros();
19    let v = b.trailing_zeros();
20    a >>= u;
21    b >>= v;
22    let k = std::cmp::min(u, v);
23    while a != b {
24        if a < b {
25            std::mem::swap(&mut a, &mut b);
26        }
27        a -= b;
28        a >>= a.trailing_zeros();
29    }
30    a << k
31}
32
33#[codesnip::entry(include("gcd"))]
34pub fn lcm(a: u64, b: u64) -> u64 {
35    a / gcd(a, b) * b
36}
37
38// ax + by = gcd(a, b)
39// a, b -> gcd(a, b), x, y
40pub fn extgcd_recurse(a: i64, b: i64) -> (i64, i64, i64) {
41    if b == 0 {
42        (a, 1, 0)
43    } else {
44        let (g, x, y) = extgcd_recurse(b, a % b);
45        (g, y, x - (a / b) * y)
46    }
47}
48
49#[codesnip::entry]
50pub fn extgcd(mut a: i64, mut b: i64) -> (i64, i64, i64) {
51    let (mut u, mut v, mut x, mut y) = (1, 0, 0, 1);
52    while a != 0 {
53        let k = b / a;
54        x -= k * u;
55        y -= k * v;
56        b -= k * a;
57        std::mem::swap(&mut x, &mut u);
58        std::mem::swap(&mut y, &mut v);
59        std::mem::swap(&mut b, &mut a);
60    }
61    (b, x, y)
62}
63
64pub fn extgcd_binary(mut a: i64, mut b: i64) -> (i64, i64, i64) {
65    if b == 0 {
66        return (a, 1, 0);
67    } else if a == 0 {
68        return (b, 1, 0);
69    }
70    let k = (a | b).trailing_zeros();
71    a >>= k;
72    b >>= k;
73    let (c, d) = (a, b);
74    let (mut u, mut v, mut s, mut t) = (1, 0, 0, 1);
75    while a & 1 == 0 {
76        a /= 2;
77        if u & 1 == 1 || v & 1 == 1 {
78            u += d;
79            v -= c;
80        }
81        u /= 2;
82        v /= 2;
83    }
84    while a != b {
85        if b & 1 == 0 {
86            b /= 2;
87            if s & 1 == 1 || t & 1 == 1 {
88                s += d;
89                t -= c;
90            }
91            s /= 2;
92            t /= 2;
93        } else if b < a {
94            std::mem::swap(&mut a, &mut b);
95            std::mem::swap(&mut u, &mut s);
96            std::mem::swap(&mut v, &mut t);
97        } else {
98            b -= a;
99            s -= u;
100            t -= v;
101        }
102    }
103    (a << k, s, t)
104}
105
106pub fn modinv_recurse(a: u64, m: u64) -> u64 {
107    (extgcd_recurse(a as i64, m as i64).1 % m as i64 + m as i64) as u64 % m
108}
109
110#[codesnip::entry(include("extgcd"))]
111pub fn modinv(a: u64, m: u64) -> u64 {
112    let (mut a, mut b) = (a as i64, m as i64);
113    let (mut u, mut x) = (1, 0);
114    while a != 0 {
115        let k = b / a;
116        x -= k * u;
117        b -= k * a;
118        std::mem::swap(&mut x, &mut u);
119        std::mem::swap(&mut b, &mut a);
120    }
121    (if x < 0 { x + m as i64 } else { x }) as _
122}
123
124/// 0 < a < p, gcd(a, p) == 1, p is prime > 2
125pub fn modinv_extgcd_binary(mut a: u64, p: u64) -> u64 {
126    let (mut b, mut u, mut s) = (p, 1, 0);
127    let k = a.trailing_zeros();
128    a >>= k;
129    for _ in 0..k {
130        if u & 1 == 1 {
131            u += p;
132        }
133        u /= 2;
134    }
135    while a != b {
136        if b < a {
137            std::mem::swap(&mut a, &mut b);
138            std::mem::swap(&mut u, &mut s);
139        }
140        b -= a;
141        if s < u {
142            s += p;
143        }
144        s -= u;
145        let k = b.trailing_zeros();
146        b >>= k;
147        for _ in 0..k {
148            if s & 1 == 1 {
149                s += p;
150            }
151            s /= 2;
152        }
153    }
154    s
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use crate::tools::Xorshift;
161    const Q: usize = 100_000;
162    const A: i64 = 1_000_000_007_000_000_007;
163
164    #[test]
165    fn test_gcd() {
166        let mut rng = Xorshift::new();
167        for (a, b) in rng.random_iter((0.., 0..)).take(Q) {
168            assert_eq!(gcd_loop(a, b), gcd(a, b));
169        }
170        assert_eq!(gcd_loop(0, 0), gcd(0, 0));
171        assert_eq!(gcd_loop(0, 100), gcd(0, 100));
172    }
173
174    #[test]
175    fn test_extgcd() {
176        let mut rng = Xorshift::new();
177        for (a, b) in rng.random_iter((-A..=A, -A..=A)).take(Q) {
178            let (g, x, y) = extgcd(a, b);
179            assert_eq!(a as i128 * x as i128 + b as i128 * y as i128, g as i128);
180        }
181    }
182
183    #[test]
184    fn test_extgcd_binary() {
185        let mut rng = Xorshift::new();
186        for (a, b) in rng.random_iter((0..=A, 0..=A)).take(Q) {
187            let (g, x, y) = extgcd_binary(a, b);
188            assert_eq!(a as i128 * x as i128 + b as i128 * y as i128, g as i128);
189        }
190    }
191
192    #[test]
193    fn test_modinv() {
194        let mut rng = Xorshift::new();
195        for _ in 0..Q {
196            let m = rng.random(1..=A as u64);
197            let a = rng.random(1..m);
198            let g = gcd(a, m);
199            let m = m / g;
200            let a = a / g;
201            let x = modinv(a, m);
202            assert!(x < m);
203            assert_eq!(a as u128 * x as u128 % m as u128, 1);
204        }
205    }
206
207    #[test]
208    fn test_modinv_extgcd_binary() {
209        let mut rng = Xorshift::new();
210        for _ in 0..Q {
211            let m = rng.random(1..=A as u64);
212            let m = m >> m.trailing_zeros();
213            let a = rng.random(1..m);
214            let g = gcd(a, m);
215            let m = m / g;
216            let a = a / g;
217            let x = modinv_extgcd_binary(a, m);
218            assert!(x < m);
219            assert_eq!(a as u128 * x as u128 % m as u128, 1);
220        }
221    }
222}