competitive/math/
matrix.rs

1use super::{Field, Invertible, Ring, SemiRing, SerdeByteStr};
2use std::{
3    fmt::{self, Debug},
4    marker::PhantomData,
5    ops::{Add, AddAssign, Index, IndexMut, Mul, MulAssign, Sub, SubAssign},
6};
7
8pub struct Matrix<R>
9where
10    R: SemiRing,
11{
12    pub shape: (usize, usize),
13    pub data: Vec<Vec<R::T>>,
14    _marker: PhantomData<fn() -> R>,
15}
16
17impl<R> Debug for Matrix<R>
18where
19    R: SemiRing,
20    R::T: Debug,
21{
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        f.debug_struct("Matrix")
24            .field("shape", &self.shape)
25            .field("data", &self.data)
26            .field("_marker", &self._marker)
27            .finish()
28    }
29}
30
31impl<R> Clone for Matrix<R>
32where
33    R: SemiRing,
34{
35    fn clone(&self) -> Self {
36        Self {
37            shape: self.shape,
38            data: self.data.clone(),
39            _marker: self._marker,
40        }
41    }
42}
43
44impl<R> PartialEq for Matrix<R>
45where
46    R: SemiRing,
47    R::T: PartialEq,
48{
49    fn eq(&self, other: &Self) -> bool {
50        self.shape == other.shape && self.data == other.data
51    }
52}
53
54impl<R> Eq for Matrix<R>
55where
56    R: SemiRing,
57    R::T: Eq,
58{
59}
60
61impl<R> Matrix<R>
62where
63    R: SemiRing,
64{
65    pub fn new(shape: (usize, usize), z: R::T) -> Self {
66        Self {
67            shape,
68            data: vec![vec![z; shape.1]; shape.0],
69            _marker: PhantomData,
70        }
71    }
72
73    pub fn from_vec(data: Vec<Vec<R::T>>) -> Self {
74        let shape = (data.len(), data.first().map(Vec::len).unwrap_or_default());
75        assert!(data.iter().all(|r| r.len() == shape.1));
76        Self {
77            shape,
78            data,
79            _marker: PhantomData,
80        }
81    }
82
83    pub fn new_with(shape: (usize, usize), mut f: impl FnMut(usize, usize) -> R::T) -> Self {
84        let data = (0..shape.0)
85            .map(|i| (0..shape.1).map(|j| f(i, j)).collect())
86            .collect();
87        Self {
88            shape,
89            data,
90            _marker: PhantomData,
91        }
92    }
93
94    pub fn zeros(shape: (usize, usize)) -> Self {
95        Self {
96            shape,
97            data: vec![vec![R::zero(); shape.1]; shape.0],
98            _marker: PhantomData,
99        }
100    }
101
102    pub fn eye(shape: (usize, usize)) -> Self {
103        let mut data = vec![vec![R::zero(); shape.1]; shape.0];
104        for (i, d) in data.iter_mut().enumerate().take(shape.1) {
105            d[i] = R::one();
106        }
107        Self {
108            shape,
109            data,
110            _marker: PhantomData,
111        }
112    }
113
114    pub fn transpose(&self) -> Self {
115        Self::new_with((self.shape.1, self.shape.0), |i, j| self[j][i].clone())
116    }
117
118    pub fn map<S, F>(&self, mut f: F) -> Matrix<S>
119    where
120        S: SemiRing,
121        F: FnMut(&R::T) -> S::T,
122    {
123        Matrix::<S>::new_with(self.shape, |i, j| f(&self[i][j]))
124    }
125
126    pub fn add_row_with(&mut self, mut f: impl FnMut(usize, usize) -> R::T) {
127        self.data
128            .push((0..self.shape.1).map(|j| f(self.shape.0, j)).collect());
129        self.shape.0 += 1;
130    }
131
132    pub fn add_col_with(&mut self, mut f: impl FnMut(usize, usize) -> R::T) {
133        for i in 0..self.shape.0 {
134            self.data[i].push(f(i, self.shape.1));
135        }
136        self.shape.1 += 1;
137    }
138
139    pub fn pairwise_assign<F>(&mut self, other: &Self, mut f: F)
140    where
141        F: FnMut(&mut R::T, &R::T),
142    {
143        assert_eq!(self.shape, other.shape);
144        for i in 0..self.shape.0 {
145            for j in 0..self.shape.1 {
146                f(&mut self[i][j], &other[i][j]);
147            }
148        }
149    }
150}
151
152#[derive(Debug)]
153pub struct SystemOfLinearEquationsSolution<R>
154where
155    R: Field,
156    R::Additive: Invertible,
157    R::Multiplicative: Invertible,
158{
159    pub particular: Vec<R::T>,
160    pub basis: Vec<Vec<R::T>>,
161}
162
163impl<R> Matrix<R>
164where
165    R: Field,
166    R::Additive: Invertible,
167    R::Multiplicative: Invertible,
168    R::T: PartialEq,
169{
170    /// f: (row, pivot_row, col)
171    pub fn row_reduction_with<F>(&mut self, normalize: bool, mut f: F)
172    where
173        F: FnMut(usize, usize, usize),
174    {
175        let (n, m) = self.shape;
176        let mut c = 0;
177        for r in 0..n {
178            loop {
179                if c >= m {
180                    return;
181                }
182                if let Some(pivot) = (r..n).find(|&p| !R::is_zero(&self[p][c])) {
183                    f(r, pivot, c);
184                    self.data.swap(r, pivot);
185                    break;
186                };
187                c += 1;
188            }
189            let d = R::inv(&self[r][c]);
190            if normalize {
191                for j in c..m {
192                    R::mul_assign(&mut self[r][j], &d);
193                }
194            }
195            for i in (0..n).filter(|&i| i != r) {
196                let mut e = self[i][c].clone();
197                if !normalize {
198                    R::mul_assign(&mut e, &d);
199                }
200                for j in c..m {
201                    let e = R::mul(&e, &self[r][j]);
202                    R::sub_assign(&mut self[i][j], &e);
203                }
204            }
205            c += 1;
206        }
207    }
208
209    pub fn row_reduction(&mut self, normalize: bool) {
210        self.row_reduction_with(normalize, |_, _, _| {});
211    }
212
213    pub fn rank(&mut self) -> usize {
214        let n = self.shape.0;
215        self.row_reduction(false);
216        (0..n)
217            .filter(|&i| !self.data[i].iter().all(|x| R::is_zero(x)))
218            .count()
219    }
220
221    pub fn determinant(&mut self) -> R::T {
222        assert_eq!(self.shape.0, self.shape.1);
223        let mut neg = false;
224        self.row_reduction_with(false, |r, p, _| neg ^= r != p);
225        let mut d = R::one();
226        if neg {
227            d = R::neg(&d);
228        }
229        for i in 0..self.shape.0 {
230            R::mul_assign(&mut d, &self[i][i]);
231        }
232        d
233    }
234
235    pub fn solve_system_of_linear_equations(
236        &self,
237        b: &[R::T],
238    ) -> Option<SystemOfLinearEquationsSolution<R>> {
239        assert_eq!(self.shape.0, b.len());
240        let (n, m) = self.shape;
241        let mut c = Matrix::<R>::zeros((n, m + 1));
242        for i in 0..n {
243            c[i][..m].clone_from_slice(&self[i]);
244            c[i][m] = b[i].clone();
245        }
246        let mut reduced = vec![!0; m + 1];
247        c.row_reduction_with(true, |r, _, c| reduced[c] = r);
248        if reduced[m] != !0 {
249            return None;
250        }
251        let mut particular = vec![R::zero(); m];
252        let mut basis = vec![];
253        for j in 0..m {
254            if reduced[j] != !0 {
255                particular[j] = c[reduced[j]][m].clone();
256            } else {
257                let mut v = vec![R::zero(); m];
258                v[j] = R::one();
259                for i in 0..m {
260                    if reduced[i] != !0 {
261                        R::sub_assign(&mut v[i], &c[reduced[i]][j]);
262                    }
263                }
264                basis.push(v);
265            }
266        }
267        Some(SystemOfLinearEquationsSolution { particular, basis })
268    }
269
270    pub fn inverse(&self) -> Option<Matrix<R>> {
271        assert_eq!(self.shape.0, self.shape.1);
272        let n = self.shape.0;
273        let mut c = Matrix::<R>::zeros((n, n * 2));
274        for i in 0..n {
275            c[i][..n].clone_from_slice(&self[i]);
276            c[i][n + i] = R::one();
277        }
278        c.row_reduction(true);
279        if (0..n).any(|i| R::is_zero(&c[i][i])) {
280            None
281        } else {
282            Some(Self::from_vec(
283                c.data.into_iter().map(|r| r[n..].to_vec()).collect(),
284            ))
285        }
286    }
287
288    pub fn characteristic_polynomial(&mut self) -> Vec<R::T> {
289        let n = self.shape.0;
290        if n == 0 {
291            return vec![R::one()];
292        }
293        assert!(self.data.iter().all(|a| a.len() == n));
294        for j in 0..(n - 1) {
295            if let Some(x) = ((j + 1)..n).find(|&x| !R::is_zero(&self[x][j])) {
296                self.data.swap(j + 1, x);
297                self.data.iter_mut().for_each(|a| a.swap(j + 1, x));
298                let inv = R::inv(&self[j + 1][j]);
299                let mut v = vec![];
300                let src = std::mem::take(&mut self[j + 1]);
301                for a in self.data[(j + 2)..].iter_mut() {
302                    let mul = R::mul(&a[j], &inv);
303                    for (a, src) in a[j..].iter_mut().zip(src[j..].iter()) {
304                        R::sub_assign(a, &R::mul(&mul, src));
305                    }
306                    v.push(mul);
307                }
308                self[j + 1] = src;
309                for a in self.data.iter_mut() {
310                    let v = a[(j + 2)..]
311                        .iter()
312                        .zip(v.iter())
313                        .fold(R::zero(), |s, a| R::add(&s, &R::mul(a.0, a.1)));
314                    R::add_assign(&mut a[j + 1], &v);
315                }
316            }
317        }
318        let mut dp = vec![vec![R::one()]];
319        for i in 0..n {
320            let mut next = vec![R::zero(); i + 2];
321            for (j, dp) in dp[i].iter().enumerate() {
322                R::sub_assign(&mut next[j], &R::mul(dp, &self[i][i]));
323                R::add_assign(&mut next[j + 1], dp);
324            }
325            let mut mul = R::one();
326            for j in (0..i).rev() {
327                mul = R::mul(&mul, &self[j + 1][j]);
328                let c = R::mul(&mul, &self[j][i]);
329                for (next, dp) in next.iter_mut().zip(dp[j].iter()) {
330                    R::sub_assign(next, &R::mul(&c, dp));
331                }
332            }
333            dp.push(next);
334        }
335        dp.pop().unwrap()
336    }
337}
338
339impl<R> Index<usize> for Matrix<R>
340where
341    R: SemiRing,
342{
343    type Output = Vec<R::T>;
344    fn index(&self, index: usize) -> &Self::Output {
345        &self.data[index]
346    }
347}
348
349impl<R> IndexMut<usize> for Matrix<R>
350where
351    R: SemiRing,
352{
353    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
354        &mut self.data[index]
355    }
356}
357
358impl<R> Index<(usize, usize)> for Matrix<R>
359where
360    R: SemiRing,
361{
362    type Output = R::T;
363    fn index(&self, index: (usize, usize)) -> &Self::Output {
364        &self.data[index.0][index.1]
365    }
366}
367
368impl<R> IndexMut<(usize, usize)> for Matrix<R>
369where
370    R: SemiRing,
371{
372    fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
373        &mut self.data[index.0][index.1]
374    }
375}
376
377macro_rules! impl_matrix_pairwise_binop {
378    ($imp:ident, $method:ident, $imp_assign:ident, $method_assign:ident $(where [$($clauses:tt)*])?) => {
379        impl<R> $imp_assign for Matrix<R>
380        where
381            R: SemiRing,
382            $($($clauses)*)?
383        {
384            fn $method_assign(&mut self, rhs: Self) {
385                self.pairwise_assign(&rhs, |a, b| R::$method_assign(a, b));
386            }
387        }
388        impl<R> $imp_assign<&Matrix<R>> for Matrix<R>
389        where
390            R: SemiRing,
391            $($($clauses)*)?
392        {
393            fn $method_assign(&mut self, rhs: &Self) {
394                self.pairwise_assign(rhs, |a, b| R::$method_assign(a, b));
395            }
396        }
397        impl<R> $imp for Matrix<R>
398        where
399            R: SemiRing,
400            $($($clauses)*)?
401        {
402            type Output = Matrix<R>;
403            fn $method(mut self, rhs: Self) -> Self::Output {
404                self.$method_assign(rhs);
405                self
406            }
407        }
408        impl<R> $imp<&Matrix<R>> for Matrix<R>
409        where
410            R: SemiRing,
411            $($($clauses)*)?
412        {
413            type Output = Matrix<R>;
414            fn $method(mut self, rhs: &Self) -> Self::Output {
415                self.$method_assign(rhs);
416                self
417            }
418        }
419        impl<R> $imp<Matrix<R>> for &Matrix<R>
420        where
421            R: SemiRing,
422            $($($clauses)*)?
423        {
424            type Output = Matrix<R>;
425            fn $method(self, mut rhs: Matrix<R>) -> Self::Output {
426                rhs.pairwise_assign(self, |a, b| *a = R::$method(b, a));
427                rhs
428            }
429        }
430        impl<R> $imp<&Matrix<R>> for &Matrix<R>
431        where
432            R: SemiRing,
433            $($($clauses)*)?
434        {
435            type Output = Matrix<R>;
436            fn $method(self, rhs: &Matrix<R>) -> Self::Output {
437                let mut this = self.clone();
438                this.$method_assign(rhs);
439                this
440            }
441        }
442    };
443}
444
445impl_matrix_pairwise_binop!(Add, add, AddAssign, add_assign);
446impl_matrix_pairwise_binop!(Sub, sub, SubAssign, sub_assign where [R::Additive: Invertible]);
447
448impl<R> Mul for Matrix<R>
449where
450    R: SemiRing,
451{
452    type Output = Matrix<R>;
453    fn mul(self, rhs: Self) -> Self::Output {
454        (&self).mul(&rhs)
455    }
456}
457impl<R> Mul<&Matrix<R>> for Matrix<R>
458where
459    R: SemiRing,
460{
461    type Output = Matrix<R>;
462    fn mul(self, rhs: &Matrix<R>) -> Self::Output {
463        (&self).mul(rhs)
464    }
465}
466impl<R> Mul<Matrix<R>> for &Matrix<R>
467where
468    R: SemiRing,
469{
470    type Output = Matrix<R>;
471    fn mul(self, rhs: Matrix<R>) -> Self::Output {
472        self.mul(&rhs)
473    }
474}
475impl<R> Mul<&Matrix<R>> for &Matrix<R>
476where
477    R: SemiRing,
478{
479    type Output = Matrix<R>;
480    fn mul(self, rhs: &Matrix<R>) -> Self::Output {
481        assert_eq!(self.shape.1, rhs.shape.0);
482        let mut res = Matrix::zeros((self.shape.0, rhs.shape.1));
483        for i in 0..self.shape.0 {
484            for k in 0..self.shape.1 {
485                for j in 0..rhs.shape.1 {
486                    R::add_assign(&mut res[i][j], &R::mul(&self[i][k], &rhs[k][j]));
487                }
488            }
489        }
490        res
491    }
492}
493
494impl<R> MulAssign<&R::T> for Matrix<R>
495where
496    R: SemiRing,
497{
498    fn mul_assign(&mut self, rhs: &R::T) {
499        for i in 0..self.shape.0 {
500            for j in 0..self.shape.1 {
501                R::mul_assign(&mut self[(i, j)], rhs);
502            }
503        }
504    }
505}
506
507impl<R> Matrix<R>
508where
509    R: SemiRing,
510{
511    pub fn pow(self, mut n: usize) -> Self {
512        assert_eq!(self.shape.0, self.shape.1);
513        let mut res = Matrix::eye(self.shape);
514        let mut x = self;
515        while n > 0 {
516            if n & 1 == 1 {
517                res = &res * &x;
518            }
519            x = &x * &x;
520            n >>= 1;
521        }
522        res
523    }
524}
525
526impl<R> SerdeByteStr for Matrix<R>
527where
528    R: SemiRing,
529    R::T: SerdeByteStr,
530{
531    fn serialize(&self, buf: &mut Vec<u8>) {
532        self.data.serialize(buf);
533    }
534
535    fn deserialize<I>(iter: &mut I) -> Self
536    where
537        I: Iterator<Item = u8>,
538    {
539        Self::from_vec(Vec::deserialize(iter))
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546    use crate::{
547        algebra::AddMulOperation,
548        num::{One, Zero, mint_basic::DynMIntU32},
549        rand, rand_value,
550        tools::{RandomSpec, Xorshift},
551    };
552
553    struct D;
554    impl RandomSpec<DynMIntU32> for D {
555        fn rand(&self, rng: &mut Xorshift) -> DynMIntU32 {
556            DynMIntU32::new_unchecked(rng.random(..DynMIntU32::get_mod()))
557        }
558    }
559
560    fn random_matrix(
561        rng: &mut Xorshift,
562        shape: (usize, usize),
563    ) -> Matrix<AddMulOperation<DynMIntU32>> {
564        if rng.gen_bool(0.5) {
565            Matrix::<AddMulOperation<_>>::new_with(shape, |_, _| rng.random(D))
566        } else if rng.gen_bool(0.5) {
567            let r = rng.randf();
568            Matrix::<AddMulOperation<_>>::new_with(shape, |_, _| {
569                if rng.gen_bool(r) {
570                    rng.random(D)
571                } else {
572                    DynMIntU32::zero()
573                }
574            })
575        } else {
576            let mut mat = Matrix::<AddMulOperation<_>>::new_with(shape, |_, _| rng.random(D));
577            let i0 = rng.random(0..shape.0);
578            let i1 = rng.random(0..shape.0);
579            let x = rng.random(D);
580            for j in 0..shape.1 {
581                mat[(i0, j)] = mat[(i1, j)] * x;
582            }
583            mat
584        }
585    }
586
587    #[test]
588    fn test_eye() {
589        for n in 0..10 {
590            for m in 0..10 {
591                let result = Matrix::<AddMulOperation<DynMIntU32>>::eye((n, m));
592                let expected = Matrix::new_with((n, m), |i, j| {
593                    if i == j {
594                        DynMIntU32::one()
595                    } else {
596                        DynMIntU32::zero()
597                    }
598                });
599                assert_eq!(result, expected);
600            }
601        }
602    }
603
604    #[test]
605    fn test_add() {
606        let mut rng = Xorshift::default();
607        for _ in 0..100 {
608            rand!(rng, n: 1..30, m: 1..30);
609            let a = Matrix::<AddMulOperation<_>>::from_vec(rand_value!(rng, [[D; m]; n]));
610            let b = Matrix::<AddMulOperation<_>>::from_vec(rand_value!(rng, [[D; m]; n]));
611            assert_eq!(&a + &b, a.clone() + b.clone());
612            assert_eq!(a.clone() + &b, a.clone() + b.clone());
613            assert_eq!(&a + b.clone(), a.clone() + b.clone());
614        }
615    }
616
617    #[test]
618    fn test_sub() {
619        let mut rng = Xorshift::default();
620        for _ in 0..100 {
621            rand!(rng, n: 1..30, m: 1..30);
622            let a = Matrix::<AddMulOperation<_>>::from_vec(rand_value!(rng, [[D; m]; n]));
623            let b = Matrix::<AddMulOperation<_>>::from_vec(rand_value!(rng, [[D; m]; n]));
624            assert_eq!(&a - &b, a.clone() - b.clone());
625            assert_eq!(a.clone() - &b, a.clone() - b.clone());
626            assert_eq!(&a - b.clone(), a.clone() - b.clone());
627        }
628    }
629
630    #[test]
631    fn test_mul() {
632        let mut rng = Xorshift::default();
633        for _ in 0..100 {
634            rand!(rng, n: 1..30, m: 1..30, l: 1..30);
635            let a = Matrix::<AddMulOperation<_>>::from_vec(rand_value!(rng, [[D; m]; n]));
636            let b = Matrix::<AddMulOperation<_>>::from_vec(rand_value!(rng, [[D; l]; m]));
637            assert_eq!(&a * &b, a.clone() * b.clone());
638            assert_eq!(a.clone() * &b, a.clone() * b.clone());
639            assert_eq!(&a * b.clone(), a.clone() * b.clone());
640            assert_eq!(
641                &a * &b,
642                Matrix::new_with((n, l), |i, j| (0..m).map(|k| a[i][k] * b[k][j]).sum())
643            );
644            let c = rand_value!(rng, D);
645            let mut ac = a.clone();
646            ac *= &c;
647            assert_eq!(ac, Matrix::new_with(a.shape, |i, j| a[i][j] * c));
648        }
649    }
650
651    #[test]
652    fn test_row_reduction() {
653        const Q: usize = 1000;
654        let mut rng = Xorshift::new();
655        let ps = [2, 3, 1_000_000_007];
656        for _ in 0..Q {
657            let m = ps[rng.random(..ps.len())];
658            DynMIntU32::set_mod(m);
659            let n = rng.random(2..=30);
660            let mat = Matrix::<AddMulOperation<_>>::from_vec(rand_value!(rng, [[D; n]; n]));
661            let rank = mat.clone().rank();
662            let inv = mat.inverse();
663            assert_eq!(rank == n, inv.is_some());
664            if let Some(inv) = inv {
665                assert_eq!(&mat * &inv, Matrix::<AddMulOperation<_>>::eye((n, n)));
666            }
667        }
668    }
669
670    #[test]
671    fn test_system_of_linear_equations() {
672        const Q: usize = 1000;
673        let mut rng = Xorshift::new();
674        let ps = [2, 3, 1_000_000_007];
675        for _ in 0..Q {
676            let p = ps[rng.random(..ps.len())];
677            DynMIntU32::set_mod(p);
678            let n = rng.random(1..=30);
679            let m = rng.random(1..=30);
680            let a = random_matrix(&mut rng, (n, m));
681            let b = random_matrix(&mut rng, (1, n))
682                .data
683                .into_iter()
684                .next()
685                .unwrap();
686            if let Some(sol) = a.solve_system_of_linear_equations(&b) {
687                assert_eq!(
688                    &a * Matrix::from_vec(vec![sol.particular.clone()]).transpose(),
689                    Matrix::from_vec(vec![b.clone()]).transpose()
690                );
691                let c = rand_value!(rng, [D; sol.basis.len()]);
692                let mut x = sol.particular.clone();
693                for (c, v) in c.iter().zip(sol.basis.iter()) {
694                    for (x, v) in x.iter_mut().zip(v.iter()) {
695                        *x += *c * *v;
696                    }
697                }
698                assert_eq!(
699                    &a * Matrix::from_vec(vec![x]).transpose(),
700                    Matrix::from_vec(vec![b]).transpose()
701                );
702            }
703        }
704    }
705}