Skip to main content

normalize_avx512

Function normalize_avx512 

Source
unsafe fn normalize_avx512<M>(a: &mut [u32])
Examples found in repository?
crates/competitive/src/math/number_theoretic_transform.rs (line 718)
628    pub(super) unsafe fn ntt_avx512<M>(a: &mut [MInt<M>])
629    where
630        M: Montgomery32NttModulus,
631    {
632        let n = a.len();
633        if n <= 1 {
634            return;
635        }
636        let ptr = a.as_mut_ptr() as *mut u32;
637        let a = std::slice::from_raw_parts_mut(ptr, n);
638        let mod_vec = _mm512_set1_epi32(M::MOD as i32);
639        let mod2_vec = _mm512_set1_epi32(M::MOD.wrapping_add(M::MOD) as i32);
640        let r_vec = _mm512_set1_epi32(M::R.wrapping_neg() as i32);
641        let imag = M::INFO.root[2];
642        let imag_vec = _mm512_set1_epi32(imag as i32);
643
644        let mut v = n / 2;
645        while v > 1 {
646            let half = v >> 1;
647            let mut w1 = M::N1;
648            for (s, block) in a.chunks_exact_mut(v << 1).enumerate() {
649                let base = block.as_mut_ptr();
650                let ll = base;
651                let lr = base.add(half);
652                let rl = base.add(v);
653                let rr = base.add(v + half);
654                let w2 = M::mod_mul(w1, w1);
655                let w3 = M::mod_mul(w2, w1);
656                let w1v = _mm512_set1_epi32(w1 as i32);
657                let w2v = _mm512_set1_epi32(w2 as i32);
658                let w3v = _mm512_set1_epi32(w3 as i32);
659
660                let mut i = 0;
661                while i + 16 <= half {
662                    let x0 = _mm512_loadu_si512(ll.add(i) as *const __m512i);
663                    let x1 = _mm512_loadu_si512(lr.add(i) as *const __m512i);
664                    let x2 = _mm512_loadu_si512(rl.add(i) as *const __m512i);
665                    let x3 = _mm512_loadu_si512(rr.add(i) as *const __m512i);
666
667                    let a1 = mul_vec_avx512::<M>(x1, w1v, r_vec, mod_vec);
668                    let a2 = mul_vec_avx512::<M>(x2, w2v, r_vec, mod_vec);
669                    let a3 = mul_vec_avx512::<M>(x3, w3v, r_vec, mod_vec);
670
671                    let a0pa2 = add_vec_avx512::<M>(x0, a2, mod_vec, mod2_vec);
672                    let a0na2 = sub_vec_avx512::<M>(x0, a2, mod_vec, mod2_vec);
673                    let a1pa3 = add_vec_avx512::<M>(a1, a3, mod_vec, mod2_vec);
674                    let a1na3 = sub_vec_avx512::<M>(a1, a3, mod_vec, mod2_vec);
675                    let a1na3imag = mul_vec_avx512::<M>(a1na3, imag_vec, r_vec, mod_vec);
676
677                    let y0 = add_vec_avx512::<M>(a0pa2, a1pa3, mod_vec, mod2_vec);
678                    let y1 = sub_vec_avx512::<M>(a0pa2, a1pa3, mod_vec, mod2_vec);
679                    let y2 = add_vec_avx512::<M>(a0na2, a1na3imag, mod_vec, mod2_vec);
680                    let y3 = sub_vec_avx512::<M>(a0na2, a1na3imag, mod_vec, mod2_vec);
681
682                    _mm512_storeu_si512(ll.add(i) as *mut __m512i, y0);
683                    _mm512_storeu_si512(lr.add(i) as *mut __m512i, y1);
684                    _mm512_storeu_si512(rl.add(i) as *mut __m512i, y2);
685                    _mm512_storeu_si512(rr.add(i) as *mut __m512i, y3);
686                    i += 16;
687                }
688                while i < half {
689                    let a0 = *ll.add(i);
690                    let a1 = M::mod_mul(*lr.add(i), w1);
691                    let a2 = M::mod_mul(*rl.add(i), w2);
692                    let a3 = M::mod_mul(*rr.add(i), w3);
693                    let a0pa2 = M::mod_add(a0, a2);
694                    let a0na2 = M::mod_sub(a0, a2);
695                    let a1pa3 = M::mod_add(a1, a3);
696                    let a1na3 = M::mod_sub(a1, a3);
697                    let a1na3imag = M::mod_mul(a1na3, imag);
698                    *ll.add(i) = M::mod_add(a0pa2, a1pa3);
699                    *lr.add(i) = M::mod_sub(a0pa2, a1pa3);
700                    *rl.add(i) = M::mod_add(a0na2, a1na3imag);
701                    *rr.add(i) = M::mod_sub(a0na2, a1na3imag);
702                    i += 1;
703                }
704                w1 = M::mod_mul(w1, M::INFO.rate3[s.trailing_ones() as usize]);
705            }
706            v >>= 2;
707        }
708        if v == 1 {
709            let mut w1 = M::N1;
710            for (s, block) in a.chunks_exact_mut(2).enumerate() {
711                let a0 = *block.get_unchecked(0);
712                let a1 = M::mod_mul(*block.get_unchecked(1), w1);
713                *block.get_unchecked_mut(0) = M::mod_add(a0, a1);
714                *block.get_unchecked_mut(1) = M::mod_sub(a0, a1);
715                w1 = M::mod_mul(w1, M::INFO.rate2[s.trailing_ones() as usize]);
716            }
717        }
718        normalize_avx512::<M>(a);
719    }
720
721    #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
722    pub(super) unsafe fn intt_avx512<M>(a: &mut [MInt<M>])
723    where
724        M: Montgomery32NttModulus,
725    {
726        let n = a.len();
727        if n <= 1 {
728            return;
729        }
730        let ptr = a.as_mut_ptr() as *mut u32;
731        let a = std::slice::from_raw_parts_mut(ptr, n);
732        let mod_vec = _mm512_set1_epi32(M::MOD as i32);
733        let mod2_vec = _mm512_set1_epi32(M::MOD.wrapping_add(M::MOD) as i32);
734        let r_vec = _mm512_set1_epi32(M::R.wrapping_neg() as i32);
735        let iimag = M::INFO.inv_root[2];
736        let iimag_vec = _mm512_set1_epi32(iimag as i32);
737
738        let mut v = 1;
739        if n.trailing_zeros() & 1 == 1 {
740            let mut w1 = M::N1;
741            for (s, block) in a.chunks_exact_mut(2).enumerate() {
742                let a0 = *block.get_unchecked(0);
743                let a1 = *block.get_unchecked(1);
744                *block.get_unchecked_mut(0) = M::mod_add(a0, a1);
745                *block.get_unchecked_mut(1) = M::mod_mul(M::mod_sub(a0, a1), w1);
746                w1 = M::mod_mul(w1, M::INFO.inv_rate2[s.trailing_ones() as usize]);
747            }
748            v <<= 1;
749        }
750        while v < n {
751            let mut w1 = M::N1;
752            for (s, block) in a.chunks_exact_mut(v << 2).enumerate() {
753                let base = block.as_mut_ptr();
754                let ll = base;
755                let lr = base.add(v);
756                let rl = base.add(v << 1);
757                let rr = base.add(v * 3);
758                let w2 = M::mod_mul(w1, w1);
759                let w3 = M::mod_mul(w2, w1);
760                let w1v = _mm512_set1_epi32(w1 as i32);
761                let w2v = _mm512_set1_epi32(w2 as i32);
762                let w3v = _mm512_set1_epi32(w3 as i32);
763
764                let mut i = 0;
765                while i + 16 <= v {
766                    let x0 = _mm512_loadu_si512(ll.add(i) as *const __m512i);
767                    let x1 = _mm512_loadu_si512(lr.add(i) as *const __m512i);
768                    let x2 = _mm512_loadu_si512(rl.add(i) as *const __m512i);
769                    let x3 = _mm512_loadu_si512(rr.add(i) as *const __m512i);
770
771                    let a0pa1 = add_vec_avx512::<M>(x0, x1, mod_vec, mod2_vec);
772                    let a0na1 = sub_vec_avx512::<M>(x0, x1, mod_vec, mod2_vec);
773                    let a2pa3 = add_vec_avx512::<M>(x2, x3, mod_vec, mod2_vec);
774                    let a2na3 = sub_vec_avx512::<M>(x2, x3, mod_vec, mod2_vec);
775                    let a2na3iimag = mul_vec_avx512::<M>(a2na3, iimag_vec, r_vec, mod_vec);
776
777                    let y0 = add_vec_avx512::<M>(a0pa1, a2pa3, mod_vec, mod2_vec);
778                    let y1 = add_vec_avx512::<M>(a0na1, a2na3iimag, mod_vec, mod2_vec);
779                    let y2 = sub_vec_avx512::<M>(a0pa1, a2pa3, mod_vec, mod2_vec);
780                    let y3 = sub_vec_avx512::<M>(a0na1, a2na3iimag, mod_vec, mod2_vec);
781
782                    let y1 = mul_vec_avx512::<M>(y1, w1v, r_vec, mod_vec);
783                    let y2 = mul_vec_avx512::<M>(y2, w2v, r_vec, mod_vec);
784                    let y3 = mul_vec_avx512::<M>(y3, w3v, r_vec, mod_vec);
785
786                    _mm512_storeu_si512(ll.add(i) as *mut __m512i, y0);
787                    _mm512_storeu_si512(lr.add(i) as *mut __m512i, y1);
788                    _mm512_storeu_si512(rl.add(i) as *mut __m512i, y2);
789                    _mm512_storeu_si512(rr.add(i) as *mut __m512i, y3);
790                    i += 16;
791                }
792                while i < v {
793                    let a0 = *ll.add(i);
794                    let a1 = *lr.add(i);
795                    let a2 = *rl.add(i);
796                    let a3 = *rr.add(i);
797                    let a0pa1 = M::mod_add(a0, a1);
798                    let a0na1 = M::mod_sub(a0, a1);
799                    let a2pa3 = M::mod_add(a2, a3);
800                    let a2na3iimag = M::mod_mul(M::mod_sub(a2, a3), iimag);
801                    *ll.add(i) = M::mod_add(a0pa1, a2pa3);
802                    *lr.add(i) = M::mod_mul(M::mod_add(a0na1, a2na3iimag), w1);
803                    *rl.add(i) = M::mod_mul(M::mod_sub(a0pa1, a2pa3), w2);
804                    *rr.add(i) = M::mod_mul(M::mod_sub(a0na1, a2na3iimag), w3);
805                    i += 1;
806                }
807                w1 = M::mod_mul(w1, M::INFO.inv_rate3[s.trailing_ones() as usize]);
808            }
809            v <<= 2;
810        }
811        normalize_avx512::<M>(a);
812    }