competitive/math/
miller_rabin.rs

1#![allow(clippy::unreadable_literal)]
2use super::BarrettReduction;
3
4macro_rules! impl_test_mr {
5    ($name:ident, $ty:ty, $upty:ty) => {
6        fn $name(n: $ty, br: &BarrettReduction<$upty>, a: $ty) -> bool {
7            if br.rem(a as $upty) == 0 {
8                return true;
9            }
10            let d = n - 1;
11            let k = d.trailing_zeros();
12            let mut d = d >> k;
13            let mut y = {
14                let mut a = a as $upty;
15                let mut y: $upty = 1;
16                while d > 0 {
17                    if d & 1 == 1 {
18                        y = br.rem(y * a);
19                    }
20                    a = br.rem(a * a);
21                    d >>= 1;
22                }
23                y as $ty
24            };
25            if y == 1 || y == n - 1 {
26                true
27            } else {
28                for _ in 0..k - 1 {
29                    y = br.rem(y as $upty * y as $upty) as $ty;
30                    if y == n - 1 {
31                        return true;
32                    }
33                }
34                false
35            }
36        }
37    };
38}
39impl_test_mr!(test_mr32, u32, u64);
40impl_test_mr!(test_mr64, u64, u128);
41
42/// http://miller-rabin.appspot.com/
43macro_rules! impl_mr {
44    ($name:ident, $test:ident, $ty:ty, $upty:ty, [$($th:expr, [$($a:expr),+]),+], |$n:ident, $br:ident|$last:expr) => {
45        fn $name($n: $ty, $br: &BarrettReduction<$upty>) -> bool {
46            $(
47                if $n >= $th {
48                    return $($test($n, $br, $a))&&+
49                }
50            )+
51            $last
52        }
53    };
54}
55impl_mr!(
56    mr32,
57    test_mr32,
58    u32,
59    u64,
60    [316349281, [2, 7, 61], 49141, [11000544, 31481107]],
61    |n, br| test_mr32(n, br, 921211727)
62);
63impl_mr!(
64    mr64,
65    test_mr64,
66    u64,
67    u128,
68    [
69        585226005592931977,
70        [2, 325, 9375, 28178, 450775, 9780504, 1795265022],
71        7999252175582851,
72        [
73            2,
74            123635709730000,
75            9233062284813009,
76            43835965440333360,
77            761179012939631437,
78            1263739024124850375
79        ],
80        55245642489451,
81        [
82            2,
83            4130806001517,
84            149795463772692060,
85            186635894390467037,
86            3967304179347715805
87        ],
88        350269456337,
89        [
90            2,
91            141889084524735,
92            1199124725622454117,
93            11096072698276303650
94        ],
95        1050535501,
96        [
97            4230279247111683200,
98            14694767155120705706,
99            16641139526367750375
100        ]
101    ],
102    |n, br| mr32(n as u32, &BarrettReduction::<u64>::new(n))
103);
104
105pub fn miller_rabin_with_br(n: u64, br: &BarrettReduction<u128>) -> bool {
106    if n % 2 == 0 {
107        return n == 2;
108    }
109    if n % 3 == 0 {
110        return n == 3;
111    }
112    if n % 5 == 0 {
113        return n == 5;
114    }
115    if n % 7 == 0 {
116        return n == 7;
117    }
118    if n < 121 { n > 2 } else { mr64(n, br) }
119}
120
121pub fn miller_rabin(n: u64) -> bool {
122    miller_rabin_with_br(n, &BarrettReduction::<u128>::new(n as u128))
123}
124
125#[test]
126fn test_miller_rabin() {
127    const N: u32 = 1_000_000;
128    let primes = super::PrimeTable::new(N);
129    for i in 1..=N {
130        assert_eq!(primes.is_prime(i), miller_rabin(i as _), "{}", i);
131    }
132    assert!(miller_rabin(1_000_000_007));
133    assert!(!miller_rabin(1_000_000_011));
134}