Skip to main content

competitive/math/
matrix.rs

1use super::{Field, Invertible, Ring, SemiRing, SerdeByteStr};
2use std::{
3    fmt::{self, Debug},
4    marker::PhantomData,
5    ops::{Add, AddAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign},
6};
7
8pub struct Matrix<R>
9where
10    R: SemiRing,
11{
12    pub shape: (usize, usize),
13    pub data: Vec<Vec<R::T>>,
14    _marker: PhantomData<fn() -> R>,
15}
16
17impl<R> Debug for Matrix<R>
18where
19    R: SemiRing<T: Debug>,
20{
21    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22        f.debug_struct("Matrix")
23            .field("shape", &self.shape)
24            .field("data", &self.data)
25            .field("_marker", &self._marker)
26            .finish()
27    }
28}
29
30impl<R> Clone for Matrix<R>
31where
32    R: SemiRing,
33{
34    fn clone(&self) -> Self {
35        Self {
36            shape: self.shape,
37            data: self.data.clone(),
38            _marker: self._marker,
39        }
40    }
41}
42
43impl<R> PartialEq for Matrix<R>
44where
45    R: SemiRing<T: PartialEq>,
46{
47    fn eq(&self, other: &Self) -> bool {
48        self.shape == other.shape && self.data == other.data
49    }
50}
51
52impl<R> Eq for Matrix<R> where R: SemiRing<T: Eq> {}
53
54impl<R> Matrix<R>
55where
56    R: SemiRing,
57{
58    pub fn new(shape: (usize, usize), z: R::T) -> Self {
59        Self {
60            shape,
61            data: vec![vec![z; shape.1]; shape.0],
62            _marker: PhantomData,
63        }
64    }
65
66    pub fn from_vec(data: Vec<Vec<R::T>>) -> Self {
67        let shape = (data.len(), data.first().map(Vec::len).unwrap_or_default());
68        assert!(data.iter().all(|r| r.len() == shape.1));
69        Self {
70            shape,
71            data,
72            _marker: PhantomData,
73        }
74    }
75
76    pub fn new_with(shape: (usize, usize), mut f: impl FnMut(usize, usize) -> R::T) -> Self {
77        let data = (0..shape.0)
78            .map(|i| (0..shape.1).map(|j| f(i, j)).collect())
79            .collect();
80        Self {
81            shape,
82            data,
83            _marker: PhantomData,
84        }
85    }
86
87    pub fn zeros(shape: (usize, usize)) -> Self {
88        Self {
89            shape,
90            data: vec![vec![R::zero(); shape.1]; shape.0],
91            _marker: PhantomData,
92        }
93    }
94
95    pub fn eye(shape: (usize, usize)) -> Self {
96        let mut data = vec![vec![R::zero(); shape.1]; shape.0];
97        for (i, d) in data.iter_mut().enumerate().take(shape.1) {
98            d[i] = R::one();
99        }
100        Self {
101            shape,
102            data,
103            _marker: PhantomData,
104        }
105    }
106
107    pub fn transpose(&self) -> Self {
108        Self::new_with((self.shape.1, self.shape.0), |i, j| self[j][i].clone())
109    }
110
111    pub fn map<S, F>(&self, mut f: F) -> Matrix<S>
112    where
113        S: SemiRing,
114        F: FnMut(&R::T) -> S::T,
115    {
116        Matrix::<S>::new_with(self.shape, |i, j| f(&self[i][j]))
117    }
118
119    pub fn add_row_with(&mut self, mut f: impl FnMut(usize, usize) -> R::T) {
120        self.data
121            .push((0..self.shape.1).map(|j| f(self.shape.0, j)).collect());
122        self.shape.0 += 1;
123    }
124
125    pub fn add_col_with(&mut self, mut f: impl FnMut(usize, usize) -> R::T) {
126        for i in 0..self.shape.0 {
127            self.data[i].push(f(i, self.shape.1));
128        }
129        self.shape.1 += 1;
130    }
131
132    pub fn pairwise_assign<F>(&mut self, other: &Self, mut f: F)
133    where
134        F: FnMut(&mut R::T, &R::T),
135    {
136        assert_eq!(self.shape, other.shape);
137        for i in 0..self.shape.0 {
138            for j in 0..self.shape.1 {
139                f(&mut self[i][j], &other[i][j]);
140            }
141        }
142    }
143}
144
145#[derive(Debug)]
146pub struct SystemOfLinearEquationsSolution<R>
147where
148    R: Field<Additive: Invertible, Multiplicative: Invertible>,
149{
150    pub particular: Vec<R::T>,
151    pub basis: Vec<Vec<R::T>>,
152}
153
154impl<R> Matrix<R>
155where
156    R: Field<T: PartialEq, Additive: Invertible, Multiplicative: Invertible>,
157{
158    /// f: (row, pivot_row, col)
159    pub fn row_reduction_with<F>(&mut self, normalize: bool, mut f: F)
160    where
161        F: FnMut(usize, usize, usize),
162    {
163        let (n, m) = self.shape;
164        let mut c = 0;
165        for r in 0..n {
166            loop {
167                if c >= m {
168                    return;
169                }
170                if let Some(pivot) = (r..n).find(|&p| !R::is_zero(&self[p][c])) {
171                    f(r, pivot, c);
172                    self.data.swap(r, pivot);
173                    break;
174                };
175                c += 1;
176            }
177            let d = R::inv(&self[r][c]);
178            if normalize {
179                for j in c..m {
180                    R::mul_assign(&mut self[r][j], &d);
181                }
182            }
183            for i in (0..n).filter(|&i| i != r) {
184                let mut e = self[i][c].clone();
185                if !normalize {
186                    R::mul_assign(&mut e, &d);
187                }
188                for j in c..m {
189                    let e = R::mul(&e, &self[r][j]);
190                    R::sub_assign(&mut self[i][j], &e);
191                }
192            }
193            c += 1;
194        }
195    }
196
197    pub fn row_reduction(&mut self, normalize: bool) {
198        self.row_reduction_with(normalize, |_, _, _| {});
199    }
200
201    pub fn rank(&mut self) -> usize {
202        let n = self.shape.0;
203        self.row_reduction(false);
204        (0..n)
205            .filter(|&i| !self.data[i].iter().all(|x| R::is_zero(x)))
206            .count()
207    }
208
209    pub fn determinant(&mut self) -> R::T {
210        assert_eq!(self.shape.0, self.shape.1);
211        let mut neg = false;
212        self.row_reduction_with(false, |r, p, _| neg ^= r != p);
213        let mut d = R::one();
214        if neg {
215            d = R::neg(&d);
216        }
217        for i in 0..self.shape.0 {
218            R::mul_assign(&mut d, &self[i][i]);
219        }
220        d
221    }
222
223    pub fn solve_system_of_linear_equations(
224        &self,
225        b: &[R::T],
226    ) -> Option<SystemOfLinearEquationsSolution<R>> {
227        assert_eq!(self.shape.0, b.len());
228        let (n, m) = self.shape;
229        let mut c = Matrix::<R>::zeros((n, m + 1));
230        for i in 0..n {
231            c[i][..m].clone_from_slice(&self[i]);
232            c[i][m] = b[i].clone();
233        }
234        let mut reduced = vec![!0; m + 1];
235        c.row_reduction_with(true, |r, _, c| reduced[c] = r);
236        if reduced[m] != !0 {
237            return None;
238        }
239        let mut particular = vec![R::zero(); m];
240        let mut basis = vec![];
241        for j in 0..m {
242            if reduced[j] != !0 {
243                particular[j] = c[reduced[j]][m].clone();
244            } else {
245                let mut v = vec![R::zero(); m];
246                v[j] = R::one();
247                for i in 0..m {
248                    if reduced[i] != !0 {
249                        R::sub_assign(&mut v[i], &c[reduced[i]][j]);
250                    }
251                }
252                basis.push(v);
253            }
254        }
255        Some(SystemOfLinearEquationsSolution { particular, basis })
256    }
257
258    pub fn inverse(&self) -> Option<Matrix<R>> {
259        assert_eq!(self.shape.0, self.shape.1);
260        let n = self.shape.0;
261        let mut c = Matrix::<R>::zeros((n, n * 2));
262        for i in 0..n {
263            c[i][..n].clone_from_slice(&self[i]);
264            c[i][n + i] = R::one();
265        }
266        c.row_reduction(true);
267        if (0..n).any(|i| R::is_zero(&c[i][i])) {
268            None
269        } else {
270            Some(Self::from_vec(
271                c.data.into_iter().map(|r| r[n..].to_vec()).collect(),
272            ))
273        }
274    }
275
276    pub fn characteristic_polynomial(&mut self) -> Vec<R::T> {
277        let n = self.shape.0;
278        if n == 0 {
279            return vec![R::one()];
280        }
281        assert!(self.data.iter().all(|a| a.len() == n));
282        for j in 0..(n - 1) {
283            if let Some(x) = ((j + 1)..n).find(|&x| !R::is_zero(&self[x][j])) {
284                self.data.swap(j + 1, x);
285                self.data.iter_mut().for_each(|a| a.swap(j + 1, x));
286                let inv = R::inv(&self[j + 1][j]);
287                let mut v = vec![];
288                let src = std::mem::take(&mut self[j + 1]);
289                for a in self.data[(j + 2)..].iter_mut() {
290                    let mul = R::mul(&a[j], &inv);
291                    for (a, src) in a[j..].iter_mut().zip(src[j..].iter()) {
292                        R::sub_assign(a, &R::mul(&mul, src));
293                    }
294                    v.push(mul);
295                }
296                self[j + 1] = src;
297                for a in self.data.iter_mut() {
298                    let v = a[(j + 2)..]
299                        .iter()
300                        .zip(v.iter())
301                        .fold(R::zero(), |s, a| R::add(&s, &R::mul(a.0, a.1)));
302                    R::add_assign(&mut a[j + 1], &v);
303                }
304            }
305        }
306        let mut dp = vec![vec![R::one()]];
307        for i in 0..n {
308            let mut next = vec![R::zero(); i + 2];
309            for (j, dp) in dp[i].iter().enumerate() {
310                R::sub_assign(&mut next[j], &R::mul(dp, &self[i][i]));
311                R::add_assign(&mut next[j + 1], dp);
312            }
313            let mut mul = R::one();
314            for j in (0..i).rev() {
315                mul = R::mul(&mul, &self[j + 1][j]);
316                let c = R::mul(&mul, &self[j][i]);
317                for (next, dp) in next.iter_mut().zip(dp[j].iter()) {
318                    R::sub_assign(next, &R::mul(&c, dp));
319                }
320            }
321            dp.push(next);
322        }
323        dp.pop().unwrap()
324    }
325}
326
327impl<R> Index<usize> for Matrix<R>
328where
329    R: SemiRing,
330{
331    type Output = Vec<R::T>;
332    fn index(&self, index: usize) -> &Self::Output {
333        &self.data[index]
334    }
335}
336
337impl<R> IndexMut<usize> for Matrix<R>
338where
339    R: SemiRing,
340{
341    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
342        &mut self.data[index]
343    }
344}
345
346impl<R> Index<(usize, usize)> for Matrix<R>
347where
348    R: SemiRing,
349{
350    type Output = R::T;
351    fn index(&self, index: (usize, usize)) -> &Self::Output {
352        &self.data[index.0][index.1]
353    }
354}
355
356impl<R> IndexMut<(usize, usize)> for Matrix<R>
357where
358    R: SemiRing,
359{
360    fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
361        &mut self.data[index.0][index.1]
362    }
363}
364
365macro_rules! impl_matrix_pairwise_binop {
366    ($imp:ident, $method:ident, $imp_assign:ident, $method_assign:ident $(where [$($clauses:tt)*])?) => {
367        impl<R> $imp_assign for Matrix<R>
368        where
369            R: SemiRing,
370            $($($clauses)*)?
371        {
372            fn $method_assign(&mut self, rhs: Self) {
373                self.pairwise_assign(&rhs, |a, b| R::$method_assign(a, b));
374            }
375        }
376        impl<R> $imp_assign<&Matrix<R>> for Matrix<R>
377        where
378            R: SemiRing,
379            $($($clauses)*)?
380        {
381            fn $method_assign(&mut self, rhs: &Self) {
382                self.pairwise_assign(rhs, |a, b| R::$method_assign(a, b));
383            }
384        }
385        impl<R> $imp for Matrix<R>
386        where
387            R: SemiRing,
388            $($($clauses)*)?
389        {
390            type Output = Matrix<R>;
391            fn $method(mut self, rhs: Self) -> Self::Output {
392                self.$method_assign(rhs);
393                self
394            }
395        }
396        impl<R> $imp<&Matrix<R>> for Matrix<R>
397        where
398            R: SemiRing,
399            $($($clauses)*)?
400        {
401            type Output = Matrix<R>;
402            fn $method(mut self, rhs: &Self) -> Self::Output {
403                self.$method_assign(rhs);
404                self
405            }
406        }
407        impl<R> $imp<Matrix<R>> for &Matrix<R>
408        where
409            R: SemiRing,
410            $($($clauses)*)?
411        {
412            type Output = Matrix<R>;
413            fn $method(self, mut rhs: Matrix<R>) -> Self::Output {
414                rhs.pairwise_assign(self, |a, b| *a = R::$method(b, a));
415                rhs
416            }
417        }
418        impl<R> $imp<&Matrix<R>> for &Matrix<R>
419        where
420            R: SemiRing,
421            $($($clauses)*)?
422        {
423            type Output = Matrix<R>;
424            fn $method(self, rhs: &Matrix<R>) -> Self::Output {
425                let mut this = self.clone();
426                this.$method_assign(rhs);
427                this
428            }
429        }
430    };
431}
432
433impl_matrix_pairwise_binop!(Add, add, AddAssign, add_assign);
434impl_matrix_pairwise_binop!(Sub, sub, SubAssign, sub_assign where [R: SemiRing<Additive: Invertible>]);
435
436impl<R> Mul for Matrix<R>
437where
438    R: SemiRing,
439{
440    type Output = Matrix<R>;
441    fn mul(self, rhs: Self) -> Self::Output {
442        (&self).mul(&rhs)
443    }
444}
445impl<R> Mul<&Matrix<R>> for Matrix<R>
446where
447    R: SemiRing,
448{
449    type Output = Matrix<R>;
450    fn mul(self, rhs: &Matrix<R>) -> Self::Output {
451        (&self).mul(rhs)
452    }
453}
454impl<R> Mul<Matrix<R>> for &Matrix<R>
455where
456    R: SemiRing,
457{
458    type Output = Matrix<R>;
459    fn mul(self, rhs: Matrix<R>) -> Self::Output {
460        self.mul(&rhs)
461    }
462}
463impl<R> Mul<&Matrix<R>> for &Matrix<R>
464where
465    R: SemiRing,
466{
467    type Output = Matrix<R>;
468    fn mul(self, rhs: &Matrix<R>) -> Self::Output {
469        assert_eq!(self.shape.1, rhs.shape.0);
470        let mut res = Matrix::zeros((self.shape.0, rhs.shape.1));
471        for (a, c) in self.data.iter().zip(res.data.iter_mut()) {
472            for (a, b) in a.iter().zip(rhs.data.iter()) {
473                for (b, c) in b.iter().zip(c.iter_mut()) {
474                    R::add_assign(c, &R::mul(a, b));
475                }
476            }
477        }
478        res
479    }
480}
481
482fn strassen_rec<R: Ring>(
483    a: &[R::T],
484    b: &[R::T],
485    c: &mut [R::T],
486    n: usize,
487    stride_a: usize,
488    stride_b: usize,
489) {
490    fn add_block<R: Ring>(
491        a: &[R::T],
492        b: &[R::T],
493        out: &mut [R::T],
494        n: usize,
495        stride_a: usize,
496        stride_b: usize,
497    ) {
498        for ((a, b), c) in a
499            .chunks(stride_a)
500            .zip(b.chunks(stride_b))
501            .zip(out.chunks_exact_mut(n))
502        {
503            for ((a, b), c) in a.iter().zip(b.iter()).zip(c.iter_mut()) {
504                *c = R::add(a, b);
505            }
506        }
507    }
508
509    fn sub_block<R: Ring>(
510        a: &[R::T],
511        b: &[R::T],
512        out: &mut [R::T],
513        n: usize,
514        stride_a: usize,
515        stride_b: usize,
516    ) {
517        for ((a, b), c) in a
518            .chunks(stride_a)
519            .zip(b.chunks(stride_b))
520            .zip(out.chunks_exact_mut(n))
521        {
522            for ((a, b), c) in a.iter().zip(b.iter()).zip(c.iter_mut()) {
523                *c = R::sub(a, b);
524            }
525        }
526    }
527
528    if n <= 64 {
529        for (a, c) in a.chunks(stride_a).zip(c.chunks_exact_mut(n)) {
530            for (a, b) in a.iter().zip(b.chunks(stride_b)).take(n) {
531                for (b, c) in b.iter().zip(c.iter_mut()) {
532                    R::add_assign(c, &R::mul(a, b));
533                }
534            }
535        }
536        return;
537    }
538    let h = n / 2;
539    let a11 = 0;
540    let a12 = h;
541    let a21 = h * stride_a;
542    let a22 = a21 + h;
543    let b11 = 0;
544    let b12 = h;
545    let b21 = h * stride_b;
546    let b22 = b21 + h;
547
548    let block = h * h;
549    let mut buf = vec![R::zero(); block * 9];
550    let (s_buf, m_buf) = buf.split_at_mut(block * 2);
551    let (s1, s2) = s_buf.split_at_mut(block);
552    let (m1, rest) = m_buf.split_at_mut(block);
553    let (m2, rest) = rest.split_at_mut(block);
554    let (m3, rest) = rest.split_at_mut(block);
555    let (m4, rest) = rest.split_at_mut(block);
556    let (m5, rest) = rest.split_at_mut(block);
557    let (m6, m7) = rest.split_at_mut(block);
558
559    // (A11 + A22)(B11 + B22)
560    add_block::<R>(&a[a11..], &a[a22..], s1, h, stride_a, stride_a);
561    add_block::<R>(&b[b11..], &b[b22..], s2, h, stride_b, stride_b);
562    strassen_rec::<R>(s1, s2, m1, h, h, h);
563
564    // (A21 + A22) B11
565    add_block::<R>(&a[a21..], &a[a22..], s1, h, stride_a, stride_a);
566    strassen_rec::<R>(s1, &b[b11..], m2, h, h, stride_b);
567
568    // A11 (B12 - B22)
569    sub_block::<R>(&b[b12..], &b[b22..], s1, h, stride_b, stride_b);
570    strassen_rec::<R>(&a[a11..], s1, m3, h, stride_a, h);
571
572    // A22 (B21 - B11)
573    sub_block::<R>(&b[b21..], &b[b11..], s1, h, stride_b, stride_b);
574    strassen_rec::<R>(&a[a22..], s1, m4, h, stride_a, h);
575
576    // (A11 + A12) B22
577    add_block::<R>(&a[a11..], &a[a12..], s1, h, stride_a, stride_a);
578    strassen_rec::<R>(s1, &b[b22..], m5, h, h, stride_b);
579
580    // (A21 - A11)(B11 + B12)
581    sub_block::<R>(&a[a21..], &a[a11..], s1, h, stride_a, stride_a);
582    add_block::<R>(&b[b11..], &b[b12..], s2, h, stride_b, stride_b);
583    strassen_rec::<R>(s1, s2, m6, h, h, h);
584
585    // (A12 - A22)(B21 + B22)
586    sub_block::<R>(&a[a12..], &a[a22..], s1, h, stride_a, stride_a);
587    add_block::<R>(&b[b21..], &b[b22..], s2, h, stride_b, stride_b);
588    strassen_rec::<R>(s1, s2, m7, h, h, h);
589
590    let c11 = 0;
591    let c12 = h;
592    let c21 = h * n;
593    let c22 = c21 + h;
594    for ((((m1, m4), m5), m7), c) in m1
595        .iter()
596        .zip(m4.iter())
597        .zip(m5.iter())
598        .zip(m7.iter())
599        .zip(c[c11..].chunks_mut(n).flat_map(|c| c.iter_mut().take(h)))
600    {
601        *c = R::add(m1, m4);
602        R::sub_assign(c, m5);
603        R::add_assign(c, m7);
604    }
605    for ((m3, m5), c) in m3
606        .iter()
607        .zip(m5.iter())
608        .zip(c[c12..].chunks_mut(n).flat_map(|c| c.iter_mut().take(h)))
609    {
610        *c = R::add(m3, m5);
611    }
612    for ((m2, m4), c) in m2
613        .iter()
614        .zip(m4.iter())
615        .zip(c[c21..].chunks_mut(n).flat_map(|c| c.iter_mut().take(h)))
616    {
617        *c = R::add(m2, m4);
618    }
619    for ((((m1, m2), m3), m6), c) in m1
620        .iter()
621        .zip(m2.iter())
622        .zip(m3.iter())
623        .zip(m6.iter())
624        .zip(c[c22..].chunks_mut(n).flat_map(|c| c.iter_mut().take(h)))
625    {
626        *c = R::sub(m1, m2);
627        R::add_assign(c, m3);
628        R::add_assign(c, m6);
629    }
630}
631
632impl<R> Matrix<R>
633where
634    R: Ring,
635{
636    pub fn mul_strassen(&self, rhs: &Matrix<R>) -> Matrix<R> {
637        assert_eq!(self.shape.1, rhs.shape.0);
638        let (n, m) = self.shape;
639        let p = rhs.shape.1;
640        if n == 0 || m == 0 || p == 0 {
641            return Matrix::zeros((n, p));
642        }
643        let max_dim = n.max(m).max(p);
644        if max_dim <= 64 {
645            return self * rhs;
646        }
647        let size = max_dim.next_power_of_two();
648        let mut a = vec![R::zero(); size * size];
649        for (a, data) in a.chunks_exact_mut(size).zip(&self.data) {
650            a[..m].clone_from_slice(data);
651        }
652        let mut b = vec![R::zero(); size * size];
653        for (b, data) in b.chunks_exact_mut(size).zip(&rhs.data) {
654            b[..p].clone_from_slice(data);
655        }
656        let mut c = vec![R::zero(); size * size];
657        strassen_rec::<R>(&a, &b, &mut c, size, size, size);
658        let mut res = Matrix::zeros((n, p));
659        for (data, c) in res.data.iter_mut().zip(c.chunks_exact(size)) {
660            data.clone_from_slice(&c[..p]);
661        }
662        res
663    }
664}
665
666impl<R> MulAssign<&R::T> for Matrix<R>
667where
668    R: SemiRing,
669{
670    fn mul_assign(&mut self, rhs: &R::T) {
671        for i in 0..self.shape.0 {
672            for j in 0..self.shape.1 {
673                R::mul_assign(&mut self[(i, j)], rhs);
674            }
675        }
676    }
677}
678
679impl<R> Neg for Matrix<R>
680where
681    R: SemiRing<Additive: Invertible>,
682{
683    type Output = Self;
684
685    fn neg(self) -> Self::Output {
686        self.map(|x| R::neg(x))
687    }
688}
689
690impl<R> Neg for &Matrix<R>
691where
692    R: SemiRing<Additive: Invertible>,
693{
694    type Output = Matrix<R>;
695
696    fn neg(self) -> Self::Output {
697        self.map(|x| R::neg(x))
698    }
699}
700
701impl<R> Matrix<R>
702where
703    R: SemiRing,
704{
705    pub fn pow(self, mut n: usize) -> Self {
706        assert_eq!(self.shape.0, self.shape.1);
707        let mut res = Matrix::eye(self.shape);
708        let mut x = self;
709        while n > 0 {
710            if n & 1 == 1 {
711                res = &res * &x;
712            }
713            x = &x * &x;
714            n >>= 1;
715        }
716        res
717    }
718}
719
720impl<R> Matrix<R>
721where
722    R: Ring,
723{
724    pub fn pow_strassen(self, mut n: usize) -> Self {
725        assert_eq!(self.shape.0, self.shape.1);
726        let mut res = Matrix::eye(self.shape);
727        let mut x = self;
728        while n > 0 {
729            if n & 1 == 1 {
730                res = res.mul_strassen(&x);
731            }
732            x = x.mul_strassen(&x);
733            n >>= 1;
734        }
735        res
736    }
737}
738
739impl<R> SerdeByteStr for Matrix<R>
740where
741    R: SemiRing<T: SerdeByteStr>,
742{
743    fn serialize(&self, buf: &mut Vec<u8>) {
744        self.data.serialize(buf);
745    }
746
747    fn deserialize<I>(iter: &mut I) -> Self
748    where
749        I: Iterator<Item = u8>,
750    {
751        Self::from_vec(Vec::deserialize(iter))
752    }
753}
754
755#[cfg(test)]
756mod tests {
757    use super::*;
758    use crate::{
759        algebra::AddMulOperation,
760        num::{One, Zero, mint_basic::DynMIntU32},
761        rand, rand_value,
762        tools::Xorshift,
763    };
764
765    type R = AddMulOperation<DynMIntU32>;
766
767    fn random_matrix(rng: &mut Xorshift, shape: (usize, usize)) -> Matrix<R> {
768        if rng.gen_bool(0.5) {
769            Matrix::new_with(shape, |_, _| rng.random(..))
770        } else if rng.gen_bool(0.5) {
771            let r = rng.randf();
772            Matrix::new_with(shape, |_, _| {
773                if rng.gen_bool(r) {
774                    rng.random(..)
775                } else {
776                    DynMIntU32::zero()
777                }
778            })
779        } else {
780            let mut mat = Matrix::new_with(shape, |_, _| rng.random(..));
781            let i0 = rng.random(0..shape.0);
782            let i1 = rng.random(0..shape.0);
783            let x: DynMIntU32 = rng.random(..);
784            for j in 0..shape.1 {
785                mat[(i0, j)] = mat[(i1, j)] * x;
786            }
787            mat
788        }
789    }
790
791    #[test]
792    fn test_eye() {
793        for n in 0..10 {
794            for m in 0..10 {
795                let result = Matrix::<R>::eye((n, m));
796                let expected = Matrix::<R>::new_with((n, m), |i, j| {
797                    if i == j {
798                        DynMIntU32::one()
799                    } else {
800                        DynMIntU32::zero()
801                    }
802                });
803                assert_eq!(result, expected);
804            }
805        }
806    }
807
808    #[test]
809    fn test_add() {
810        let mut rng = Xorshift::default();
811        for _ in 0..100 {
812            rand!(rng, n: 1..30, m: 1..30);
813            let a = Matrix::<R>::new_with((n, m), |_, _| rng.random(..));
814            let b = Matrix::<R>::new_with((n, m), |_, _| rng.random(..));
815            assert_eq!(&a + &b, a.clone() + b.clone());
816            assert_eq!(a.clone() + &b, a.clone() + b.clone());
817            assert_eq!(&a + b.clone(), a.clone() + b.clone());
818        }
819    }
820
821    #[test]
822    fn test_sub() {
823        let mut rng = Xorshift::default();
824        for _ in 0..100 {
825            rand!(rng, n: 1..30, m: 1..30);
826            let a = Matrix::<R>::new_with((n, m), |_, _| rng.random(..));
827            let b = Matrix::<R>::new_with((n, m), |_, _| rng.random(..));
828            assert_eq!(&a - &b, a.clone() - b.clone());
829            assert_eq!(a.clone() - &b, a.clone() - b.clone());
830            assert_eq!(&a - b.clone(), a.clone() - b.clone());
831        }
832    }
833
834    #[test]
835    fn test_mul() {
836        let mut rng = Xorshift::default();
837        for _ in 0..100 {
838            rand!(rng, n: 1..30, m: 1..30, l: 1..30);
839            let a = Matrix::<R>::new_with((n, m), |_, _| rng.random(..));
840            let b = Matrix::<R>::new_with((m, l), |_, _| rng.random(..));
841            assert_eq!(&a * &b, a.clone() * b.clone());
842            assert_eq!(a.clone() * &b, a.clone() * b.clone());
843            assert_eq!(&a * b.clone(), a.clone() * b.clone());
844            assert_eq!(
845                &a * &b,
846                Matrix::new_with((n, l), |i, j| (0..m).map(|k| a[i][k] * b[k][j]).sum())
847            );
848            let c = rng.random(..);
849            let mut ac = a.clone();
850            ac *= &c;
851            assert_eq!(ac, Matrix::new_with(a.shape, |i, j| a[i][j] * c));
852        }
853    }
854
855    #[test]
856    fn test_row_reduction() {
857        const Q: usize = 1000;
858        let mut rng = Xorshift::default();
859        let ps = [2, 3, 1_000_000_007];
860        for _ in 0..Q {
861            let m = ps[rng.random(..ps.len())];
862            DynMIntU32::set_mod(m);
863            let n = rng.random(2..=30);
864            let mat = Matrix::<R>::new_with((n, n), |_, _| rng.random(..));
865            let rank = mat.clone().rank();
866            let inv = mat.inverse();
867            assert_eq!(rank == n, inv.is_some());
868            if let Some(inv) = inv {
869                assert_eq!(&mat * &inv, Matrix::eye((n, n)));
870            }
871        }
872    }
873
874    #[test]
875    fn test_system_of_linear_equations() {
876        const Q: usize = 1000;
877        let mut rng = Xorshift::default();
878        let ps = [2, 3, 1_000_000_007];
879        for _ in 0..Q {
880            let p = ps[rng.random(..ps.len())];
881            DynMIntU32::set_mod(p);
882            let n = rng.random(1..=30);
883            let m = rng.random(1..=30);
884            let a = random_matrix(&mut rng, (n, m));
885            let b = random_matrix(&mut rng, (1, n))
886                .data
887                .into_iter()
888                .next()
889                .unwrap();
890            if let Some(sol) = a.solve_system_of_linear_equations(&b) {
891                assert_eq!(
892                    &a * Matrix::from_vec(vec![sol.particular.clone()]).transpose(),
893                    Matrix::from_vec(vec![b.clone()]).transpose()
894                );
895                let c: Vec<DynMIntU32> = rand_value!(rng, [..; sol.basis.len()]);
896                let mut x = sol.particular.clone();
897                for (c, v) in c.iter().zip(sol.basis.iter()) {
898                    for (x, v) in x.iter_mut().zip(v.iter()) {
899                        *x += *c * *v;
900                    }
901                }
902                assert_eq!(
903                    &a * Matrix::from_vec(vec![x]).transpose(),
904                    Matrix::from_vec(vec![b]).transpose()
905                );
906            }
907        }
908    }
909}