competitive/math/
matrix.rs1use super::{One, Zero};
2use std::ops::{Add, Div, Index, IndexMut, Mul, Sub};
3
4#[derive(Clone, Debug, PartialEq, Eq)]
5pub struct Matrix<T> {
6 pub shape: (usize, usize),
7 pub data: Vec<Vec<T>>,
8}
9
10impl<T: Clone> Matrix<T> {
11 pub fn new(shape: (usize, usize), z: T) -> Self {
12 Self {
13 shape,
14 data: vec![vec![z; shape.1]; shape.0],
15 }
16 }
17}
18impl<T> Matrix<T> {
19 pub fn from_vec(data: Vec<Vec<T>>) -> Self {
20 Self {
21 shape: (data.len(), data.first().map(Vec::len).unwrap_or_default()),
22 data,
23 }
24 }
25}
26impl<T> Matrix<T>
27where
28 T: Clone + Zero,
29{
30 pub fn zeros(shape: (usize, usize)) -> Self {
31 Self {
32 shape,
33 data: vec![vec![Zero::zero(); shape.1]; shape.0],
34 }
35 }
36}
37impl<T> Matrix<T>
38where
39 T: Clone + Zero + One,
40{
41 pub fn eye(shape: (usize, usize)) -> Self {
42 let mut data = vec![vec![Zero::zero(); shape.1]; shape.0];
43 for (i, d) in data.iter_mut().enumerate() {
44 d[i] = One::one();
45 }
46 Self { shape, data }
47 }
48}
49impl<T> Index<usize> for Matrix<T> {
50 type Output = Vec<T>;
51 fn index(&self, index: usize) -> &Self::Output {
52 &self.data[index]
53 }
54}
55impl<T> IndexMut<usize> for Matrix<T> {
56 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
57 &mut self.data[index]
58 }
59}
60impl<T> Index<(usize, usize)> for Matrix<T> {
61 type Output = T;
62 fn index(&self, index: (usize, usize)) -> &Self::Output {
63 &self.data[index.0][index.1]
64 }
65}
66impl<T> IndexMut<(usize, usize)> for Matrix<T> {
67 fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
68 &mut self.data[index.0][index.1]
69 }
70}
71impl<T> Add for &Matrix<T>
72where
73 T: Copy + Zero + Add<Output = T>,
74{
75 type Output = Matrix<T>;
76 fn add(self, rhs: Self) -> Self::Output {
77 assert_eq!(self.shape, rhs.shape);
78 let mut res = Matrix::zeros(self.shape);
79 for i in 0..self.shape.0 {
80 for j in 0..self.shape.1 {
81 res[i][j] = self[i][j] + rhs[i][j];
82 }
83 }
84 res
85 }
86}
87impl<T> Sub for &Matrix<T>
88where
89 T: Copy + Zero + Sub<Output = T>,
90{
91 type Output = Matrix<T>;
92 fn sub(self, rhs: Self) -> Self::Output {
93 assert_eq!(self.shape, rhs.shape);
94 let mut res = Matrix::zeros(self.shape);
95 for i in 0..self.shape.0 {
96 for j in 0..self.shape.1 {
97 res[i][j] = self[i][j] - rhs[i][j];
98 }
99 }
100 res
101 }
102}
103impl<T> Mul for &Matrix<T>
104where
105 T: Copy + Zero + Add<Output = T> + Mul<Output = T>,
106{
107 type Output = Matrix<T>;
108 fn mul(self, rhs: Self) -> Self::Output {
109 assert_eq!(self.shape.1, rhs.shape.0);
110 let mut res = Matrix::zeros((self.shape.0, rhs.shape.1));
111 for i in 0..self.shape.0 {
112 for j in 0..rhs.shape.1 {
113 for k in 0..self.shape.1 {
114 res[i][j] = res[i][j] + self[i][k] * rhs[k][j];
115 }
116 }
117 }
118 res
119 }
120}
121impl<T> Matrix<T>
122where
123 T: Copy + Zero + One + Add<Output = T> + Mul<Output = T>,
124{
125 pub fn pow(&self, mut n: usize) -> Self {
126 assert_eq!(self.shape.0, self.shape.1);
127 let mut x = self.clone();
128 let mut res = Matrix::eye(self.shape);
129 while n > 0 {
130 if n & 1 == 1 {
131 res = &res * &x;
132 }
133 x = &x * &x;
134 n >>= 1;
135 }
136 res
137 }
138}
139impl<T> Matrix<T>
140where
141 T: Copy + PartialEq + Zero + One + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
142{
143 pub fn row_reduction(&mut self, normalize: bool) {
144 let (n, m) = self.shape;
145 let mut c = 0;
146 for r in 0..n {
147 loop {
148 if c >= m {
149 return;
150 }
151 if let Some(pivot) = (r..n).find(|&p| !self[p][c].is_zero()) {
152 self.data.swap(r, pivot);
153 break;
154 };
155 c += 1;
156 }
157 let d = T::one() / self[r][c];
158 if normalize {
159 for j in c..m {
160 self[r][j] = self[r][j] * d;
161 }
162 }
163 for i in (0..n).filter(|&i| i != r) {
164 let mut e = self[i][c];
165 if !normalize {
166 e = e * d;
167 }
168 for j in c..m {
169 self[i][j] = self[i][j] - e * self[r][j];
170 }
171 }
172 c += 1;
173 }
174 }
175 pub fn rank(&mut self) -> usize {
176 let n = self.shape.0;
177 self.row_reduction(false);
178 (0..n)
179 .filter(|&i| !self.data[i].iter().all(|x| x.is_zero()))
180 .count()
181 }
182 pub fn determinant(&mut self) -> T {
183 assert_eq!(self.shape.0, self.shape.1);
184 self.row_reduction(false);
185 let mut d = T::one();
186 for i in 0..self.shape.0 {
187 d = d * self[i][i];
188 }
189 d
190 }
191 pub fn solve_system_of_linear_equations(&self, b: &[T]) -> Option<Vec<T>> {
192 assert_eq!(self.shape.0, b.len());
193 let (n, m) = self.shape;
194 let mut c = Matrix::<T>::zeros((n, m + 1));
195 for i in 0..n {
196 c[i][..m].clone_from_slice(&self[i]);
197 c[i][m] = b[i];
198 }
199 c.row_reduction(true);
200 let mut x = vec![T::zero(); m];
201 for i in 0..n {
202 let mut j = 0usize;
203 while j <= m && c[i][j].is_zero() {
204 j += 1;
205 }
206 if j == m {
207 return None;
208 }
209 if j < m {
210 x[j] = c[i][m];
211 }
212 }
213 Some(x)
214 }
215 pub fn inverse(&self) -> Option<Matrix<T>> {
216 assert_eq!(self.shape.0, self.shape.1);
217 let n = self.shape.0;
218 let mut c = Matrix::<T>::zeros((n, n * 2));
219 for i in 0..n {
220 c[i][..n].clone_from_slice(&self[i]);
221 c[i][n + i] = T::one();
222 }
223 c.row_reduction(true);
224 if (0..n).any(|i| c[i][i].is_zero()) {
225 None
226 } else {
227 Some(Self::from_vec(
228 c.data.into_iter().map(|r| r[n..].to_vec()).collect(),
229 ))
230 }
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use crate::{
238 num::mint_basic::DynMIntU32,
239 rand_value,
240 tools::{RandomSpec, Xorshift},
241 };
242 struct D;
243 impl RandomSpec<DynMIntU32> for D {
244 fn rand(&self, rng: &mut Xorshift) -> DynMIntU32 {
245 DynMIntU32::new_unchecked(rng.random(..DynMIntU32::get_mod()))
246 }
247 }
248
249 #[test]
250 fn test_row_reduction() {
251 const Q: usize = 1000;
252 let mut rng = Xorshift::new();
253 let ps = [2, 3, 1_000_000_007];
254 for _ in 0..Q {
255 let m = ps[rng.random(..ps.len())];
256 DynMIntU32::set_mod(m);
257 let n = rng.random(2..=30);
258 let mat = Matrix::from_vec(rand_value!(rng, [[D; n]; n]));
259 let rank = mat.clone().rank();
260 let inv = mat.inverse();
261 assert_eq!(rank == n, inv.is_some());
262 if let Some(inv) = inv {
263 assert_eq!(&mat * &inv, Matrix::eye((n, n)));
264 }
265 }
266 }
267}