unsafe fn add_vec_avx2<M>(
a: __m256i,
b: __m256i,
mod_vec: __m256i,
mod2_vec: __m256i,
sign: __m256i,
) -> __m256iwhere
M: Montgomery32NttModulus,Examples found in repository?
crates/competitive/src/math/number_theoretic_transform.rs (line 482)
437 pub(super) unsafe fn ntt_avx2<M>(a: &mut [MInt<M>])
438 where
439 M: Montgomery32NttModulus,
440 {
441 let n = a.len();
442 if n <= 1 {
443 return;
444 }
445 let ptr = a.as_mut_ptr() as *mut u32;
446 let a = std::slice::from_raw_parts_mut(ptr, n);
447 let mod_vec = _mm256_set1_epi32(M::MOD as i32);
448 let mod2_vec = _mm256_set1_epi32(M::MOD.wrapping_add(M::MOD) as i32);
449 let r_vec = _mm256_set1_epi32(M::R.wrapping_neg() as i32);
450 let sign = _mm256_set1_epi32(0x8000_0000u32 as i32);
451 let imag = M::INFO.root[2];
452 let imag_vec = _mm256_set1_epi32(imag as i32);
453
454 let mut v = n / 2;
455 while v > 1 {
456 let half = v >> 1;
457 let mut w1 = M::N1;
458 for (s, block) in a.chunks_exact_mut(v << 1).enumerate() {
459 let base = block.as_mut_ptr();
460 let ll = base;
461 let lr = base.add(half);
462 let rl = base.add(v);
463 let rr = base.add(v + half);
464
465 let w2 = M::mod_mul(w1, w1);
466 let w3 = M::mod_mul(w2, w1);
467 let w1v = _mm256_set1_epi32(w1 as i32);
468 let w2v = _mm256_set1_epi32(w2 as i32);
469 let w3v = _mm256_set1_epi32(w3 as i32);
470
471 let mut i = 0;
472 while i + 8 <= half {
473 let x0 = _mm256_loadu_si256(ll.add(i) as *const __m256i);
474 let x1 = _mm256_loadu_si256(lr.add(i) as *const __m256i);
475 let x2 = _mm256_loadu_si256(rl.add(i) as *const __m256i);
476 let x3 = _mm256_loadu_si256(rr.add(i) as *const __m256i);
477
478 let a1 = mul_vec_avx2::<M>(x1, w1v, r_vec, mod_vec, sign);
479 let a2 = mul_vec_avx2::<M>(x2, w2v, r_vec, mod_vec, sign);
480 let a3 = mul_vec_avx2::<M>(x3, w3v, r_vec, mod_vec, sign);
481
482 let a0pa2 = add_vec_avx2::<M>(x0, a2, mod_vec, mod2_vec, sign);
483 let a0na2 = sub_vec_avx2::<M>(x0, a2, mod_vec, mod2_vec, sign);
484 let a1pa3 = add_vec_avx2::<M>(a1, a3, mod_vec, mod2_vec, sign);
485 let a1na3 = sub_vec_avx2::<M>(a1, a3, mod_vec, mod2_vec, sign);
486 let a1na3imag = mul_vec_avx2::<M>(a1na3, imag_vec, r_vec, mod_vec, sign);
487
488 let y0 = add_vec_avx2::<M>(a0pa2, a1pa3, mod_vec, mod2_vec, sign);
489 let y1 = sub_vec_avx2::<M>(a0pa2, a1pa3, mod_vec, mod2_vec, sign);
490 let y2 = add_vec_avx2::<M>(a0na2, a1na3imag, mod_vec, mod2_vec, sign);
491 let y3 = sub_vec_avx2::<M>(a0na2, a1na3imag, mod_vec, mod2_vec, sign);
492
493 _mm256_storeu_si256(ll.add(i) as *mut __m256i, y0);
494 _mm256_storeu_si256(lr.add(i) as *mut __m256i, y1);
495 _mm256_storeu_si256(rl.add(i) as *mut __m256i, y2);
496 _mm256_storeu_si256(rr.add(i) as *mut __m256i, y3);
497 i += 8;
498 }
499 while i < half {
500 let a0 = *ll.add(i);
501 let a1 = M::mod_mul(*lr.add(i), w1);
502 let a2 = M::mod_mul(*rl.add(i), w2);
503 let a3 = M::mod_mul(*rr.add(i), w3);
504 let a0pa2 = M::mod_add(a0, a2);
505 let a0na2 = M::mod_sub(a0, a2);
506 let a1pa3 = M::mod_add(a1, a3);
507 let a1na3 = M::mod_sub(a1, a3);
508 let a1na3imag = M::mod_mul(a1na3, imag);
509 *ll.add(i) = M::mod_add(a0pa2, a1pa3);
510 *lr.add(i) = M::mod_sub(a0pa2, a1pa3);
511 *rl.add(i) = M::mod_add(a0na2, a1na3imag);
512 *rr.add(i) = M::mod_sub(a0na2, a1na3imag);
513 i += 1;
514 }
515 w1 = M::mod_mul(w1, M::INFO.rate3[s.trailing_ones() as usize]);
516 }
517 v >>= 2;
518 }
519 if v == 1 {
520 let mut w1 = M::N1;
521 for (s, block) in a.chunks_exact_mut(2).enumerate() {
522 let a0 = *block.get_unchecked(0);
523 let a1 = M::mod_mul(*block.get_unchecked(1), w1);
524 *block.get_unchecked_mut(0) = M::mod_add(a0, a1);
525 *block.get_unchecked_mut(1) = M::mod_sub(a0, a1);
526 w1 = M::mod_mul(w1, M::INFO.rate2[s.trailing_ones() as usize]);
527 }
528 }
529 normalize_avx2::<M>(a);
530 }
531
532 #[target_feature(enable = "avx2")]
533 pub(super) unsafe fn intt_avx2<M>(a: &mut [MInt<M>])
534 where
535 M: Montgomery32NttModulus,
536 {
537 let n = a.len();
538 if n <= 1 {
539 return;
540 }
541 let ptr = a.as_mut_ptr() as *mut u32;
542 let a = std::slice::from_raw_parts_mut(ptr, n);
543 let mod_vec = _mm256_set1_epi32(M::MOD as i32);
544 let mod2_vec = _mm256_set1_epi32(M::MOD.wrapping_add(M::MOD) as i32);
545 let r_vec = _mm256_set1_epi32(M::R.wrapping_neg() as i32);
546 let sign = _mm256_set1_epi32(0x8000_0000u32 as i32);
547 let iimag = M::INFO.inv_root[2];
548 let iimag_vec = _mm256_set1_epi32(iimag as i32);
549
550 let mut v = 1;
551 if n.trailing_zeros() & 1 == 1 {
552 let mut w1 = M::N1;
553 for (s, block) in a.chunks_exact_mut(2).enumerate() {
554 let a0 = *block.get_unchecked(0);
555 let a1 = *block.get_unchecked(1);
556 *block.get_unchecked_mut(0) = M::mod_add(a0, a1);
557 *block.get_unchecked_mut(1) = M::mod_mul(M::mod_sub(a0, a1), w1);
558 w1 = M::mod_mul(w1, M::INFO.inv_rate2[s.trailing_ones() as usize]);
559 }
560 v <<= 1;
561 }
562 while v < n {
563 let mut w1 = M::N1;
564 for (s, block) in a.chunks_exact_mut(v << 2).enumerate() {
565 let base = block.as_mut_ptr();
566 let ll = base;
567 let lr = base.add(v);
568 let rl = base.add(v << 1);
569 let rr = base.add(v * 3);
570
571 let w2 = M::mod_mul(w1, w1);
572 let w3 = M::mod_mul(w2, w1);
573 let w1v = _mm256_set1_epi32(w1 as i32);
574 let w2v = _mm256_set1_epi32(w2 as i32);
575 let w3v = _mm256_set1_epi32(w3 as i32);
576
577 let mut i = 0;
578 while i + 8 <= v {
579 let x0 = _mm256_loadu_si256(ll.add(i) as *const __m256i);
580 let x1 = _mm256_loadu_si256(lr.add(i) as *const __m256i);
581 let x2 = _mm256_loadu_si256(rl.add(i) as *const __m256i);
582 let x3 = _mm256_loadu_si256(rr.add(i) as *const __m256i);
583
584 let a0pa1 = add_vec_avx2::<M>(x0, x1, mod_vec, mod2_vec, sign);
585 let a0na1 = sub_vec_avx2::<M>(x0, x1, mod_vec, mod2_vec, sign);
586 let a2pa3 = add_vec_avx2::<M>(x2, x3, mod_vec, mod2_vec, sign);
587 let a2na3 = sub_vec_avx2::<M>(x2, x3, mod_vec, mod2_vec, sign);
588 let a2na3iimag = mul_vec_avx2::<M>(a2na3, iimag_vec, r_vec, mod_vec, sign);
589
590 let y0 = add_vec_avx2::<M>(a0pa1, a2pa3, mod_vec, mod2_vec, sign);
591 let y1 = add_vec_avx2::<M>(a0na1, a2na3iimag, mod_vec, mod2_vec, sign);
592 let y2 = sub_vec_avx2::<M>(a0pa1, a2pa3, mod_vec, mod2_vec, sign);
593 let y3 = sub_vec_avx2::<M>(a0na1, a2na3iimag, mod_vec, mod2_vec, sign);
594
595 let y1 = mul_vec_avx2::<M>(y1, w1v, r_vec, mod_vec, sign);
596 let y2 = mul_vec_avx2::<M>(y2, w2v, r_vec, mod_vec, sign);
597 let y3 = mul_vec_avx2::<M>(y3, w3v, r_vec, mod_vec, sign);
598
599 _mm256_storeu_si256(ll.add(i) as *mut __m256i, y0);
600 _mm256_storeu_si256(lr.add(i) as *mut __m256i, y1);
601 _mm256_storeu_si256(rl.add(i) as *mut __m256i, y2);
602 _mm256_storeu_si256(rr.add(i) as *mut __m256i, y3);
603 i += 8;
604 }
605 while i < v {
606 let a0 = *ll.add(i);
607 let a1 = *lr.add(i);
608 let a2 = *rl.add(i);
609 let a3 = *rr.add(i);
610 let a0pa1 = M::mod_add(a0, a1);
611 let a0na1 = M::mod_sub(a0, a1);
612 let a2pa3 = M::mod_add(a2, a3);
613 let a2na3iimag = M::mod_mul(M::mod_sub(a2, a3), iimag);
614 *ll.add(i) = M::mod_add(a0pa1, a2pa3);
615 *lr.add(i) = M::mod_mul(M::mod_add(a0na1, a2na3iimag), w1);
616 *rl.add(i) = M::mod_mul(M::mod_sub(a0pa1, a2pa3), w2);
617 *rr.add(i) = M::mod_mul(M::mod_sub(a0na1, a2na3iimag), w3);
618 i += 1;
619 }
620 w1 = M::mod_mul(w1, M::INFO.inv_rate3[s.trailing_ones() as usize]);
621 }
622 v <<= 2;
623 }
624 normalize_avx2::<M>(a);
625 }