competitive/num/mint/
montgomery.rs1use super::*;
2
3impl<M> MIntBase for M
4where
5 M: MontgomeryReduction32,
6{
7 type Inner = u32;
8 #[inline]
9 fn get_mod() -> Self::Inner {
10 <Self as MontgomeryReduction32>::MOD
11 }
12 #[inline]
13 fn mod_zero() -> Self::Inner {
14 0
15 }
16 #[inline]
17 fn mod_one() -> Self::Inner {
18 Self::N1
19 }
20 #[inline]
21 fn mod_add(x: Self::Inner, y: Self::Inner) -> Self::Inner {
22 let z = x + y;
23 let m = Self::get_mod();
24 if z >= m { z - m } else { z }
25 }
26 #[inline]
27 fn mod_sub(x: Self::Inner, y: Self::Inner) -> Self::Inner {
28 if x < y {
29 x + Self::get_mod() - y
30 } else {
31 x - y
32 }
33 }
34 #[inline]
35 fn mod_mul(x: Self::Inner, y: Self::Inner) -> Self::Inner {
36 Self::reduce(x as u64 * y as u64)
37 }
38 #[inline]
39 fn mod_div(x: Self::Inner, y: Self::Inner) -> Self::Inner {
40 Self::mod_mul(x, Self::mod_inv(y))
41 }
42 #[inline]
43 fn mod_neg(x: Self::Inner) -> Self::Inner {
44 if x == 0 { 0 } else { Self::get_mod() - x }
45 }
46 fn mod_inv(x: Self::Inner) -> Self::Inner {
47 let p = Self::get_mod() as i32;
48 let (mut a, mut b) = (x as i32, p);
49 let (mut u, mut x) = (1, 0);
50 while a != 0 {
51 let k = b / a;
52 x -= k * u;
53 b -= k * a;
54 std::mem::swap(&mut x, &mut u);
55 std::mem::swap(&mut b, &mut a);
56 }
57 Self::reduce((if x < 0 { x + p } else { x }) as u64 * Self::N3 as u64)
58 }
59 fn mod_inner(x: Self::Inner) -> Self::Inner {
60 Self::reduce(x as u64)
61 }
62}
63impl<M> MIntConvert<u32> for M
64where
65 M: MontgomeryReduction32,
66{
67 #[inline]
68 fn from(x: u32) -> Self::Inner {
69 Self::reduce(x as u64 * Self::N2 as u64)
70 }
71 #[inline]
72 fn into(x: Self::Inner) -> u32 {
73 Self::reduce(x as u64)
74 }
75 #[inline]
76 fn mod_into() -> u32 {
77 <Self as MIntBase>::get_mod()
78 }
79}
80impl<M> MIntConvert<u64> for M
81where
82 M: MontgomeryReduction32,
83{
84 #[inline]
85 fn from(x: u64) -> Self::Inner {
86 Self::reduce(x % Self::get_mod() as u64 * Self::N2 as u64)
87 }
88 #[inline]
89 fn into(x: Self::Inner) -> u64 {
90 Self::reduce(x as u64) as u64
91 }
92 #[inline]
93 fn mod_into() -> u64 {
94 <Self as MIntBase>::get_mod() as u64
95 }
96}
97impl<M> MIntConvert<usize> for M
98where
99 M: MontgomeryReduction32,
100{
101 #[inline]
102 fn from(x: usize) -> Self::Inner {
103 Self::reduce(x as u64 % Self::get_mod() as u64 * Self::N2 as u64)
104 }
105 #[inline]
106 fn into(x: Self::Inner) -> usize {
107 Self::reduce(x as u64) as usize
108 }
109 #[inline]
110 fn mod_into() -> usize {
111 <Self as MIntBase>::get_mod() as usize
112 }
113}
114impl<M> MIntConvert<i32> for M
115where
116 M: MontgomeryReduction32,
117{
118 #[inline]
119 fn from(x: i32) -> Self::Inner {
120 let x = x % <Self as MIntBase>::get_mod() as i32;
121 let x = if x < 0 {
122 (x + <Self as MIntBase>::get_mod() as i32) as u64
123 } else {
124 x as u64
125 };
126 Self::reduce(x * Self::N2 as u64)
127 }
128 #[inline]
129 fn into(x: Self::Inner) -> i32 {
130 Self::reduce(x as u64) as i32
131 }
132 #[inline]
133 fn mod_into() -> i32 {
134 <Self as MIntBase>::get_mod() as i32
135 }
136}
137impl<M> MIntConvert<i64> for M
138where
139 M: MontgomeryReduction32,
140{
141 #[inline]
142 fn from(x: i64) -> Self::Inner {
143 let x = x % <Self as MIntBase>::get_mod() as i64;
144 let x = if x < 0 {
145 (x + <Self as MIntBase>::get_mod() as i64) as u64
146 } else {
147 x as u64
148 };
149 Self::reduce(x * Self::N2 as u64)
150 }
151 #[inline]
152 fn into(x: Self::Inner) -> i64 {
153 Self::reduce(x as u64) as i64
154 }
155 #[inline]
156 fn mod_into() -> i64 {
157 <Self as MIntBase>::get_mod() as i64
158 }
159}
160impl<M> MIntConvert<isize> for M
161where
162 M: MontgomeryReduction32,
163{
164 #[inline]
165 fn from(x: isize) -> Self::Inner {
166 let x = x % <Self as MIntBase>::get_mod() as isize;
167 let x = if x < 0 {
168 (x + <Self as MIntBase>::get_mod() as isize) as u64
169 } else {
170 x as u64
171 };
172 Self::reduce(x * Self::N2 as u64)
173 }
174 #[inline]
175 fn into(x: Self::Inner) -> isize {
176 Self::reduce(x as u64) as isize
177 }
178 #[inline]
179 fn mod_into() -> isize {
180 <Self as MIntBase>::get_mod() as isize
181 }
182}
183pub trait MontgomeryReduction32 {
185 const MOD: u32;
187 const R: u32 = {
189 let m = Self::MOD;
190 let mut r = 0;
191 let mut t = 0;
192 let mut i = 0;
193 while i < 32 {
194 if t % 2 == 0 {
195 t += m;
196 r += 1 << i;
197 }
198 t /= 2;
199 i += 1;
200 }
201 r
202 };
203 const N1: u32 = ((1u64 << 32) % Self::MOD as u64) as _;
205 const N2: u32 = (Self::N1 as u64 * Self::N1 as u64 % Self::MOD as u64) as _;
207 const N3: u32 = (Self::N1 as u64 * Self::N2 as u64 % Self::MOD as u64) as _;
209 fn reduce(x: u64) -> u32 {
211 let m: u32 = Self::MOD;
212 let r = Self::R;
213 let mut x = ((x + r.wrapping_mul(x as u32) as u64 * m as u64) >> 32) as u32;
214 if x >= m {
215 x -= m;
216 }
217 x
218 }
219}
220macro_rules! define_montgomery_reduction_32 {
221 ($([$name:ident, $m:expr, $mint_name:ident $(,)?]),* $(,)?) => {
222 $(
223 pub enum $name {}
224 impl MontgomeryReduction32 for $name {
225 const MOD: u32 = $m;
226 }
227 pub type $mint_name = MInt<$name>;
228 )*
229 };
230}
231define_montgomery_reduction_32!(
232 [Modulo998244353, 998_244_353, MInt998244353],
233 [Modulo2113929217, 2_113_929_217, MInt2113929217],
234 [Modulo1811939329, 1_811_939_329, MInt1811939329],
235 [Modulo2013265921, 2_013_265_921, MInt2013265921],
236);
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use crate::num::montgomery::MInt998244353 as M;
242 use crate::tools::Xorshift;
243
244 #[test]
245 fn test_mint998244353() {
246 let mut rng = Xorshift::new();
247 const Q: usize = 1000;
248 assert_eq!(0, MInt998244353::zero().inner());
249 assert_eq!(1, MInt998244353::one().inner());
250 assert_eq!(
251 Modulo998244353::reduce(Modulo998244353::N3 as u64),
252 Modulo998244353::N2
253 );
254 assert_eq!(
255 Modulo998244353::reduce(Modulo998244353::N2 as u64),
256 Modulo998244353::N1
257 );
258 assert_eq!(Modulo998244353::reduce(Modulo998244353::N1 as u64), 1);
259 for _ in 0..Q {
260 let x = rng.random(..MInt998244353::get_mod());
261 assert_eq!(x, MInt998244353::new(x).inner());
262 assert_eq!((-M::new(x)).inner(), (-MInt998244353::new(x)).inner());
263 assert_eq!(x, MInt998244353::new(x).inv().inv().inner());
264 assert_eq!(M::new(x).inv().inner(), MInt998244353::new(x).inv().inner());
265 }
266
267 for _ in 0..Q {
268 let x = rng.random(..MInt998244353::get_mod());
269 let y = rng.random(..MInt998244353::get_mod());
270 assert_eq!(
271 (M::new(x) + M::new(y)).inner(),
272 (MInt998244353::new(x) + MInt998244353::new(y)).inner()
273 );
274 assert_eq!(
275 (M::new(x) - M::new(y)).inner(),
276 (MInt998244353::new(x) - MInt998244353::new(y)).inner()
277 );
278 assert_eq!(
279 (M::new(x) * M::new(y)).inner(),
280 (MInt998244353::new(x) * MInt998244353::new(y)).inner()
281 );
282 assert_eq!(
283 (M::new(x) / M::new(y)).inner(),
284 (MInt998244353::new(x) / MInt998244353::new(y)).inner()
285 );
286 assert_eq!(
287 M::new(x).pow(y as usize).inner(),
288 MInt998244353::new(x).pow(y as usize).inner()
289 );
290 }
291
292 for _ in 0..Q {
293 let x = rng.rand64();
294 assert_eq!(
295 M::from(x as u32).inner(),
296 MInt998244353::from(x as u32).inner()
297 );
298 assert_eq!(M::from(x).inner(), MInt998244353::from(x).inner());
299 assert_eq!(
300 M::from(x as usize).inner(),
301 MInt998244353::from(x as usize).inner()
302 );
303 assert_eq!(
304 M::from(x as i32).inner(),
305 MInt998244353::from(x as i32).inner()
306 );
307 assert_eq!(
308 M::from(x as i64).inner(),
309 MInt998244353::from(x as i64).inner()
310 );
311 assert_eq!(
312 M::from(x as isize).inner(),
313 MInt998244353::from(x as isize).inner()
314 );
315 }
316 }
317}