const fn reduce(z: u64, p: u32, r: u32) -> u32Examples found in repository?
crates/competitive/src/math/number_theoretic_transform.rs (line 35)
34const fn mod_mul(x: u32, y: u32, p: u32, r: u32) -> u32 {
35 reduce(x as u64 * y as u64, p, r)
36}
37const fn mod_pow(mut x: u32, mut y: u32, p: u32, r: u32, mut z: u32) -> u32 {
38 while y > 0 {
39 if y & 1 == 1 {
40 z = mod_mul(z, x, p, r);
41 }
42 x = mod_mul(x, x, p, r);
43 y >>= 1;
44 }
45 z
46}
47
48pub trait Montgomery32NttModulus: Sized + MontgomeryReduction32 {
49 const PRIMITIVE_ROOT: u32 = {
50 let mut g = 3u32;
51 loop {
52 let mut ok = true;
53 let mut d = 1u32;
54 while d * d < Self::MOD {
55 if (Self::MOD - 1) % d == 0 {
56 let ds = [d, (Self::MOD - 1) / d];
57 let mut i = 0;
58 while i < 2 {
59 ok &= ds[i] == Self::MOD - 1
60 || mod_pow(
61 reduce(g as u64 * Self::N2 as u64, Self::MOD, Self::R),
62 ds[i],
63 Self::MOD,
64 Self::R,
65 Self::N1,
66 ) != Self::N1;
67 i += 1;
68 }
69 }
70 d += 1;
71 }
72 if ok {
73 break;
74 }
75 g += 2;
76 }
77 g
78 };
79 const RANK: u32 = (Self::MOD - 1).trailing_zeros();
80 const INFO: NttInfo = NttInfo::new::<Self>();
81}
82
83#[derive(Debug, PartialEq)]
84pub struct NttInfo {
85 root: [u32; 32],
86 inv_root: [u32; 32],
87 rate2: [u32; 32],
88 inv_rate2: [u32; 32],
89 rate3: [u32; 32],
90 inv_rate3: [u32; 32],
91}
92impl NttInfo {
93 const fn new<M>() -> Self
94 where
95 M: Montgomery32NttModulus,
96 {
97 let mut root = [0; 32];
98 let mut inv_root = [0; 32];
99 let mut rate2 = [0; 32];
100 let mut inv_rate2 = [0; 32];
101 let mut rate3 = [0; 32];
102 let mut inv_rate3 = [0; 32];
103 let rank = M::RANK as usize;
104
105 let g = reduce(M::PRIMITIVE_ROOT as u64 * M::N2 as u64, M::MOD, M::R);
106 root[rank] = mod_pow(g, (M::MOD - 1) >> rank, M::MOD, M::R, M::N1);
107 inv_root[rank] = mod_pow(root[rank], M::MOD - 2, M::MOD, M::R, M::N1);
108 let mut i = rank - 1;
109 loop {
110 root[i] = mod_mul(root[i + 1], root[i + 1], M::MOD, M::R);
111 inv_root[i] = mod_mul(inv_root[i + 1], inv_root[i + 1], M::MOD, M::R);
112 if i == 0 {
113 break;
114 }
115 i -= 1;
116 }
117
118 let (mut i, mut prod, mut inv_prod) = (0, M::N1, M::N1);
119 while i < rank - 1 {
120 rate2[i] = mod_mul(root[i + 2], prod, M::MOD, M::R);
121 inv_rate2[i] = mod_mul(inv_root[i + 2], inv_prod, M::MOD, M::R);
122 prod = mod_mul(prod, inv_root[i + 2], M::MOD, M::R);
123 inv_prod = mod_mul(inv_prod, root[i + 2], M::MOD, M::R);
124 i += 1;
125 }
126
127 let (mut i, mut prod, mut inv_prod) = (0, M::N1, M::N1);
128 while i < rank - 2 {
129 rate3[i] = mod_mul(root[i + 3], prod, M::MOD, M::R);
130 inv_rate3[i] = mod_mul(inv_root[i + 3], inv_prod, M::MOD, M::R);
131 prod = mod_mul(prod, inv_root[i + 3], M::MOD, M::R);
132 inv_prod = mod_mul(inv_prod, root[i + 3], M::MOD, M::R);
133 i += 1;
134 }
135
136 NttInfo {
137 root,
138 inv_root,
139 rate2,
140 inv_rate2,
141 rate3,
142 inv_rate3,
143 }
144 }