1use super::*;
2
3impl<M> MIntBase for M
4where
5 M: MontgomeryReduction32,
6{
7 type Inner = u32;
8 fn get_mod() -> Self::Inner {
9 <Self as MontgomeryReduction32>::MOD
10 }
11 fn mod_zero() -> Self::Inner {
12 0
13 }
14 fn mod_one() -> Self::Inner {
15 Self::N1
16 }
17 fn mod_add(x: Self::Inner, y: Self::Inner) -> Self::Inner {
18 let z = x + y;
19 let m = Self::get_mod();
20 if z >= m { z - m } else { z }
21 }
22 fn mod_sub(x: Self::Inner, y: Self::Inner) -> Self::Inner {
23 if x < y {
24 x + Self::get_mod() - y
25 } else {
26 x - y
27 }
28 }
29 fn mod_mul(x: Self::Inner, y: Self::Inner) -> Self::Inner {
30 Self::reduce(x as u64 * y as u64)
31 }
32 fn mod_div(x: Self::Inner, y: Self::Inner) -> Self::Inner {
33 Self::mod_mul(x, Self::mod_inv(y))
34 }
35 fn mod_neg(x: Self::Inner) -> Self::Inner {
36 if x == 0 { 0 } else { Self::get_mod() - x }
37 }
38 fn mod_inv(x: Self::Inner) -> Self::Inner {
39 let p = Self::get_mod() as i32;
40 let (mut a, mut b) = (x as i32, p);
41 let (mut u, mut x) = (1, 0);
42 while a != 0 {
43 let k = b / a;
44 x -= k * u;
45 b -= k * a;
46 std::mem::swap(&mut x, &mut u);
47 std::mem::swap(&mut b, &mut a);
48 }
49 Self::reduce((if x < 0 { x + p } else { x }) as u64 * Self::N3 as u64)
50 }
51 fn mod_inner(x: Self::Inner) -> Self::Inner {
52 Self::reduce(x as u64)
53 }
54}
55impl<M> MIntConvert<u32> for M
56where
57 M: MontgomeryReduction32,
58{
59 fn from(x: u32) -> Self::Inner {
60 Self::reduce(x as u64 * Self::N2 as u64)
61 }
62 fn into(x: Self::Inner) -> u32 {
63 Self::reduce(x as u64)
64 }
65 fn mod_into() -> u32 {
66 <Self as MIntBase>::get_mod()
67 }
68}
69impl<M> MIntConvert<u64> for M
70where
71 M: MontgomeryReduction32,
72{
73 fn from(x: u64) -> Self::Inner {
74 Self::reduce(x % Self::get_mod() as u64 * Self::N2 as u64)
75 }
76 fn into(x: Self::Inner) -> u64 {
77 Self::reduce(x as u64) as u64
78 }
79 fn mod_into() -> u64 {
80 <Self as MIntBase>::get_mod() as u64
81 }
82}
83impl<M> MIntConvert<usize> for M
84where
85 M: MontgomeryReduction32,
86{
87 fn from(x: usize) -> Self::Inner {
88 Self::reduce(x as u64 % Self::get_mod() as u64 * Self::N2 as u64)
89 }
90 fn into(x: Self::Inner) -> usize {
91 Self::reduce(x as u64) as usize
92 }
93 fn mod_into() -> usize {
94 <Self as MIntBase>::get_mod() as usize
95 }
96}
97impl<M> MIntConvert<i32> for M
98where
99 M: MontgomeryReduction32,
100{
101 fn from(x: i32) -> Self::Inner {
102 let x = x % <Self as MIntBase>::get_mod() as i32;
103 let x = if x < 0 {
104 (x + <Self as MIntBase>::get_mod() as i32) as u64
105 } else {
106 x as u64
107 };
108 Self::reduce(x * Self::N2 as u64)
109 }
110 fn into(x: Self::Inner) -> i32 {
111 Self::reduce(x as u64) as i32
112 }
113 fn mod_into() -> i32 {
114 <Self as MIntBase>::get_mod() as i32
115 }
116}
117impl<M> MIntConvert<i64> for M
118where
119 M: MontgomeryReduction32,
120{
121 fn from(x: i64) -> Self::Inner {
122 let x = x % <Self as MIntBase>::get_mod() as i64;
123 let x = if x < 0 {
124 (x + <Self as MIntBase>::get_mod() as i64) as u64
125 } else {
126 x as u64
127 };
128 Self::reduce(x * Self::N2 as u64)
129 }
130 fn into(x: Self::Inner) -> i64 {
131 Self::reduce(x as u64) as i64
132 }
133 fn mod_into() -> i64 {
134 <Self as MIntBase>::get_mod() as i64
135 }
136}
137impl<M> MIntConvert<isize> for M
138where
139 M: MontgomeryReduction32,
140{
141 fn from(x: isize) -> Self::Inner {
142 let x = x % <Self as MIntBase>::get_mod() as isize;
143 let x = if x < 0 {
144 (x + <Self as MIntBase>::get_mod() as isize) as u64
145 } else {
146 x as u64
147 };
148 Self::reduce(x * Self::N2 as u64)
149 }
150 fn into(x: Self::Inner) -> isize {
151 Self::reduce(x as u64) as isize
152 }
153 fn mod_into() -> isize {
154 <Self as MIntBase>::get_mod() as isize
155 }
156}
157pub trait MontgomeryReduction32 {
159 const MOD: u32;
161 const R: u32 = {
163 let m = Self::MOD;
164 let mut r = 0;
165 let mut t = 0;
166 let mut i = 0;
167 while i < 32 {
168 if t % 2 == 0 {
169 t += m;
170 r += 1 << i;
171 }
172 t /= 2;
173 i += 1;
174 }
175 r
176 };
177 const N1: u32 = ((1u64 << 32) % Self::MOD as u64) as _;
179 const N2: u32 = (Self::N1 as u64 * Self::N1 as u64 % Self::MOD as u64) as _;
181 const N3: u32 = (Self::N1 as u64 * Self::N2 as u64 % Self::MOD as u64) as _;
183 fn reduce(x: u64) -> u32 {
185 let m: u32 = Self::MOD;
186 let r = Self::R;
187 let mut x = ((x + r.wrapping_mul(x as u32) as u64 * m as u64) >> 32) as u32;
188 if x >= m {
189 x -= m;
190 }
191 x
192 }
193}
194macro_rules! define_montgomery_reduction_32 {
195 ($([$name:ident, $m:expr, $mint_name:ident $(,)?]),* $(,)?) => {
196 $(
197 pub enum $name {}
198 impl MontgomeryReduction32 for $name {
199 const MOD: u32 = $m;
200 }
201 pub type $mint_name = MInt<$name>;
202 )*
203 };
204}
205define_montgomery_reduction_32!(
206 [Modulo998244353, 998_244_353, MInt998244353],
207 [Modulo2113929217, 2_113_929_217, MInt2113929217],
208 [Modulo1811939329, 1_811_939_329, MInt1811939329],
209 [Modulo2013265921, 2_013_265_921, MInt2013265921],
210);
211
212#[cfg(target_arch = "x86_64")]
213#[allow(unsafe_op_in_unsafe_fn)] pub mod simd32 {
215 use std::arch::x86_64::*;
216
217 #[target_feature(enable = "avx2")]
218 unsafe fn my256_mullo_epu32(a: __m256i, b: __m256i) -> __m256i {
219 _mm256_mullo_epi32(a, b)
220 }
221
222 #[target_feature(enable = "avx2")]
223 unsafe fn my256_mulhi_epu32(a: __m256i, b: __m256i) -> __m256i {
224 let a13 = _mm256_shuffle_epi32(a, 0xF5);
225 let b13 = _mm256_shuffle_epi32(b, 0xF5);
226 let prod02 = _mm256_mul_epu32(a, b);
227 let prod13 = _mm256_mul_epu32(a13, b13);
228 let t0 = _mm256_unpacklo_epi32(prod02, prod13);
229 let t1 = _mm256_unpackhi_epi32(prod02, prod13);
230 _mm256_unpackhi_epi64(t0, t1)
231 }
232
233 #[target_feature(enable = "avx2")]
234 pub unsafe fn montgomery_mul_256(
235 a: __m256i,
236 b: __m256i,
237 r_vec: __m256i,
238 mod_vec: __m256i,
239 ) -> __m256i {
240 let hi = my256_mulhi_epu32(a, b);
241 let lo = my256_mullo_epu32(a, b);
242 let lo = my256_mullo_epu32(lo, r_vec);
243 let lo = my256_mulhi_epu32(lo, mod_vec);
244 _mm256_sub_epi32(_mm256_add_epi32(hi, mod_vec), lo)
245 }
246
247 #[target_feature(enable = "avx2")]
248 pub unsafe fn add_mod_256(a: __m256i, b: __m256i, mod_vec: __m256i, sign: __m256i) -> __m256i {
249 let sum = _mm256_add_epi32(a, b);
250 let sum_x = _mm256_xor_si256(sum, sign);
251 let mod_x = _mm256_xor_si256(mod_vec, sign);
252 let gt = _mm256_cmpgt_epi32(sum_x, mod_x);
253 let eq = _mm256_cmpeq_epi32(sum, mod_vec);
254 let mask = _mm256_or_si256(gt, eq);
255 let sub = _mm256_and_si256(mod_vec, mask);
256 _mm256_sub_epi32(sum, sub)
257 }
258
259 #[target_feature(enable = "avx2")]
260 pub unsafe fn sub_mod_256(a: __m256i, b: __m256i, mod_vec: __m256i, sign: __m256i) -> __m256i {
261 let diff = _mm256_sub_epi32(a, b);
262 let a_x = _mm256_xor_si256(a, sign);
263 let b_x = _mm256_xor_si256(b, sign);
264 let mask = _mm256_cmpgt_epi32(b_x, a_x);
265 let add = _mm256_and_si256(mod_vec, mask);
266 _mm256_add_epi32(diff, add)
267 }
268
269 #[target_feature(enable = "avx2")]
270 pub unsafe fn montgomery_mul_256_canon(
271 a: __m256i,
272 b: __m256i,
273 r_vec: __m256i,
274 mod_vec: __m256i,
275 sign: __m256i,
276 ) -> __m256i {
277 let x = montgomery_mul_256(a, b, r_vec, mod_vec);
278 add_mod_256(x, _mm256_setzero_si256(), mod_vec, sign)
279 }
280
281 #[target_feature(enable = "avx2")]
282 pub unsafe fn montgomery_add_256(
283 a: __m256i,
284 b: __m256i,
285 mod2_vec: __m256i,
286 sign: __m256i,
287 ) -> __m256i {
288 let sum = _mm256_add_epi32(a, b);
289 let sum_x = _mm256_xor_si256(sum, sign);
290 let mod_x = _mm256_xor_si256(mod2_vec, sign);
291 let gt = _mm256_cmpgt_epi32(sum_x, mod_x);
292 let eq = _mm256_cmpeq_epi32(sum, mod2_vec);
293 let mask = _mm256_or_si256(gt, eq);
294 let sub = _mm256_and_si256(mod2_vec, mask);
295 _mm256_sub_epi32(sum, sub)
296 }
297
298 #[target_feature(enable = "avx2")]
299 pub unsafe fn montgomery_sub_256(
300 a: __m256i,
301 b: __m256i,
302 mod2_vec: __m256i,
303 sign: __m256i,
304 ) -> __m256i {
305 let diff = _mm256_sub_epi32(a, b);
306 let a_x = _mm256_xor_si256(a, sign);
307 let b_x = _mm256_xor_si256(b, sign);
308 let mask = _mm256_cmpgt_epi32(b_x, a_x);
309 let add = _mm256_and_si256(mod2_vec, mask);
310 _mm256_add_epi32(diff, add)
311 }
312
313 #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
314 unsafe fn my512_mullo_epu32(a: __m512i, b: __m512i) -> __m512i {
315 _mm512_mullo_epi32(a, b)
316 }
317
318 #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
319 unsafe fn my512_mulhi_epu32(a: __m512i, b: __m512i) -> __m512i {
320 let a13 = _mm512_shuffle_epi32(a, 0xF5);
321 let b13 = _mm512_shuffle_epi32(b, 0xF5);
322 let prod02 = _mm512_mul_epu32(a, b);
323 let prod13 = _mm512_mul_epu32(a13, b13);
324 let t0 = _mm512_unpacklo_epi32(prod02, prod13);
325 let t1 = _mm512_unpackhi_epi32(prod02, prod13);
326 _mm512_unpackhi_epi64(t0, t1)
327 }
328
329 #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
330 pub unsafe fn montgomery_mul_512(
331 a: __m512i,
332 b: __m512i,
333 r_vec: __m512i,
334 mod_vec: __m512i,
335 ) -> __m512i {
336 let hi = my512_mulhi_epu32(a, b);
337 let lo = my512_mullo_epu32(a, b);
338 let lo = my512_mullo_epu32(lo, r_vec);
339 let lo = my512_mulhi_epu32(lo, mod_vec);
340 _mm512_sub_epi32(_mm512_add_epi32(hi, mod_vec), lo)
341 }
342
343 #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
344 pub unsafe fn add_mod_512(a: __m512i, b: __m512i, mod_vec: __m512i) -> __m512i {
345 let sum = _mm512_add_epi32(a, b);
346 let mask = !_mm512_cmp_epu32_mask(sum, mod_vec, _MM_CMPINT_LT);
347 _mm512_mask_sub_epi32(sum, mask, sum, mod_vec)
348 }
349
350 #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
351 pub unsafe fn sub_mod_512(a: __m512i, b: __m512i, mod_vec: __m512i) -> __m512i {
352 let diff = _mm512_sub_epi32(a, b);
353 let mask = _mm512_cmp_epu32_mask(a, b, _MM_CMPINT_LT);
354 _mm512_mask_add_epi32(diff, mask, diff, mod_vec)
355 }
356
357 #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
358 pub unsafe fn montgomery_mul_512_canon(
359 a: __m512i,
360 b: __m512i,
361 r_vec: __m512i,
362 mod_vec: __m512i,
363 ) -> __m512i {
364 let x = montgomery_mul_512(a, b, r_vec, mod_vec);
365 add_mod_512(x, _mm512_setzero_si512(), mod_vec)
366 }
367
368 #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
369 pub unsafe fn montgomery_add_512(a: __m512i, b: __m512i, mod2_vec: __m512i) -> __m512i {
370 let sum = _mm512_add_epi32(a, b);
371 let mask = !_mm512_cmp_epu32_mask(sum, mod2_vec, _MM_CMPINT_LT);
372 _mm512_mask_sub_epi32(sum, mask, sum, mod2_vec)
373 }
374
375 #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
376 pub unsafe fn montgomery_sub_512(a: __m512i, b: __m512i, mod2_vec: __m512i) -> __m512i {
377 let diff = _mm512_sub_epi32(a, b);
378 let mask = _mm512_cmp_epu32_mask(a, b, _MM_CMPINT_LT);
379 _mm512_mask_add_epi32(diff, mask, diff, mod2_vec)
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use crate::num::montgomery::MInt998244353 as M;
387 use crate::tools::Xorshift;
388
389 #[test]
390 fn test_mint998244353() {
391 let mut rng = Xorshift::default();
392 const Q: usize = 1000;
393 assert_eq!(0, MInt998244353::zero().inner());
394 assert_eq!(1, MInt998244353::one().inner());
395 assert_eq!(
396 Modulo998244353::reduce(Modulo998244353::N3 as u64),
397 Modulo998244353::N2
398 );
399 assert_eq!(
400 Modulo998244353::reduce(Modulo998244353::N2 as u64),
401 Modulo998244353::N1
402 );
403 assert_eq!(Modulo998244353::reduce(Modulo998244353::N1 as u64), 1);
404 for _ in 0..Q {
405 let x = rng.random(..MInt998244353::get_mod());
406 assert_eq!(x, MInt998244353::new(x).inner());
407 assert_eq!((-M::new(x)).inner(), (-MInt998244353::new(x)).inner());
408 assert_eq!(x, MInt998244353::new(x).inv().inv().inner());
409 assert_eq!(M::new(x).inv().inner(), MInt998244353::new(x).inv().inner());
410 }
411
412 for _ in 0..Q {
413 let x = rng.random(..MInt998244353::get_mod());
414 let y = rng.random(..MInt998244353::get_mod());
415 assert_eq!(
416 (M::new(x) + M::new(y)).inner(),
417 (MInt998244353::new(x) + MInt998244353::new(y)).inner()
418 );
419 assert_eq!(
420 (M::new(x) - M::new(y)).inner(),
421 (MInt998244353::new(x) - MInt998244353::new(y)).inner()
422 );
423 assert_eq!(
424 (M::new(x) * M::new(y)).inner(),
425 (MInt998244353::new(x) * MInt998244353::new(y)).inner()
426 );
427 assert_eq!(
428 (M::new(x) / M::new(y)).inner(),
429 (MInt998244353::new(x) / MInt998244353::new(y)).inner()
430 );
431 assert_eq!(
432 M::new(x).pow(y as usize).inner(),
433 MInt998244353::new(x).pow(y as usize).inner()
434 );
435 }
436
437 for _ in 0..Q {
438 let x = rng.rand64();
439 assert_eq!(
440 M::from(x as u32).inner(),
441 MInt998244353::from(x as u32).inner()
442 );
443 assert_eq!(M::from(x).inner(), MInt998244353::from(x).inner());
444 assert_eq!(
445 M::from(x as usize).inner(),
446 MInt998244353::from(x as usize).inner()
447 );
448 assert_eq!(
449 M::from(x as i32).inner(),
450 MInt998244353::from(x as i32).inner()
451 );
452 assert_eq!(
453 M::from(x as i64).inner(),
454 MInt998244353::from(x as i64).inner()
455 );
456 assert_eq!(
457 M::from(x as isize).inner(),
458 MInt998244353::from(x as isize).inner()
459 );
460 }
461 }
462}