competitive/math/
mod_sqrt.rs1use super::{MInt, MIntConvert, One, Zero};
2use std::cmp::Ordering;
3
4impl<M> MInt<M>
5where
6 M: MIntConvert<u32>,
7{
8 pub fn sqrt(self) -> Option<Self> {
9 fn jacobi<M>(mut x: u32) -> i8
10 where
11 M: MIntConvert<u32>,
12 {
13 let mut s = 1i8;
14 let mut m = M::mod_into();
15 while m > 1 {
16 x %= m;
17 if x == 0 {
18 return 0;
19 }
20 let k = x.trailing_zeros();
21 if k % 2 == 1 && (m + 2) & 4 != 0 {
22 s = -s;
23 }
24 x >>= k;
25 if x & m & 2 != 0 {
26 s = -s;
27 }
28 std::mem::swap(&mut x, &mut m);
29 }
30 s
31 }
32 if M::mod_into() == 2 {
33 return Some(self);
34 }
35 let j = jacobi::<M>(u32::from(self));
36 match j.cmp(&0) {
37 Ordering::Less => {
38 return None;
39 }
40 Ordering::Equal => {
41 return Some(Self::zero());
42 }
43 Ordering::Greater => {}
44 }
45 let mut r = 1;
46 let (mut f0, d) = loop {
47 r ^= r << 5;
48 r ^= r >> 17;
49 r ^= r << 11;
50 let b = Self::from(r);
51 let d = b * b - self;
52 if jacobi::<M>(u32::from(d)) == -1 {
53 break (b, d);
54 }
55 };
56 let (mut f1, mut g0, mut g1, mut e) = (
57 Self::one(),
58 Self::one(),
59 Self::zero(),
60 (M::mod_into() + 1) / 2,
61 );
62 while e > 0 {
63 if e % 2 == 1 {
64 let t = g0 * f0 + d * g1 * f1;
65 g1 = g0 * f1 + g1 * f0;
66 g0 = t;
67 }
68 let t = f0 * f0 + d * f1 * f1;
69 f1 = Self::from(2) * f0 * f1;
70 f0 = t;
71 e /= 2;
72 }
73 if u32::from(g0) > M::mod_into() - u32::from(g0) {
74 g0 = -g0;
75 }
76 Some(g0)
77 }
78}