unsafe fn sub_vec_avx512<M>(
a: __m512i,
b: __m512i,
mod_vec: __m512i,
mod2_vec: __m512i,
) -> __m512iwhere
M: Montgomery32NttModulus,Examples found in repository?
crates/competitive/src/math/number_theoretic_transform.rs (line 672)
628 pub(super) unsafe fn ntt_avx512<M>(a: &mut [MInt<M>])
629 where
630 M: Montgomery32NttModulus,
631 {
632 let n = a.len();
633 if n <= 1 {
634 return;
635 }
636 let ptr = a.as_mut_ptr() as *mut u32;
637 let a = std::slice::from_raw_parts_mut(ptr, n);
638 let mod_vec = _mm512_set1_epi32(M::MOD as i32);
639 let mod2_vec = _mm512_set1_epi32(M::MOD.wrapping_add(M::MOD) as i32);
640 let r_vec = _mm512_set1_epi32(M::R.wrapping_neg() as i32);
641 let imag = M::INFO.root[2];
642 let imag_vec = _mm512_set1_epi32(imag as i32);
643
644 let mut v = n / 2;
645 while v > 1 {
646 let half = v >> 1;
647 let mut w1 = M::N1;
648 for (s, block) in a.chunks_exact_mut(v << 1).enumerate() {
649 let base = block.as_mut_ptr();
650 let ll = base;
651 let lr = base.add(half);
652 let rl = base.add(v);
653 let rr = base.add(v + half);
654 let w2 = M::mod_mul(w1, w1);
655 let w3 = M::mod_mul(w2, w1);
656 let w1v = _mm512_set1_epi32(w1 as i32);
657 let w2v = _mm512_set1_epi32(w2 as i32);
658 let w3v = _mm512_set1_epi32(w3 as i32);
659
660 let mut i = 0;
661 while i + 16 <= half {
662 let x0 = _mm512_loadu_si512(ll.add(i) as *const __m512i);
663 let x1 = _mm512_loadu_si512(lr.add(i) as *const __m512i);
664 let x2 = _mm512_loadu_si512(rl.add(i) as *const __m512i);
665 let x3 = _mm512_loadu_si512(rr.add(i) as *const __m512i);
666
667 let a1 = mul_vec_avx512::<M>(x1, w1v, r_vec, mod_vec);
668 let a2 = mul_vec_avx512::<M>(x2, w2v, r_vec, mod_vec);
669 let a3 = mul_vec_avx512::<M>(x3, w3v, r_vec, mod_vec);
670
671 let a0pa2 = add_vec_avx512::<M>(x0, a2, mod_vec, mod2_vec);
672 let a0na2 = sub_vec_avx512::<M>(x0, a2, mod_vec, mod2_vec);
673 let a1pa3 = add_vec_avx512::<M>(a1, a3, mod_vec, mod2_vec);
674 let a1na3 = sub_vec_avx512::<M>(a1, a3, mod_vec, mod2_vec);
675 let a1na3imag = mul_vec_avx512::<M>(a1na3, imag_vec, r_vec, mod_vec);
676
677 let y0 = add_vec_avx512::<M>(a0pa2, a1pa3, mod_vec, mod2_vec);
678 let y1 = sub_vec_avx512::<M>(a0pa2, a1pa3, mod_vec, mod2_vec);
679 let y2 = add_vec_avx512::<M>(a0na2, a1na3imag, mod_vec, mod2_vec);
680 let y3 = sub_vec_avx512::<M>(a0na2, a1na3imag, mod_vec, mod2_vec);
681
682 _mm512_storeu_si512(ll.add(i) as *mut __m512i, y0);
683 _mm512_storeu_si512(lr.add(i) as *mut __m512i, y1);
684 _mm512_storeu_si512(rl.add(i) as *mut __m512i, y2);
685 _mm512_storeu_si512(rr.add(i) as *mut __m512i, y3);
686 i += 16;
687 }
688 while i < half {
689 let a0 = *ll.add(i);
690 let a1 = M::mod_mul(*lr.add(i), w1);
691 let a2 = M::mod_mul(*rl.add(i), w2);
692 let a3 = M::mod_mul(*rr.add(i), w3);
693 let a0pa2 = M::mod_add(a0, a2);
694 let a0na2 = M::mod_sub(a0, a2);
695 let a1pa3 = M::mod_add(a1, a3);
696 let a1na3 = M::mod_sub(a1, a3);
697 let a1na3imag = M::mod_mul(a1na3, imag);
698 *ll.add(i) = M::mod_add(a0pa2, a1pa3);
699 *lr.add(i) = M::mod_sub(a0pa2, a1pa3);
700 *rl.add(i) = M::mod_add(a0na2, a1na3imag);
701 *rr.add(i) = M::mod_sub(a0na2, a1na3imag);
702 i += 1;
703 }
704 w1 = M::mod_mul(w1, M::INFO.rate3[s.trailing_ones() as usize]);
705 }
706 v >>= 2;
707 }
708 if v == 1 {
709 let mut w1 = M::N1;
710 for (s, block) in a.chunks_exact_mut(2).enumerate() {
711 let a0 = *block.get_unchecked(0);
712 let a1 = M::mod_mul(*block.get_unchecked(1), w1);
713 *block.get_unchecked_mut(0) = M::mod_add(a0, a1);
714 *block.get_unchecked_mut(1) = M::mod_sub(a0, a1);
715 w1 = M::mod_mul(w1, M::INFO.rate2[s.trailing_ones() as usize]);
716 }
717 }
718 normalize_avx512::<M>(a);
719 }
720
721 #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
722 pub(super) unsafe fn intt_avx512<M>(a: &mut [MInt<M>])
723 where
724 M: Montgomery32NttModulus,
725 {
726 let n = a.len();
727 if n <= 1 {
728 return;
729 }
730 let ptr = a.as_mut_ptr() as *mut u32;
731 let a = std::slice::from_raw_parts_mut(ptr, n);
732 let mod_vec = _mm512_set1_epi32(M::MOD as i32);
733 let mod2_vec = _mm512_set1_epi32(M::MOD.wrapping_add(M::MOD) as i32);
734 let r_vec = _mm512_set1_epi32(M::R.wrapping_neg() as i32);
735 let iimag = M::INFO.inv_root[2];
736 let iimag_vec = _mm512_set1_epi32(iimag as i32);
737
738 let mut v = 1;
739 if n.trailing_zeros() & 1 == 1 {
740 let mut w1 = M::N1;
741 for (s, block) in a.chunks_exact_mut(2).enumerate() {
742 let a0 = *block.get_unchecked(0);
743 let a1 = *block.get_unchecked(1);
744 *block.get_unchecked_mut(0) = M::mod_add(a0, a1);
745 *block.get_unchecked_mut(1) = M::mod_mul(M::mod_sub(a0, a1), w1);
746 w1 = M::mod_mul(w1, M::INFO.inv_rate2[s.trailing_ones() as usize]);
747 }
748 v <<= 1;
749 }
750 while v < n {
751 let mut w1 = M::N1;
752 for (s, block) in a.chunks_exact_mut(v << 2).enumerate() {
753 let base = block.as_mut_ptr();
754 let ll = base;
755 let lr = base.add(v);
756 let rl = base.add(v << 1);
757 let rr = base.add(v * 3);
758 let w2 = M::mod_mul(w1, w1);
759 let w3 = M::mod_mul(w2, w1);
760 let w1v = _mm512_set1_epi32(w1 as i32);
761 let w2v = _mm512_set1_epi32(w2 as i32);
762 let w3v = _mm512_set1_epi32(w3 as i32);
763
764 let mut i = 0;
765 while i + 16 <= v {
766 let x0 = _mm512_loadu_si512(ll.add(i) as *const __m512i);
767 let x1 = _mm512_loadu_si512(lr.add(i) as *const __m512i);
768 let x2 = _mm512_loadu_si512(rl.add(i) as *const __m512i);
769 let x3 = _mm512_loadu_si512(rr.add(i) as *const __m512i);
770
771 let a0pa1 = add_vec_avx512::<M>(x0, x1, mod_vec, mod2_vec);
772 let a0na1 = sub_vec_avx512::<M>(x0, x1, mod_vec, mod2_vec);
773 let a2pa3 = add_vec_avx512::<M>(x2, x3, mod_vec, mod2_vec);
774 let a2na3 = sub_vec_avx512::<M>(x2, x3, mod_vec, mod2_vec);
775 let a2na3iimag = mul_vec_avx512::<M>(a2na3, iimag_vec, r_vec, mod_vec);
776
777 let y0 = add_vec_avx512::<M>(a0pa1, a2pa3, mod_vec, mod2_vec);
778 let y1 = add_vec_avx512::<M>(a0na1, a2na3iimag, mod_vec, mod2_vec);
779 let y2 = sub_vec_avx512::<M>(a0pa1, a2pa3, mod_vec, mod2_vec);
780 let y3 = sub_vec_avx512::<M>(a0na1, a2na3iimag, mod_vec, mod2_vec);
781
782 let y1 = mul_vec_avx512::<M>(y1, w1v, r_vec, mod_vec);
783 let y2 = mul_vec_avx512::<M>(y2, w2v, r_vec, mod_vec);
784 let y3 = mul_vec_avx512::<M>(y3, w3v, r_vec, mod_vec);
785
786 _mm512_storeu_si512(ll.add(i) as *mut __m512i, y0);
787 _mm512_storeu_si512(lr.add(i) as *mut __m512i, y1);
788 _mm512_storeu_si512(rl.add(i) as *mut __m512i, y2);
789 _mm512_storeu_si512(rr.add(i) as *mut __m512i, y3);
790 i += 16;
791 }
792 while i < v {
793 let a0 = *ll.add(i);
794 let a1 = *lr.add(i);
795 let a2 = *rl.add(i);
796 let a3 = *rr.add(i);
797 let a0pa1 = M::mod_add(a0, a1);
798 let a0na1 = M::mod_sub(a0, a1);
799 let a2pa3 = M::mod_add(a2, a3);
800 let a2na3iimag = M::mod_mul(M::mod_sub(a2, a3), iimag);
801 *ll.add(i) = M::mod_add(a0pa1, a2pa3);
802 *lr.add(i) = M::mod_mul(M::mod_add(a0na1, a2na3iimag), w1);
803 *rl.add(i) = M::mod_mul(M::mod_sub(a0pa1, a2pa3), w2);
804 *rr.add(i) = M::mod_mul(M::mod_sub(a0na1, a2na3iimag), w3);
805 i += 1;
806 }
807 w1 = M::mod_mul(w1, M::INFO.inv_rate3[s.trailing_ones() as usize]);
808 }
809 v <<= 2;
810 }
811 normalize_avx512::<M>(a);
812 }