fn strassen_rec<R: Ring>(
a: &[R::T],
b: &[R::T],
c: &mut [R::T],
n: usize,
stride_a: usize,
stride_b: usize,
)Examples found in repository?
crates/competitive/src/math/matrix.rs (line 562)
482fn strassen_rec<R: Ring>(
483 a: &[R::T],
484 b: &[R::T],
485 c: &mut [R::T],
486 n: usize,
487 stride_a: usize,
488 stride_b: usize,
489) {
490 fn add_block<R: Ring>(
491 a: &[R::T],
492 b: &[R::T],
493 out: &mut [R::T],
494 n: usize,
495 stride_a: usize,
496 stride_b: usize,
497 ) {
498 for ((a, b), c) in a
499 .chunks(stride_a)
500 .zip(b.chunks(stride_b))
501 .zip(out.chunks_exact_mut(n))
502 {
503 for ((a, b), c) in a.iter().zip(b.iter()).zip(c.iter_mut()) {
504 *c = R::add(a, b);
505 }
506 }
507 }
508
509 fn sub_block<R: Ring>(
510 a: &[R::T],
511 b: &[R::T],
512 out: &mut [R::T],
513 n: usize,
514 stride_a: usize,
515 stride_b: usize,
516 ) {
517 for ((a, b), c) in a
518 .chunks(stride_a)
519 .zip(b.chunks(stride_b))
520 .zip(out.chunks_exact_mut(n))
521 {
522 for ((a, b), c) in a.iter().zip(b.iter()).zip(c.iter_mut()) {
523 *c = R::sub(a, b);
524 }
525 }
526 }
527
528 if n <= 64 {
529 for (a, c) in a.chunks(stride_a).zip(c.chunks_exact_mut(n)) {
530 for (a, b) in a.iter().zip(b.chunks(stride_b)).take(n) {
531 for (b, c) in b.iter().zip(c.iter_mut()) {
532 R::add_assign(c, &R::mul(a, b));
533 }
534 }
535 }
536 return;
537 }
538 let h = n / 2;
539 let a11 = 0;
540 let a12 = h;
541 let a21 = h * stride_a;
542 let a22 = a21 + h;
543 let b11 = 0;
544 let b12 = h;
545 let b21 = h * stride_b;
546 let b22 = b21 + h;
547
548 let block = h * h;
549 let mut buf = vec![R::zero(); block * 9];
550 let (s_buf, m_buf) = buf.split_at_mut(block * 2);
551 let (s1, s2) = s_buf.split_at_mut(block);
552 let (m1, rest) = m_buf.split_at_mut(block);
553 let (m2, rest) = rest.split_at_mut(block);
554 let (m3, rest) = rest.split_at_mut(block);
555 let (m4, rest) = rest.split_at_mut(block);
556 let (m5, rest) = rest.split_at_mut(block);
557 let (m6, m7) = rest.split_at_mut(block);
558
559 // (A11 + A22)(B11 + B22)
560 add_block::<R>(&a[a11..], &a[a22..], s1, h, stride_a, stride_a);
561 add_block::<R>(&b[b11..], &b[b22..], s2, h, stride_b, stride_b);
562 strassen_rec::<R>(s1, s2, m1, h, h, h);
563
564 // (A21 + A22) B11
565 add_block::<R>(&a[a21..], &a[a22..], s1, h, stride_a, stride_a);
566 strassen_rec::<R>(s1, &b[b11..], m2, h, h, stride_b);
567
568 // A11 (B12 - B22)
569 sub_block::<R>(&b[b12..], &b[b22..], s1, h, stride_b, stride_b);
570 strassen_rec::<R>(&a[a11..], s1, m3, h, stride_a, h);
571
572 // A22 (B21 - B11)
573 sub_block::<R>(&b[b21..], &b[b11..], s1, h, stride_b, stride_b);
574 strassen_rec::<R>(&a[a22..], s1, m4, h, stride_a, h);
575
576 // (A11 + A12) B22
577 add_block::<R>(&a[a11..], &a[a12..], s1, h, stride_a, stride_a);
578 strassen_rec::<R>(s1, &b[b22..], m5, h, h, stride_b);
579
580 // (A21 - A11)(B11 + B12)
581 sub_block::<R>(&a[a21..], &a[a11..], s1, h, stride_a, stride_a);
582 add_block::<R>(&b[b11..], &b[b12..], s2, h, stride_b, stride_b);
583 strassen_rec::<R>(s1, s2, m6, h, h, h);
584
585 // (A12 - A22)(B21 + B22)
586 sub_block::<R>(&a[a12..], &a[a22..], s1, h, stride_a, stride_a);
587 add_block::<R>(&b[b21..], &b[b22..], s2, h, stride_b, stride_b);
588 strassen_rec::<R>(s1, s2, m7, h, h, h);
589
590 let c11 = 0;
591 let c12 = h;
592 let c21 = h * n;
593 let c22 = c21 + h;
594 for ((((m1, m4), m5), m7), c) in m1
595 .iter()
596 .zip(m4.iter())
597 .zip(m5.iter())
598 .zip(m7.iter())
599 .zip(c[c11..].chunks_mut(n).flat_map(|c| c.iter_mut().take(h)))
600 {
601 *c = R::add(m1, m4);
602 R::sub_assign(c, m5);
603 R::add_assign(c, m7);
604 }
605 for ((m3, m5), c) in m3
606 .iter()
607 .zip(m5.iter())
608 .zip(c[c12..].chunks_mut(n).flat_map(|c| c.iter_mut().take(h)))
609 {
610 *c = R::add(m3, m5);
611 }
612 for ((m2, m4), c) in m2
613 .iter()
614 .zip(m4.iter())
615 .zip(c[c21..].chunks_mut(n).flat_map(|c| c.iter_mut().take(h)))
616 {
617 *c = R::add(m2, m4);
618 }
619 for ((((m1, m2), m3), m6), c) in m1
620 .iter()
621 .zip(m2.iter())
622 .zip(m3.iter())
623 .zip(m6.iter())
624 .zip(c[c22..].chunks_mut(n).flat_map(|c| c.iter_mut().take(h)))
625 {
626 *c = R::sub(m1, m2);
627 R::add_assign(c, m3);
628 R::add_assign(c, m6);
629 }
630}
631
632impl<R> Matrix<R>
633where
634 R: Ring,
635{
636 pub fn mul_strassen(&self, rhs: &Matrix<R>) -> Matrix<R> {
637 assert_eq!(self.shape.1, rhs.shape.0);
638 let (n, m) = self.shape;
639 let p = rhs.shape.1;
640 if n == 0 || m == 0 || p == 0 {
641 return Matrix::zeros((n, p));
642 }
643 let max_dim = n.max(m).max(p);
644 if max_dim <= 64 {
645 return self * rhs;
646 }
647 let size = max_dim.next_power_of_two();
648 let mut a = vec![R::zero(); size * size];
649 for (a, data) in a.chunks_exact_mut(size).zip(&self.data) {
650 a[..m].clone_from_slice(data);
651 }
652 let mut b = vec![R::zero(); size * size];
653 for (b, data) in b.chunks_exact_mut(size).zip(&rhs.data) {
654 b[..p].clone_from_slice(data);
655 }
656 let mut c = vec![R::zero(); size * size];
657 strassen_rec::<R>(&a, &b, &mut c, size, size, size);
658 let mut res = Matrix::zeros((n, p));
659 for (data, c) in res.data.iter_mut().zip(c.chunks_exact(size)) {
660 data.clone_from_slice(&c[..p]);
661 }
662 res
663 }