Skip to main content

strassen_rec

Function strassen_rec 

Source
fn strassen_rec<R: Ring>(
    a: &[R::T],
    b: &[R::T],
    c: &mut [R::T],
    n: usize,
    stride_a: usize,
    stride_b: usize,
)
Examples found in repository?
crates/competitive/src/math/matrix.rs (line 562)
482fn strassen_rec<R: Ring>(
483    a: &[R::T],
484    b: &[R::T],
485    c: &mut [R::T],
486    n: usize,
487    stride_a: usize,
488    stride_b: usize,
489) {
490    fn add_block<R: Ring>(
491        a: &[R::T],
492        b: &[R::T],
493        out: &mut [R::T],
494        n: usize,
495        stride_a: usize,
496        stride_b: usize,
497    ) {
498        for ((a, b), c) in a
499            .chunks(stride_a)
500            .zip(b.chunks(stride_b))
501            .zip(out.chunks_exact_mut(n))
502        {
503            for ((a, b), c) in a.iter().zip(b.iter()).zip(c.iter_mut()) {
504                *c = R::add(a, b);
505            }
506        }
507    }
508
509    fn sub_block<R: Ring>(
510        a: &[R::T],
511        b: &[R::T],
512        out: &mut [R::T],
513        n: usize,
514        stride_a: usize,
515        stride_b: usize,
516    ) {
517        for ((a, b), c) in a
518            .chunks(stride_a)
519            .zip(b.chunks(stride_b))
520            .zip(out.chunks_exact_mut(n))
521        {
522            for ((a, b), c) in a.iter().zip(b.iter()).zip(c.iter_mut()) {
523                *c = R::sub(a, b);
524            }
525        }
526    }
527
528    if n <= 64 {
529        for (a, c) in a.chunks(stride_a).zip(c.chunks_exact_mut(n)) {
530            for (a, b) in a.iter().zip(b.chunks(stride_b)).take(n) {
531                for (b, c) in b.iter().zip(c.iter_mut()) {
532                    R::add_assign(c, &R::mul(a, b));
533                }
534            }
535        }
536        return;
537    }
538    let h = n / 2;
539    let a11 = 0;
540    let a12 = h;
541    let a21 = h * stride_a;
542    let a22 = a21 + h;
543    let b11 = 0;
544    let b12 = h;
545    let b21 = h * stride_b;
546    let b22 = b21 + h;
547
548    let block = h * h;
549    let mut buf = vec![R::zero(); block * 9];
550    let (s_buf, m_buf) = buf.split_at_mut(block * 2);
551    let (s1, s2) = s_buf.split_at_mut(block);
552    let (m1, rest) = m_buf.split_at_mut(block);
553    let (m2, rest) = rest.split_at_mut(block);
554    let (m3, rest) = rest.split_at_mut(block);
555    let (m4, rest) = rest.split_at_mut(block);
556    let (m5, rest) = rest.split_at_mut(block);
557    let (m6, m7) = rest.split_at_mut(block);
558
559    // (A11 + A22)(B11 + B22)
560    add_block::<R>(&a[a11..], &a[a22..], s1, h, stride_a, stride_a);
561    add_block::<R>(&b[b11..], &b[b22..], s2, h, stride_b, stride_b);
562    strassen_rec::<R>(s1, s2, m1, h, h, h);
563
564    // (A21 + A22) B11
565    add_block::<R>(&a[a21..], &a[a22..], s1, h, stride_a, stride_a);
566    strassen_rec::<R>(s1, &b[b11..], m2, h, h, stride_b);
567
568    // A11 (B12 - B22)
569    sub_block::<R>(&b[b12..], &b[b22..], s1, h, stride_b, stride_b);
570    strassen_rec::<R>(&a[a11..], s1, m3, h, stride_a, h);
571
572    // A22 (B21 - B11)
573    sub_block::<R>(&b[b21..], &b[b11..], s1, h, stride_b, stride_b);
574    strassen_rec::<R>(&a[a22..], s1, m4, h, stride_a, h);
575
576    // (A11 + A12) B22
577    add_block::<R>(&a[a11..], &a[a12..], s1, h, stride_a, stride_a);
578    strassen_rec::<R>(s1, &b[b22..], m5, h, h, stride_b);
579
580    // (A21 - A11)(B11 + B12)
581    sub_block::<R>(&a[a21..], &a[a11..], s1, h, stride_a, stride_a);
582    add_block::<R>(&b[b11..], &b[b12..], s2, h, stride_b, stride_b);
583    strassen_rec::<R>(s1, s2, m6, h, h, h);
584
585    // (A12 - A22)(B21 + B22)
586    sub_block::<R>(&a[a12..], &a[a22..], s1, h, stride_a, stride_a);
587    add_block::<R>(&b[b21..], &b[b22..], s2, h, stride_b, stride_b);
588    strassen_rec::<R>(s1, s2, m7, h, h, h);
589
590    let c11 = 0;
591    let c12 = h;
592    let c21 = h * n;
593    let c22 = c21 + h;
594    for ((((m1, m4), m5), m7), c) in m1
595        .iter()
596        .zip(m4.iter())
597        .zip(m5.iter())
598        .zip(m7.iter())
599        .zip(c[c11..].chunks_mut(n).flat_map(|c| c.iter_mut().take(h)))
600    {
601        *c = R::add(m1, m4);
602        R::sub_assign(c, m5);
603        R::add_assign(c, m7);
604    }
605    for ((m3, m5), c) in m3
606        .iter()
607        .zip(m5.iter())
608        .zip(c[c12..].chunks_mut(n).flat_map(|c| c.iter_mut().take(h)))
609    {
610        *c = R::add(m3, m5);
611    }
612    for ((m2, m4), c) in m2
613        .iter()
614        .zip(m4.iter())
615        .zip(c[c21..].chunks_mut(n).flat_map(|c| c.iter_mut().take(h)))
616    {
617        *c = R::add(m2, m4);
618    }
619    for ((((m1, m2), m3), m6), c) in m1
620        .iter()
621        .zip(m2.iter())
622        .zip(m3.iter())
623        .zip(m6.iter())
624        .zip(c[c22..].chunks_mut(n).flat_map(|c| c.iter_mut().take(h)))
625    {
626        *c = R::sub(m1, m2);
627        R::add_assign(c, m3);
628        R::add_assign(c, m6);
629    }
630}
631
632impl<R> Matrix<R>
633where
634    R: Ring,
635{
636    pub fn mul_strassen(&self, rhs: &Matrix<R>) -> Matrix<R> {
637        assert_eq!(self.shape.1, rhs.shape.0);
638        let (n, m) = self.shape;
639        let p = rhs.shape.1;
640        if n == 0 || m == 0 || p == 0 {
641            return Matrix::zeros((n, p));
642        }
643        let max_dim = n.max(m).max(p);
644        if max_dim <= 64 {
645            return self * rhs;
646        }
647        let size = max_dim.next_power_of_two();
648        let mut a = vec![R::zero(); size * size];
649        for (a, data) in a.chunks_exact_mut(size).zip(&self.data) {
650            a[..m].clone_from_slice(data);
651        }
652        let mut b = vec![R::zero(); size * size];
653        for (b, data) in b.chunks_exact_mut(size).zip(&rhs.data) {
654            b[..p].clone_from_slice(data);
655        }
656        let mut c = vec![R::zero(); size * size];
657        strassen_rec::<R>(&a, &b, &mut c, size, size, size);
658        let mut res = Matrix::zeros((n, p));
659        for (data, c) in res.data.iter_mut().zip(c.chunks_exact(size)) {
660            data.clone_from_slice(&c[..p]);
661        }
662        res
663    }