competitive/math/
miller_rabin.rs

1use super::BarrettReduction;
2
3macro_rules! impl_test_mr {
4    ($name:ident, $ty:ty, $upty:ty) => {
5        fn $name(n: $ty, br: &BarrettReduction<$upty>, a: $ty) -> bool {
6            if br.rem(a as $upty) == 0 {
7                return true;
8            }
9            let d = n - 1;
10            let k = d.trailing_zeros();
11            let mut d = d >> k;
12            let mut y = {
13                let mut a = a as $upty;
14                let mut y: $upty = 1;
15                while d > 0 {
16                    if d & 1 == 1 {
17                        y = br.rem(y * a);
18                    }
19                    a = br.rem(a * a);
20                    d >>= 1;
21                }
22                y as $ty
23            };
24            if y == 1 || y == n - 1 {
25                true
26            } else {
27                for _ in 0..k - 1 {
28                    y = br.rem(y as $upty * y as $upty) as $ty;
29                    if y == n - 1 {
30                        return true;
31                    }
32                }
33                false
34            }
35        }
36    };
37}
38impl_test_mr!(test_mr32, u32, u64);
39impl_test_mr!(test_mr64, u64, u128);
40
41/// http://miller-rabin.appspot.com/
42macro_rules! impl_mr {
43    ($name:ident, $test:ident, $ty:ty, $upty:ty, [$($th:expr, [$($a:expr),+]),+], |$n:ident, $br:ident|$last:expr) => {
44        fn $name($n: $ty, $br: &BarrettReduction<$upty>) -> bool {
45            $(
46                if $n >= $th {
47                    return $($test($n, $br, $a))&&+
48                }
49            )+
50            $last
51        }
52    };
53}
54impl_mr!(
55    mr32,
56    test_mr32,
57    u32,
58    u64,
59    [316349281, [2, 7, 61], 49141, [11000544, 31481107]],
60    |n, br| test_mr32(n, br, 921211727)
61);
62impl_mr!(
63    mr64,
64    test_mr64,
65    u64,
66    u128,
67    [
68        585226005592931977,
69        [2, 325, 9375, 28178, 450775, 9780504, 1795265022],
70        7999252175582851,
71        [
72            2,
73            123635709730000,
74            9233062284813009,
75            43835965440333360,
76            761179012939631437,
77            1263739024124850375
78        ],
79        55245642489451,
80        [
81            2,
82            4130806001517,
83            149795463772692060,
84            186635894390467037,
85            3967304179347715805
86        ],
87        350269456337,
88        [
89            2,
90            141889084524735,
91            1199124725622454117,
92            11096072698276303650
93        ],
94        1050535501,
95        [
96            4230279247111683200,
97            14694767155120705706,
98            16641139526367750375
99        ]
100    ],
101    |n, br| mr32(n as u32, &BarrettReduction::<u64>::new(n))
102);
103
104pub fn miller_rabin_with_br(n: u64, br: &BarrettReduction<u128>) -> bool {
105    if n.is_multiple_of(2) {
106        return n == 2;
107    }
108    if n.is_multiple_of(3) {
109        return n == 3;
110    }
111    if n.is_multiple_of(5) {
112        return n == 5;
113    }
114    if n.is_multiple_of(7) {
115        return n == 7;
116    }
117    if n < 121 { n > 2 } else { mr64(n, br) }
118}
119
120pub fn miller_rabin(n: u64) -> bool {
121    miller_rabin_with_br(n, &BarrettReduction::<u128>::new(n as u128))
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use crate::math::PrimeTable;
128
129    #[test]
130    fn test_miller_rabin() {
131        const N: u32 = 1_000_000;
132        let primes = PrimeTable::new(N);
133        for i in 1..=N {
134            assert_eq!(primes.is_prime(i), miller_rabin(i as _), "{}", i);
135        }
136        assert!(miller_rabin(1_000_000_007));
137        assert!(!miller_rabin(1_000_000_011));
138    }
139}