Function modinv

Source
pub fn modinv(a: u64, m: u64) -> u64
Examples found in repository?
crates/competitive/src/math/discrete_logarithm.rs (line 26)
20fn solve_linear_congruence(a: u64, b: u64, m: u64) -> Option<(u64, u64)> {
21    let g = gcd(a, m);
22    if b % g != 0 {
23        return None;
24    }
25    let (a, b, m) = (a / g, b / g, m / g);
26    Some(((b as u128 * modinv(a, m) as u128 % m as u128) as _, m))
27}
28
29fn solve_linear_congruences<I>(abm: I) -> Option<(u64, u64)>
30where
31    I: IntoIterator<Item = (u64, u64, u64)>,
32{
33    let mut x = 0u64;
34    let mut m0 = 1u64;
35    for (a, b, m) in abm {
36        let mut b = b + m - a * x % m;
37        if b >= m {
38            b -= m;
39        }
40        let a = a * m0;
41        let g = gcd(a, m);
42        if b % g != 0 {
43            return None;
44        }
45        let (a, b, m) = (a / g, b / g, m / g);
46        x += (b as u128 * modinv(a, m) as u128 % m as u128 * m0 as u128) as u64;
47        m0 *= m;
48    }
49    Some((x, m0))
50}
51
52#[derive(Debug)]
53struct IndexCalculus {
54    primes: PrimeList,
55    br_primes: Vec<BarrettReduction<u64>>,
56    ic: HashMap<u64, IndexCalculusWithPrimitiveRoot>,
57}
58
59impl IndexCalculus {
60    fn new() -> Self {
61        Self {
62            primes: PrimeList::new(2),
63            br_primes: Default::default(),
64            ic: Default::default(),
65        }
66    }
67    fn discrete_logarithm(&mut self, a: u64, b: u64, p: u64) -> Option<(u64, u64)> {
68        let lim = ((((p as f64).log2() * (p as f64).log2().log2()).sqrt() / 2.0 + 1.).exp2() * 0.9)
69            as u64;
70        self.primes.reserve(lim);
71        let primes = self.primes.primes_lte(lim);
72        while self.br_primes.len() < primes.len() {
73            let br = BarrettReduction::<u64>::new(primes[self.br_primes.len()]);
74            self.br_primes.push(br);
75        }
76        let br_primes = &self.br_primes[..primes.len()];
77        self.ic
78            .entry(p)
79            .or_insert_with(|| IndexCalculusWithPrimitiveRoot::new(p, br_primes))
80            .discrete_logarithm(a, b, br_primes)
81    }
82}
83
84const A: [u32; 150] = [
85    62, 61, 60, 60, 59, 58, 58, 58, 57, 56, 56, 56, 56, 55, 55, 55, 54, 54, 54, 53, 53, 53, 53, 52,
86    52, 52, 52, 52, 52, 51, 50, 50, 50, 50, 49, 49, 49, 48, 48, 48, 48, 48, 47, 47, 47, 47, 47, 47,
87    47, 47, 47, 47, 47, 47, 47, 47, 45, 42, 42, 41, 41, 41, 41, 41, 41, 41, 40, 40, 40, 40, 40, 40,
88    40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 38, 38, 38, 38, 38, 32, 32, 32, 32,
89    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 31, 31, 31, 31, 31,
90    31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 22, 22, 22, 22,
91    22, 22, 22, 22, 22, 22,
92];
93
94fn factorize_smooth(mut x: u64, row: &mut [u64], br_primes: &[BarrettReduction<u64>]) -> bool {
95    for (j, (&br, r)) in br_primes.iter().zip(row).enumerate() {
96        *r = 0;
97        loop {
98            let (div, rem) = br.div_rem(x);
99            if rem != 0 {
100                break;
101            }
102            *r += 1;
103            x = div;
104        }
105        if j < 150 && x >= (1u64 << A[j]) {
106            break;
107        }
108    }
109    x == 1
110}
111
112#[derive(Debug)]
113struct QdrtPowPrec {
114    br_qdrt: BarrettReduction<u64>,
115    p0: Vec<u64>,
116    p1: Vec<u64>,
117    p2: Vec<u64>,
118    p3: Vec<u64>,
119}
120
121impl QdrtPowPrec {
122    fn new(a: u64, ord: u64, br: &BarrettReduction<u128>) -> Self {
123        let qdrt = (ord as f64).powf(0.25).ceil() as u64;
124        let br_qdrt = BarrettReduction::<u64>::new(qdrt);
125        let mut p0 = Vec::with_capacity(qdrt as usize);
126        let mut p1 = Vec::with_capacity(qdrt as usize);
127        let mut p2 = Vec::with_capacity(qdrt as usize);
128        let mut p3 = Vec::with_capacity(qdrt as usize);
129        let mut acc = 1u64;
130        for _ in 0..qdrt {
131            p0.push(acc);
132            acc = br.rem(acc as u128 * a as u128) as u64;
133        }
134        let a = acc;
135        acc = 1;
136        for _ in 0..qdrt {
137            p1.push(acc);
138            acc = br.rem(acc as u128 * a as u128) as u64;
139        }
140        let a = acc;
141        acc = 1;
142        for _ in 0..qdrt {
143            p2.push(acc);
144            acc = br.rem(acc as u128 * a as u128) as u64;
145        }
146        let a = acc;
147        acc = 1;
148        for _ in 0..qdrt {
149            p3.push(acc);
150            acc = br.rem(acc as u128 * a as u128) as u64;
151        }
152        Self {
153            br_qdrt,
154            p0,
155            p1,
156            p2,
157            p3,
158        }
159    }
160    fn pow(&self, mut k: u64, br: &BarrettReduction<u128>) -> u64 {
161        let (a, b) = self.br_qdrt.div_rem(k);
162        let mut x = self.p0[b as usize];
163        k = a;
164        if k > 0 {
165            let (a, b) = self.br_qdrt.div_rem(k);
166            x = br.rem(x as u128 * self.p1[b as usize] as u128) as u64;
167            k = a;
168        }
169        if k > 0 {
170            let (a, b) = self.br_qdrt.div_rem(k);
171            x = br.rem(x as u128 * self.p2[b as usize] as u128) as u64;
172            k = a;
173        }
174        if k > 0 {
175            let (_, b) = self.br_qdrt.div_rem(k);
176            x = br.rem(x as u128 * self.p3[b as usize] as u128) as u64;
177        }
178        x
179    }
180}
181
182fn index_calculus_for_primitive_root(
183    p: u64,
184    ord: u64,
185    br_primes: &[BarrettReduction<u64>],
186    prec: &QdrtPowPrec,
187) -> Vec<u64> {
188    let br_ord = BarrettReduction::<u128>::new(ord as u128);
189    let mul = |x: u64, y: u64| br_ord.rem(x as u128 * y as u128) as u64;
190    let sub = |x: u64, y: u64| if x < y { x + ord - y } else { x - y };
191
192    let pc = br_primes.len();
193    let mut mat: Vec<Vec<u64>> = vec![];
194    let mut rows: Vec<Vec<u64>> = vec![];
195
196    let mut rng = Xorshift::default();
197    let br = BarrettReduction::<u128>::new(p as u128);
198
199    for i in 0..pc {
200        for ri in 0usize.. {
201            let mut row = vec![0u64; pc + 1];
202            let mut kk = rng.rand(ord - 1) + 1;
203            let mut gkk = prec.pow(kk, &br);
204            let mut k = kk;
205            let mut gk = gkk;
206            while ri >= rows.len() {
207                row[pc] = k;
208                if factorize_smooth(gk, &mut row, br_primes) {
209                    rows.push(row);
210                    break;
211                }
212                if k + kk < ord {
213                    k += kk;
214                    gk = br.rem(gk as u128 * gkk as u128) as u64;
215                } else {
216                    kk = rng.rand(ord - 1) + 1;
217                    gkk = prec.pow(kk, &br);
218                    k = kk;
219                    gk = gkk;
220                }
221            }
222            let row = &mut rows[ri];
223            for j in 0..i {
224                if row[j] != 0 {
225                    let b = mul(modinv(mat[j][j], ord), row[j]);
226                    for (r, a) in row[j..].iter_mut().zip(&mat[j][j..]) {
227                        *r = sub(*r, mul(*a, b));
228                    }
229                }
230                assert_eq!(row[j], 0);
231            }
232            if gcd(row[i], ord) == 1 {
233                let last = rows.len() - 1;
234                rows.swap(ri, last);
235                mat.push(rows.pop().unwrap());
236                break;
237            }
238        }
239    }
240    for i in (0..pc).rev() {
241        for j in i + 1..pc {
242            mat[i][pc] = sub(mat[i][pc], mul(mat[i][j], mat[j][pc]));
243        }
244        mat[i][pc] = mul(mat[i][pc], modinv(mat[i][i], ord));
245    }
246    (0..pc).map(|i| (mat[i][pc])).collect()
247}
248
249#[derive(Debug)]
250struct IndexCalculusWithPrimitiveRoot {
251    p: u64,
252    ord: u64,
253    prec: QdrtPowPrec,
254    coeff: Vec<u64>,
255}
256
257impl IndexCalculusWithPrimitiveRoot {
258    fn new(p: u64, br_primes: &[BarrettReduction<u64>]) -> Self {
259        let ord = p - 1;
260        let g = primitive_root(p);
261        let br = BarrettReduction::<u128>::new(p as u128);
262        let prec = QdrtPowPrec::new(g, ord, &br);
263        let coeff = index_calculus_for_primitive_root(p, ord, br_primes, &prec);
264        Self {
265            p,
266            ord,
267            prec,
268            coeff,
269        }
270    }
271    fn index_calculus(&self, a: u64, br_primes: &[BarrettReduction<u64>]) -> Option<u64> {
272        let p = self.p;
273        let ord = self.ord;
274        let br = BarrettReduction::<u128>::new(p as u128);
275        let a = br.rem(a as _) as u64;
276        if a == 1 {
277            return Some(0);
278        }
279        if p == 2 {
280            return None;
281        }
282
283        let mut rng = Xorshift::new();
284        let mut row = vec![0u64; br_primes.len()];
285        let mut kk = rng.rand(ord - 1) + 1;
286        let mut gkk = self.prec.pow(kk, &br);
287        let mut k = kk;
288        let mut gk = br.rem(gkk as u128 * a as u128) as u64;
289        loop {
290            if factorize_smooth(gk, &mut row, br_primes) {
291                let mut res = ord - k;
292                for (&c, &r) in self.coeff.iter().zip(&row) {
293                    for _ in 0..r {
294                        res += c;
295                        if res >= ord {
296                            res -= ord;
297                        }
298                    }
299                }
300                return Some(res);
301            }
302            if k + kk < ord {
303                k += kk;
304                gk = br.rem(gk as u128 * gkk as u128) as u64;
305            } else {
306                kk = rng.rand(ord - 1) + 1;
307                gkk = self.prec.pow(kk, &br);
308                k = kk;
309                gk = br.rem(gkk as u128 * a as u128) as u64;
310            }
311        }
312    }
313    fn discrete_logarithm(
314        &self,
315        a: u64,
316        b: u64,
317        br_primes: &[BarrettReduction<u64>],
318    ) -> Option<(u64, u64)> {
319        let p = self.p;
320        let ord = self.ord;
321        let br = BarrettReduction::<u128>::new(p as u128);
322        let a = br.rem(a as _) as u64;
323        let b = br.rem(b as _) as u64;
324        if a == 0 {
325            return if b == 0 { Some((1, 1)) } else { None };
326        }
327        if b == 0 {
328            return None;
329        }
330
331        let x = self.index_calculus(a, br_primes)?;
332        let y = self.index_calculus(b, br_primes)?;
333        solve_linear_congruence(x, y, ord)
334    }
335}
336
337thread_local!(
338    static IC: UnsafeCell<IndexCalculus> = UnsafeCell::new(IndexCalculus::new());
339);
340
341pub fn discrete_logarithm_prime_mod(a: u64, b: u64, p: u64) -> Option<u64> {
342    IC.with(|ic| unsafe { &mut *ic.get() }.discrete_logarithm(a, b, p))
343        .map(|t| t.0)
344}
345
346/// a^x ≡ b (mod n), a has order p^e
347fn pohlig_hellman_prime_power_order(a: u64, b: u64, n: u64, p: u64, e: u32) -> Option<u64> {
348    let br = BarrettReduction::<u128>::new(n as u128);
349    let mul = |x: u64, y: u64| br.rem(x as u128 * y as u128) as u64;
350    let block_size = (p as f64).sqrt().ceil() as u64;
351    let mut baby = HashMap::<u64, u64>::new();
352    let g = pow(a, p.pow(e - 1), &br);
353    let mut xj = 1;
354    for j in 0..block_size {
355        baby.entry(xj).or_insert(j);
356        xj = mul(xj, g);
357    }
358    let xi = modinv(xj, n);
359    let mut t = 0u64;
360    for k in 0..e {
361        let mut h = pow(mul(modinv(pow(a, t, &br), n), b), p.pow(e - 1 - k), &br);
362        let mut ok = false;
363        for i in (0..block_size * block_size).step_by(block_size as usize) {
364            if let Some(j) = baby.get(&h) {
365                t += (i + j) * p.pow(k);
366                ok = true;
367                break;
368            }
369            h = mul(h, xi);
370        }
371        if !ok {
372            return None;
373        }
374    }
375    Some(t)
376}
377
378/// a^x ≡ b (mod p^e)
379fn discrete_logarithm_prime_power(a: u64, b: u64, p: u64, e: u32) -> Option<(u64, u64)> {
380    assert_ne!(p, 0);
381    assert_ne!(e, 0);
382    let n = p.pow(e);
383    assert!(a < n);
384    assert!(b < n);
385    assert_eq!(gcd(a, p), 1);
386    if p == 1 {
387        return Some((0, 1));
388    }
389    if a == 0 {
390        return if b == 0 { Some((1, 1)) } else { None };
391    }
392    if b == 0 {
393        return None;
394    }
395    if e == 1 {
396        return IC.with(|ic| unsafe { &mut *ic.get() }.discrete_logarithm(a, b, p));
397    }
398    let br = BarrettReduction::<u128>::new(n as _);
399    if p == 2 {
400        if e >= 3 {
401            if a % 4 == 1 && b % 4 != 1 {
402                return None;
403            }
404            let aa = if a % 4 == 1 { a } else { n - a };
405            let bb = if b % 4 == 1 { b } else { n - b };
406            let g = 5;
407            let ord = n / 4;
408            let x = pohlig_hellman_prime_power_order(g, aa, n, p, e - 2)?;
409            let y = pohlig_hellman_prime_power_order(g, bb, n, p, e - 2)?;
410            let t = solve_linear_congruence(x, y, ord)?;
411            match (a % 4 == 1, b % 4 == 1) {
412                (true, true) => Some(t),
413                (false, true) if t.0 % 2 == 0 => Some((t.0, lcm(t.1, 2))),
414                (false, false) if t.0 % 2 == 1 => Some((t.0, lcm(t.1, 2))),
415                (false, false) if a == b => Some((1, lcm(t.1, 2))),
416                _ => None,
417            }
418        } else if a == 1 {
419            if b == 1 { Some((0, 1)) } else { None }
420        } else {
421            assert_eq!(a, 3);
422            if b == 1 {
423                Some((0, 2))
424            } else if b == 3 {
425                Some((1, 2))
426            } else {
427                None
428            }
429        }
430    } else {
431        let ord = n - n / p;
432        let pf_ord = prime_factors(ord);
433        let g = (2..)
434            .find(|&g| check_primitive_root(g, ord, &br, &pf_ord))
435            .unwrap();
436        let mut pf_p = prime_factors(p - 1);
437        pf_p.push((p, e - 1));
438        let mut abm = vec![];
439        for (q, c) in pf_p {
440            let m = q.pow(c);
441            let d = ord / m;
442            let gg = pow(g, d, &br);
443            let aa = pow(a, d, &br);
444            let bb = pow(b, d, &br);
445            let x = pohlig_hellman_prime_power_order(gg, aa, n, q, c)?;
446            let y = pohlig_hellman_prime_power_order(gg, bb, n, q, c)?;
447            abm.push((x, y, m));
448        }
449        solve_linear_congruences(abm)
450    }
451}
452
453/// a^x ≡ b (mod n)
454pub fn discrete_logarithm(a: u64, b: u64, n: u64) -> Option<u64> {
455    let a = a % n;
456    let b = b % n;
457    let d = 2.max(64 - n.leading_zeros() as u64);
458    let mut pw = 1 % n;
459    for i in 0..d {
460        if pw == b {
461            return Some(i);
462        }
463        pw = (pw as u128 * a as u128 % n as u128) as u64;
464    }
465    let g = gcd(pw, n);
466    if b % g != 0 {
467        return None;
468    }
469    let n = n / g;
470    let b = (b as u128 * modinv(pw, n) as u128 % n as u128) as u64;
471    let pf = prime_factors(n);
472    let mut abm = vec![];
473    for (p, e) in pf {
474        let q = p.pow(e);
475        let x = discrete_logarithm_prime_power(a % q, b % q, p, e)?;
476        abm.push((1, x.0, x.1));
477    }
478    solve_linear_congruences(abm).map(|x| x.0 + d)
479}