Skip to main content

add_vec_avx2

Function add_vec_avx2 

Source
unsafe fn add_vec_avx2<M>(
    a: __m256i,
    b: __m256i,
    mod_vec: __m256i,
    mod2_vec: __m256i,
    sign: __m256i,
) -> __m256i
Examples found in repository?
crates/competitive/src/math/number_theoretic_transform.rs (line 482)
437    pub(super) unsafe fn ntt_avx2<M>(a: &mut [MInt<M>])
438    where
439        M: Montgomery32NttModulus,
440    {
441        let n = a.len();
442        if n <= 1 {
443            return;
444        }
445        let ptr = a.as_mut_ptr() as *mut u32;
446        let a = std::slice::from_raw_parts_mut(ptr, n);
447        let mod_vec = _mm256_set1_epi32(M::MOD as i32);
448        let mod2_vec = _mm256_set1_epi32(M::MOD.wrapping_add(M::MOD) as i32);
449        let r_vec = _mm256_set1_epi32(M::R.wrapping_neg() as i32);
450        let sign = _mm256_set1_epi32(0x8000_0000u32 as i32);
451        let imag = M::INFO.root[2];
452        let imag_vec = _mm256_set1_epi32(imag as i32);
453
454        let mut v = n / 2;
455        while v > 1 {
456            let half = v >> 1;
457            let mut w1 = M::N1;
458            for (s, block) in a.chunks_exact_mut(v << 1).enumerate() {
459                let base = block.as_mut_ptr();
460                let ll = base;
461                let lr = base.add(half);
462                let rl = base.add(v);
463                let rr = base.add(v + half);
464
465                let w2 = M::mod_mul(w1, w1);
466                let w3 = M::mod_mul(w2, w1);
467                let w1v = _mm256_set1_epi32(w1 as i32);
468                let w2v = _mm256_set1_epi32(w2 as i32);
469                let w3v = _mm256_set1_epi32(w3 as i32);
470
471                let mut i = 0;
472                while i + 8 <= half {
473                    let x0 = _mm256_loadu_si256(ll.add(i) as *const __m256i);
474                    let x1 = _mm256_loadu_si256(lr.add(i) as *const __m256i);
475                    let x2 = _mm256_loadu_si256(rl.add(i) as *const __m256i);
476                    let x3 = _mm256_loadu_si256(rr.add(i) as *const __m256i);
477
478                    let a1 = mul_vec_avx2::<M>(x1, w1v, r_vec, mod_vec, sign);
479                    let a2 = mul_vec_avx2::<M>(x2, w2v, r_vec, mod_vec, sign);
480                    let a3 = mul_vec_avx2::<M>(x3, w3v, r_vec, mod_vec, sign);
481
482                    let a0pa2 = add_vec_avx2::<M>(x0, a2, mod_vec, mod2_vec, sign);
483                    let a0na2 = sub_vec_avx2::<M>(x0, a2, mod_vec, mod2_vec, sign);
484                    let a1pa3 = add_vec_avx2::<M>(a1, a3, mod_vec, mod2_vec, sign);
485                    let a1na3 = sub_vec_avx2::<M>(a1, a3, mod_vec, mod2_vec, sign);
486                    let a1na3imag = mul_vec_avx2::<M>(a1na3, imag_vec, r_vec, mod_vec, sign);
487
488                    let y0 = add_vec_avx2::<M>(a0pa2, a1pa3, mod_vec, mod2_vec, sign);
489                    let y1 = sub_vec_avx2::<M>(a0pa2, a1pa3, mod_vec, mod2_vec, sign);
490                    let y2 = add_vec_avx2::<M>(a0na2, a1na3imag, mod_vec, mod2_vec, sign);
491                    let y3 = sub_vec_avx2::<M>(a0na2, a1na3imag, mod_vec, mod2_vec, sign);
492
493                    _mm256_storeu_si256(ll.add(i) as *mut __m256i, y0);
494                    _mm256_storeu_si256(lr.add(i) as *mut __m256i, y1);
495                    _mm256_storeu_si256(rl.add(i) as *mut __m256i, y2);
496                    _mm256_storeu_si256(rr.add(i) as *mut __m256i, y3);
497                    i += 8;
498                }
499                while i < half {
500                    let a0 = *ll.add(i);
501                    let a1 = M::mod_mul(*lr.add(i), w1);
502                    let a2 = M::mod_mul(*rl.add(i), w2);
503                    let a3 = M::mod_mul(*rr.add(i), w3);
504                    let a0pa2 = M::mod_add(a0, a2);
505                    let a0na2 = M::mod_sub(a0, a2);
506                    let a1pa3 = M::mod_add(a1, a3);
507                    let a1na3 = M::mod_sub(a1, a3);
508                    let a1na3imag = M::mod_mul(a1na3, imag);
509                    *ll.add(i) = M::mod_add(a0pa2, a1pa3);
510                    *lr.add(i) = M::mod_sub(a0pa2, a1pa3);
511                    *rl.add(i) = M::mod_add(a0na2, a1na3imag);
512                    *rr.add(i) = M::mod_sub(a0na2, a1na3imag);
513                    i += 1;
514                }
515                w1 = M::mod_mul(w1, M::INFO.rate3[s.trailing_ones() as usize]);
516            }
517            v >>= 2;
518        }
519        if v == 1 {
520            let mut w1 = M::N1;
521            for (s, block) in a.chunks_exact_mut(2).enumerate() {
522                let a0 = *block.get_unchecked(0);
523                let a1 = M::mod_mul(*block.get_unchecked(1), w1);
524                *block.get_unchecked_mut(0) = M::mod_add(a0, a1);
525                *block.get_unchecked_mut(1) = M::mod_sub(a0, a1);
526                w1 = M::mod_mul(w1, M::INFO.rate2[s.trailing_ones() as usize]);
527            }
528        }
529        normalize_avx2::<M>(a);
530    }
531
532    #[target_feature(enable = "avx2")]
533    pub(super) unsafe fn intt_avx2<M>(a: &mut [MInt<M>])
534    where
535        M: Montgomery32NttModulus,
536    {
537        let n = a.len();
538        if n <= 1 {
539            return;
540        }
541        let ptr = a.as_mut_ptr() as *mut u32;
542        let a = std::slice::from_raw_parts_mut(ptr, n);
543        let mod_vec = _mm256_set1_epi32(M::MOD as i32);
544        let mod2_vec = _mm256_set1_epi32(M::MOD.wrapping_add(M::MOD) as i32);
545        let r_vec = _mm256_set1_epi32(M::R.wrapping_neg() as i32);
546        let sign = _mm256_set1_epi32(0x8000_0000u32 as i32);
547        let iimag = M::INFO.inv_root[2];
548        let iimag_vec = _mm256_set1_epi32(iimag as i32);
549
550        let mut v = 1;
551        if n.trailing_zeros() & 1 == 1 {
552            let mut w1 = M::N1;
553            for (s, block) in a.chunks_exact_mut(2).enumerate() {
554                let a0 = *block.get_unchecked(0);
555                let a1 = *block.get_unchecked(1);
556                *block.get_unchecked_mut(0) = M::mod_add(a0, a1);
557                *block.get_unchecked_mut(1) = M::mod_mul(M::mod_sub(a0, a1), w1);
558                w1 = M::mod_mul(w1, M::INFO.inv_rate2[s.trailing_ones() as usize]);
559            }
560            v <<= 1;
561        }
562        while v < n {
563            let mut w1 = M::N1;
564            for (s, block) in a.chunks_exact_mut(v << 2).enumerate() {
565                let base = block.as_mut_ptr();
566                let ll = base;
567                let lr = base.add(v);
568                let rl = base.add(v << 1);
569                let rr = base.add(v * 3);
570
571                let w2 = M::mod_mul(w1, w1);
572                let w3 = M::mod_mul(w2, w1);
573                let w1v = _mm256_set1_epi32(w1 as i32);
574                let w2v = _mm256_set1_epi32(w2 as i32);
575                let w3v = _mm256_set1_epi32(w3 as i32);
576
577                let mut i = 0;
578                while i + 8 <= v {
579                    let x0 = _mm256_loadu_si256(ll.add(i) as *const __m256i);
580                    let x1 = _mm256_loadu_si256(lr.add(i) as *const __m256i);
581                    let x2 = _mm256_loadu_si256(rl.add(i) as *const __m256i);
582                    let x3 = _mm256_loadu_si256(rr.add(i) as *const __m256i);
583
584                    let a0pa1 = add_vec_avx2::<M>(x0, x1, mod_vec, mod2_vec, sign);
585                    let a0na1 = sub_vec_avx2::<M>(x0, x1, mod_vec, mod2_vec, sign);
586                    let a2pa3 = add_vec_avx2::<M>(x2, x3, mod_vec, mod2_vec, sign);
587                    let a2na3 = sub_vec_avx2::<M>(x2, x3, mod_vec, mod2_vec, sign);
588                    let a2na3iimag = mul_vec_avx2::<M>(a2na3, iimag_vec, r_vec, mod_vec, sign);
589
590                    let y0 = add_vec_avx2::<M>(a0pa1, a2pa3, mod_vec, mod2_vec, sign);
591                    let y1 = add_vec_avx2::<M>(a0na1, a2na3iimag, mod_vec, mod2_vec, sign);
592                    let y2 = sub_vec_avx2::<M>(a0pa1, a2pa3, mod_vec, mod2_vec, sign);
593                    let y3 = sub_vec_avx2::<M>(a0na1, a2na3iimag, mod_vec, mod2_vec, sign);
594
595                    let y1 = mul_vec_avx2::<M>(y1, w1v, r_vec, mod_vec, sign);
596                    let y2 = mul_vec_avx2::<M>(y2, w2v, r_vec, mod_vec, sign);
597                    let y3 = mul_vec_avx2::<M>(y3, w3v, r_vec, mod_vec, sign);
598
599                    _mm256_storeu_si256(ll.add(i) as *mut __m256i, y0);
600                    _mm256_storeu_si256(lr.add(i) as *mut __m256i, y1);
601                    _mm256_storeu_si256(rl.add(i) as *mut __m256i, y2);
602                    _mm256_storeu_si256(rr.add(i) as *mut __m256i, y3);
603                    i += 8;
604                }
605                while i < v {
606                    let a0 = *ll.add(i);
607                    let a1 = *lr.add(i);
608                    let a2 = *rl.add(i);
609                    let a3 = *rr.add(i);
610                    let a0pa1 = M::mod_add(a0, a1);
611                    let a0na1 = M::mod_sub(a0, a1);
612                    let a2pa3 = M::mod_add(a2, a3);
613                    let a2na3iimag = M::mod_mul(M::mod_sub(a2, a3), iimag);
614                    *ll.add(i) = M::mod_add(a0pa1, a2pa3);
615                    *lr.add(i) = M::mod_mul(M::mod_add(a0na1, a2na3iimag), w1);
616                    *rl.add(i) = M::mod_mul(M::mod_sub(a0pa1, a2pa3), w2);
617                    *rr.add(i) = M::mod_mul(M::mod_sub(a0na1, a2na3iimag), w3);
618                    i += 1;
619                }
620                w1 = M::mod_mul(w1, M::INFO.inv_rate3[s.trailing_ones() as usize]);
621            }
622            v <<= 2;
623        }
624        normalize_avx2::<M>(a);
625    }