competitive/math/
matrix.rs

1use super::{One, Zero};
2use std::ops::{Add, Div, Index, IndexMut, Mul, Sub};
3
4#[derive(Clone, Debug, PartialEq, Eq)]
5pub struct Matrix<T> {
6    pub shape: (usize, usize),
7    pub data: Vec<Vec<T>>,
8}
9
10impl<T: Clone> Matrix<T> {
11    pub fn new(shape: (usize, usize), z: T) -> Self {
12        Self {
13            shape,
14            data: vec![vec![z; shape.1]; shape.0],
15        }
16    }
17}
18impl<T> Matrix<T> {
19    pub fn from_vec(data: Vec<Vec<T>>) -> Self {
20        Self {
21            shape: (data.len(), data.first().map(Vec::len).unwrap_or_default()),
22            data,
23        }
24    }
25}
26impl<T> Matrix<T>
27where
28    T: Clone + Zero,
29{
30    pub fn zeros(shape: (usize, usize)) -> Self {
31        Self {
32            shape,
33            data: vec![vec![Zero::zero(); shape.1]; shape.0],
34        }
35    }
36}
37impl<T> Matrix<T>
38where
39    T: Clone + Zero + One,
40{
41    pub fn eye(shape: (usize, usize)) -> Self {
42        let mut data = vec![vec![Zero::zero(); shape.1]; shape.0];
43        for (i, d) in data.iter_mut().enumerate() {
44            d[i] = One::one();
45        }
46        Self { shape, data }
47    }
48}
49impl<T> Index<usize> for Matrix<T> {
50    type Output = Vec<T>;
51    fn index(&self, index: usize) -> &Self::Output {
52        &self.data[index]
53    }
54}
55impl<T> IndexMut<usize> for Matrix<T> {
56    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
57        &mut self.data[index]
58    }
59}
60impl<T> Index<(usize, usize)> for Matrix<T> {
61    type Output = T;
62    fn index(&self, index: (usize, usize)) -> &Self::Output {
63        &self.data[index.0][index.1]
64    }
65}
66impl<T> IndexMut<(usize, usize)> for Matrix<T> {
67    fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
68        &mut self.data[index.0][index.1]
69    }
70}
71impl<T> Add for &Matrix<T>
72where
73    T: Copy + Zero + Add<Output = T>,
74{
75    type Output = Matrix<T>;
76    fn add(self, rhs: Self) -> Self::Output {
77        assert_eq!(self.shape, rhs.shape);
78        let mut res = Matrix::zeros(self.shape);
79        for i in 0..self.shape.0 {
80            for j in 0..self.shape.1 {
81                res[i][j] = self[i][j] + rhs[i][j];
82            }
83        }
84        res
85    }
86}
87impl<T> Sub for &Matrix<T>
88where
89    T: Copy + Zero + Sub<Output = T>,
90{
91    type Output = Matrix<T>;
92    fn sub(self, rhs: Self) -> Self::Output {
93        assert_eq!(self.shape, rhs.shape);
94        let mut res = Matrix::zeros(self.shape);
95        for i in 0..self.shape.0 {
96            for j in 0..self.shape.1 {
97                res[i][j] = self[i][j] - rhs[i][j];
98            }
99        }
100        res
101    }
102}
103impl<T> Mul for &Matrix<T>
104where
105    T: Copy + Zero + Add<Output = T> + Mul<Output = T>,
106{
107    type Output = Matrix<T>;
108    fn mul(self, rhs: Self) -> Self::Output {
109        assert_eq!(self.shape.1, rhs.shape.0);
110        let mut res = Matrix::zeros((self.shape.0, rhs.shape.1));
111        for i in 0..self.shape.0 {
112            for j in 0..rhs.shape.1 {
113                for k in 0..self.shape.1 {
114                    res[i][j] = res[i][j] + self[i][k] * rhs[k][j];
115                }
116            }
117        }
118        res
119    }
120}
121impl<T> Matrix<T>
122where
123    T: Copy + Zero + One + Add<Output = T> + Mul<Output = T>,
124{
125    pub fn pow(&self, mut n: usize) -> Self {
126        assert_eq!(self.shape.0, self.shape.1);
127        let mut x = self.clone();
128        let mut res = Matrix::eye(self.shape);
129        while n > 0 {
130            if n & 1 == 1 {
131                res = &res * &x;
132            }
133            x = &x * &x;
134            n >>= 1;
135        }
136        res
137    }
138}
139impl<T> Matrix<T>
140where
141    T: Copy + PartialEq + Zero + One + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
142{
143    pub fn row_reduction(&mut self, normalize: bool) {
144        let (n, m) = self.shape;
145        let mut c = 0;
146        for r in 0..n {
147            loop {
148                if c >= m {
149                    return;
150                }
151                if let Some(pivot) = (r..n).find(|&p| !self[p][c].is_zero()) {
152                    self.data.swap(r, pivot);
153                    break;
154                };
155                c += 1;
156            }
157            let d = T::one() / self[r][c];
158            if normalize {
159                for j in c..m {
160                    self[r][j] = self[r][j] * d;
161                }
162            }
163            for i in (0..n).filter(|&i| i != r) {
164                let mut e = self[i][c];
165                if !normalize {
166                    e = e * d;
167                }
168                for j in c..m {
169                    self[i][j] = self[i][j] - e * self[r][j];
170                }
171            }
172            c += 1;
173        }
174    }
175    pub fn rank(&mut self) -> usize {
176        let n = self.shape.0;
177        self.row_reduction(false);
178        (0..n)
179            .filter(|&i| !self.data[i].iter().all(|x| x.is_zero()))
180            .count()
181    }
182    pub fn determinant(&mut self) -> T {
183        assert_eq!(self.shape.0, self.shape.1);
184        self.row_reduction(false);
185        let mut d = T::one();
186        for i in 0..self.shape.0 {
187            d = d * self[i][i];
188        }
189        d
190    }
191    pub fn solve_system_of_linear_equations(&self, b: &[T]) -> Option<Vec<T>> {
192        assert_eq!(self.shape.0, b.len());
193        let (n, m) = self.shape;
194        let mut c = Matrix::<T>::zeros((n, m + 1));
195        for i in 0..n {
196            c[i][..m].clone_from_slice(&self[i]);
197            c[i][m] = b[i];
198        }
199        c.row_reduction(true);
200        let mut x = vec![T::zero(); m];
201        for i in 0..n {
202            let mut j = 0usize;
203            while j <= m && c[i][j].is_zero() {
204                j += 1;
205            }
206            if j == m {
207                return None;
208            }
209            if j < m {
210                x[j] = c[i][m];
211            }
212        }
213        Some(x)
214    }
215    pub fn inverse(&self) -> Option<Matrix<T>> {
216        assert_eq!(self.shape.0, self.shape.1);
217        let n = self.shape.0;
218        let mut c = Matrix::<T>::zeros((n, n * 2));
219        for i in 0..n {
220            c[i][..n].clone_from_slice(&self[i]);
221            c[i][n + i] = T::one();
222        }
223        c.row_reduction(true);
224        if (0..n).any(|i| c[i][i].is_zero()) {
225            None
226        } else {
227            Some(Self::from_vec(
228                c.data.into_iter().map(|r| r[n..].to_vec()).collect(),
229            ))
230        }
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use crate::{
238        num::mint_basic::DynMIntU32,
239        rand_value,
240        tools::{RandomSpec, Xorshift},
241    };
242    struct D;
243    impl RandomSpec<DynMIntU32> for D {
244        fn rand(&self, rng: &mut Xorshift) -> DynMIntU32 {
245            DynMIntU32::new_unchecked(rng.random(..DynMIntU32::get_mod()))
246        }
247    }
248
249    #[test]
250    fn test_row_reduction() {
251        const Q: usize = 1000;
252        let mut rng = Xorshift::new();
253        let ps = [2, 3, 1_000_000_007];
254        for _ in 0..Q {
255            let m = ps[rng.random(..ps.len())];
256            DynMIntU32::set_mod(m);
257            let n = rng.random(2..=30);
258            let mat = Matrix::from_vec(rand_value!(rng, [[D; n]; n]));
259            let rank = mat.clone().rank();
260            let inv = mat.inverse();
261            assert_eq!(rank == n, inv.is_some());
262            if let Some(inv) = inv {
263                assert_eq!(&mat * &inv, Matrix::eye((n, n)));
264            }
265        }
266    }
267}