competitive/math/
mod_sqrt.rs

1use super::{MInt, MIntConvert, One, Zero};
2use std::cmp::Ordering;
3
4impl<M> MInt<M>
5where
6    M: MIntConvert<u32>,
7{
8    pub fn sqrt(self) -> Option<Self> {
9        fn jacobi<M>(mut x: u32) -> i8
10        where
11            M: MIntConvert<u32>,
12        {
13            let mut s = 1i8;
14            let mut m = M::mod_into();
15            while m > 1 {
16                x %= m;
17                if x == 0 {
18                    return 0;
19                }
20                let k = x.trailing_zeros();
21                if k % 2 == 1 && (m + 2) & 4 != 0 {
22                    s = -s;
23                }
24                x >>= k;
25                if x & m & 2 != 0 {
26                    s = -s;
27                }
28                std::mem::swap(&mut x, &mut m);
29            }
30            s
31        }
32        if M::mod_into() == 2 {
33            return Some(self);
34        }
35        let j = jacobi::<M>(u32::from(self));
36        match j.cmp(&0) {
37            Ordering::Less => {
38                return None;
39            }
40            Ordering::Equal => {
41                return Some(Self::zero());
42            }
43            Ordering::Greater => {}
44        }
45        let mut r = 1;
46        let (mut f0, d) = loop {
47            r ^= r << 5;
48            r ^= r >> 17;
49            r ^= r << 11;
50            let b = Self::from(r);
51            let d = b * b - self;
52            if jacobi::<M>(u32::from(d)) == -1 {
53                break (b, d);
54            }
55        };
56        let (mut f1, mut g0, mut g1, mut e) = (
57            Self::one(),
58            Self::one(),
59            Self::zero(),
60            (M::mod_into() + 1) / 2,
61        );
62        while e > 0 {
63            if e % 2 == 1 {
64                let t = g0 * f0 + d * g1 * f1;
65                g1 = g0 * f1 + g1 * f0;
66                g0 = t;
67            }
68            let t = f0 * f0 + d * f1 * f1;
69            f1 = Self::from(2) * f0 * f1;
70            f0 = t;
71            e /= 2;
72        }
73        if u32::from(g0) > M::mod_into() - u32::from(g0) {
74            g0 = -g0;
75        }
76        Some(g0)
77    }
78}