1use super::{Field, Invertible, Ring, SemiRing, SerdeByteStr};
2use std::{
3 fmt::{self, Debug},
4 marker::PhantomData,
5 ops::{Add, AddAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign},
6};
7
8pub struct Matrix<R>
9where
10 R: SemiRing,
11{
12 pub shape: (usize, usize),
13 pub data: Vec<Vec<R::T>>,
14 _marker: PhantomData<fn() -> R>,
15}
16
17impl<R> Debug for Matrix<R>
18where
19 R: SemiRing<T: Debug>,
20{
21 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22 f.debug_struct("Matrix")
23 .field("shape", &self.shape)
24 .field("data", &self.data)
25 .field("_marker", &self._marker)
26 .finish()
27 }
28}
29
30impl<R> Clone for Matrix<R>
31where
32 R: SemiRing,
33{
34 fn clone(&self) -> Self {
35 Self {
36 shape: self.shape,
37 data: self.data.clone(),
38 _marker: self._marker,
39 }
40 }
41}
42
43impl<R> PartialEq for Matrix<R>
44where
45 R: SemiRing<T: PartialEq>,
46{
47 fn eq(&self, other: &Self) -> bool {
48 self.shape == other.shape && self.data == other.data
49 }
50}
51
52impl<R> Eq for Matrix<R> where R: SemiRing<T: Eq> {}
53
54impl<R> Matrix<R>
55where
56 R: SemiRing,
57{
58 pub fn new(shape: (usize, usize), z: R::T) -> Self {
59 Self {
60 shape,
61 data: vec![vec![z; shape.1]; shape.0],
62 _marker: PhantomData,
63 }
64 }
65
66 pub fn from_vec(data: Vec<Vec<R::T>>) -> Self {
67 let shape = (data.len(), data.first().map(Vec::len).unwrap_or_default());
68 assert!(data.iter().all(|r| r.len() == shape.1));
69 Self {
70 shape,
71 data,
72 _marker: PhantomData,
73 }
74 }
75
76 pub fn new_with(shape: (usize, usize), mut f: impl FnMut(usize, usize) -> R::T) -> Self {
77 let data = (0..shape.0)
78 .map(|i| (0..shape.1).map(|j| f(i, j)).collect())
79 .collect();
80 Self {
81 shape,
82 data,
83 _marker: PhantomData,
84 }
85 }
86
87 pub fn zeros(shape: (usize, usize)) -> Self {
88 Self {
89 shape,
90 data: vec![vec![R::zero(); shape.1]; shape.0],
91 _marker: PhantomData,
92 }
93 }
94
95 pub fn eye(shape: (usize, usize)) -> Self {
96 let mut data = vec![vec![R::zero(); shape.1]; shape.0];
97 for (i, d) in data.iter_mut().enumerate().take(shape.1) {
98 d[i] = R::one();
99 }
100 Self {
101 shape,
102 data,
103 _marker: PhantomData,
104 }
105 }
106
107 pub fn transpose(&self) -> Self {
108 Self::new_with((self.shape.1, self.shape.0), |i, j| self[j][i].clone())
109 }
110
111 pub fn map<S, F>(&self, mut f: F) -> Matrix<S>
112 where
113 S: SemiRing,
114 F: FnMut(&R::T) -> S::T,
115 {
116 Matrix::<S>::new_with(self.shape, |i, j| f(&self[i][j]))
117 }
118
119 pub fn add_row_with(&mut self, mut f: impl FnMut(usize, usize) -> R::T) {
120 self.data
121 .push((0..self.shape.1).map(|j| f(self.shape.0, j)).collect());
122 self.shape.0 += 1;
123 }
124
125 pub fn add_col_with(&mut self, mut f: impl FnMut(usize, usize) -> R::T) {
126 for i in 0..self.shape.0 {
127 self.data[i].push(f(i, self.shape.1));
128 }
129 self.shape.1 += 1;
130 }
131
132 pub fn pairwise_assign<F>(&mut self, other: &Self, mut f: F)
133 where
134 F: FnMut(&mut R::T, &R::T),
135 {
136 assert_eq!(self.shape, other.shape);
137 for i in 0..self.shape.0 {
138 for j in 0..self.shape.1 {
139 f(&mut self[i][j], &other[i][j]);
140 }
141 }
142 }
143}
144
145#[derive(Debug)]
146pub struct SystemOfLinearEquationsSolution<R>
147where
148 R: Field<Additive: Invertible, Multiplicative: Invertible>,
149{
150 pub particular: Vec<R::T>,
151 pub basis: Vec<Vec<R::T>>,
152}
153
154impl<R> Matrix<R>
155where
156 R: Field<T: PartialEq, Additive: Invertible, Multiplicative: Invertible>,
157{
158 pub fn row_reduction_with<F>(&mut self, normalize: bool, mut f: F)
160 where
161 F: FnMut(usize, usize, usize),
162 {
163 let (n, m) = self.shape;
164 let mut c = 0;
165 for r in 0..n {
166 loop {
167 if c >= m {
168 return;
169 }
170 if let Some(pivot) = (r..n).find(|&p| !R::is_zero(&self[p][c])) {
171 f(r, pivot, c);
172 self.data.swap(r, pivot);
173 break;
174 };
175 c += 1;
176 }
177 let d = R::inv(&self[r][c]);
178 if normalize {
179 for j in c..m {
180 R::mul_assign(&mut self[r][j], &d);
181 }
182 }
183 for i in (0..n).filter(|&i| i != r) {
184 let mut e = self[i][c].clone();
185 if !normalize {
186 R::mul_assign(&mut e, &d);
187 }
188 for j in c..m {
189 let e = R::mul(&e, &self[r][j]);
190 R::sub_assign(&mut self[i][j], &e);
191 }
192 }
193 c += 1;
194 }
195 }
196
197 pub fn row_reduction(&mut self, normalize: bool) {
198 self.row_reduction_with(normalize, |_, _, _| {});
199 }
200
201 pub fn rank(&mut self) -> usize {
202 let n = self.shape.0;
203 self.row_reduction(false);
204 (0..n)
205 .filter(|&i| !self.data[i].iter().all(|x| R::is_zero(x)))
206 .count()
207 }
208
209 pub fn determinant(&mut self) -> R::T {
210 assert_eq!(self.shape.0, self.shape.1);
211 let mut neg = false;
212 self.row_reduction_with(false, |r, p, _| neg ^= r != p);
213 let mut d = R::one();
214 if neg {
215 d = R::neg(&d);
216 }
217 for i in 0..self.shape.0 {
218 R::mul_assign(&mut d, &self[i][i]);
219 }
220 d
221 }
222
223 pub fn solve_system_of_linear_equations(
224 &self,
225 b: &[R::T],
226 ) -> Option<SystemOfLinearEquationsSolution<R>> {
227 assert_eq!(self.shape.0, b.len());
228 let (n, m) = self.shape;
229 let mut c = Matrix::<R>::zeros((n, m + 1));
230 for i in 0..n {
231 c[i][..m].clone_from_slice(&self[i]);
232 c[i][m] = b[i].clone();
233 }
234 let mut reduced = vec![!0; m + 1];
235 c.row_reduction_with(true, |r, _, c| reduced[c] = r);
236 if reduced[m] != !0 {
237 return None;
238 }
239 let mut particular = vec![R::zero(); m];
240 let mut basis = vec![];
241 for j in 0..m {
242 if reduced[j] != !0 {
243 particular[j] = c[reduced[j]][m].clone();
244 } else {
245 let mut v = vec![R::zero(); m];
246 v[j] = R::one();
247 for i in 0..m {
248 if reduced[i] != !0 {
249 R::sub_assign(&mut v[i], &c[reduced[i]][j]);
250 }
251 }
252 basis.push(v);
253 }
254 }
255 Some(SystemOfLinearEquationsSolution { particular, basis })
256 }
257
258 pub fn inverse(&self) -> Option<Matrix<R>> {
259 assert_eq!(self.shape.0, self.shape.1);
260 let n = self.shape.0;
261 let mut c = Matrix::<R>::zeros((n, n * 2));
262 for i in 0..n {
263 c[i][..n].clone_from_slice(&self[i]);
264 c[i][n + i] = R::one();
265 }
266 c.row_reduction(true);
267 if (0..n).any(|i| R::is_zero(&c[i][i])) {
268 None
269 } else {
270 Some(Self::from_vec(
271 c.data.into_iter().map(|r| r[n..].to_vec()).collect(),
272 ))
273 }
274 }
275
276 pub fn characteristic_polynomial(&mut self) -> Vec<R::T> {
277 let n = self.shape.0;
278 if n == 0 {
279 return vec![R::one()];
280 }
281 assert!(self.data.iter().all(|a| a.len() == n));
282 for j in 0..(n - 1) {
283 if let Some(x) = ((j + 1)..n).find(|&x| !R::is_zero(&self[x][j])) {
284 self.data.swap(j + 1, x);
285 self.data.iter_mut().for_each(|a| a.swap(j + 1, x));
286 let inv = R::inv(&self[j + 1][j]);
287 let mut v = vec![];
288 let src = std::mem::take(&mut self[j + 1]);
289 for a in self.data[(j + 2)..].iter_mut() {
290 let mul = R::mul(&a[j], &inv);
291 for (a, src) in a[j..].iter_mut().zip(src[j..].iter()) {
292 R::sub_assign(a, &R::mul(&mul, src));
293 }
294 v.push(mul);
295 }
296 self[j + 1] = src;
297 for a in self.data.iter_mut() {
298 let v = a[(j + 2)..]
299 .iter()
300 .zip(v.iter())
301 .fold(R::zero(), |s, a| R::add(&s, &R::mul(a.0, a.1)));
302 R::add_assign(&mut a[j + 1], &v);
303 }
304 }
305 }
306 let mut dp = vec![vec![R::one()]];
307 for i in 0..n {
308 let mut next = vec![R::zero(); i + 2];
309 for (j, dp) in dp[i].iter().enumerate() {
310 R::sub_assign(&mut next[j], &R::mul(dp, &self[i][i]));
311 R::add_assign(&mut next[j + 1], dp);
312 }
313 let mut mul = R::one();
314 for j in (0..i).rev() {
315 mul = R::mul(&mul, &self[j + 1][j]);
316 let c = R::mul(&mul, &self[j][i]);
317 for (next, dp) in next.iter_mut().zip(dp[j].iter()) {
318 R::sub_assign(next, &R::mul(&c, dp));
319 }
320 }
321 dp.push(next);
322 }
323 dp.pop().unwrap()
324 }
325}
326
327impl<R> Index<usize> for Matrix<R>
328where
329 R: SemiRing,
330{
331 type Output = Vec<R::T>;
332 fn index(&self, index: usize) -> &Self::Output {
333 &self.data[index]
334 }
335}
336
337impl<R> IndexMut<usize> for Matrix<R>
338where
339 R: SemiRing,
340{
341 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
342 &mut self.data[index]
343 }
344}
345
346impl<R> Index<(usize, usize)> for Matrix<R>
347where
348 R: SemiRing,
349{
350 type Output = R::T;
351 fn index(&self, index: (usize, usize)) -> &Self::Output {
352 &self.data[index.0][index.1]
353 }
354}
355
356impl<R> IndexMut<(usize, usize)> for Matrix<R>
357where
358 R: SemiRing,
359{
360 fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
361 &mut self.data[index.0][index.1]
362 }
363}
364
365macro_rules! impl_matrix_pairwise_binop {
366 ($imp:ident, $method:ident, $imp_assign:ident, $method_assign:ident $(where [$($clauses:tt)*])?) => {
367 impl<R> $imp_assign for Matrix<R>
368 where
369 R: SemiRing,
370 $($($clauses)*)?
371 {
372 fn $method_assign(&mut self, rhs: Self) {
373 self.pairwise_assign(&rhs, |a, b| R::$method_assign(a, b));
374 }
375 }
376 impl<R> $imp_assign<&Matrix<R>> for Matrix<R>
377 where
378 R: SemiRing,
379 $($($clauses)*)?
380 {
381 fn $method_assign(&mut self, rhs: &Self) {
382 self.pairwise_assign(rhs, |a, b| R::$method_assign(a, b));
383 }
384 }
385 impl<R> $imp for Matrix<R>
386 where
387 R: SemiRing,
388 $($($clauses)*)?
389 {
390 type Output = Matrix<R>;
391 fn $method(mut self, rhs: Self) -> Self::Output {
392 self.$method_assign(rhs);
393 self
394 }
395 }
396 impl<R> $imp<&Matrix<R>> for Matrix<R>
397 where
398 R: SemiRing,
399 $($($clauses)*)?
400 {
401 type Output = Matrix<R>;
402 fn $method(mut self, rhs: &Self) -> Self::Output {
403 self.$method_assign(rhs);
404 self
405 }
406 }
407 impl<R> $imp<Matrix<R>> for &Matrix<R>
408 where
409 R: SemiRing,
410 $($($clauses)*)?
411 {
412 type Output = Matrix<R>;
413 fn $method(self, mut rhs: Matrix<R>) -> Self::Output {
414 rhs.pairwise_assign(self, |a, b| *a = R::$method(b, a));
415 rhs
416 }
417 }
418 impl<R> $imp<&Matrix<R>> for &Matrix<R>
419 where
420 R: SemiRing,
421 $($($clauses)*)?
422 {
423 type Output = Matrix<R>;
424 fn $method(self, rhs: &Matrix<R>) -> Self::Output {
425 let mut this = self.clone();
426 this.$method_assign(rhs);
427 this
428 }
429 }
430 };
431}
432
433impl_matrix_pairwise_binop!(Add, add, AddAssign, add_assign);
434impl_matrix_pairwise_binop!(Sub, sub, SubAssign, sub_assign where [R: SemiRing<Additive: Invertible>]);
435
436impl<R> Mul for Matrix<R>
437where
438 R: SemiRing,
439{
440 type Output = Matrix<R>;
441 fn mul(self, rhs: Self) -> Self::Output {
442 (&self).mul(&rhs)
443 }
444}
445impl<R> Mul<&Matrix<R>> for Matrix<R>
446where
447 R: SemiRing,
448{
449 type Output = Matrix<R>;
450 fn mul(self, rhs: &Matrix<R>) -> Self::Output {
451 (&self).mul(rhs)
452 }
453}
454impl<R> Mul<Matrix<R>> for &Matrix<R>
455where
456 R: SemiRing,
457{
458 type Output = Matrix<R>;
459 fn mul(self, rhs: Matrix<R>) -> Self::Output {
460 self.mul(&rhs)
461 }
462}
463impl<R> Mul<&Matrix<R>> for &Matrix<R>
464where
465 R: SemiRing,
466{
467 type Output = Matrix<R>;
468 fn mul(self, rhs: &Matrix<R>) -> Self::Output {
469 assert_eq!(self.shape.1, rhs.shape.0);
470 let mut res = Matrix::zeros((self.shape.0, rhs.shape.1));
471 for i in 0..self.shape.0 {
472 for k in 0..self.shape.1 {
473 for j in 0..rhs.shape.1 {
474 R::add_assign(&mut res[i][j], &R::mul(&self[i][k], &rhs[k][j]));
475 }
476 }
477 }
478 res
479 }
480}
481
482impl<R> MulAssign<&R::T> for Matrix<R>
483where
484 R: SemiRing,
485{
486 fn mul_assign(&mut self, rhs: &R::T) {
487 for i in 0..self.shape.0 {
488 for j in 0..self.shape.1 {
489 R::mul_assign(&mut self[(i, j)], rhs);
490 }
491 }
492 }
493}
494
495impl<R> Neg for Matrix<R>
496where
497 R: SemiRing<Additive: Invertible>,
498{
499 type Output = Self;
500
501 fn neg(self) -> Self::Output {
502 self.map(|x| R::neg(x))
503 }
504}
505
506impl<R> Neg for &Matrix<R>
507where
508 R: SemiRing<Additive: Invertible>,
509{
510 type Output = Matrix<R>;
511
512 fn neg(self) -> Self::Output {
513 self.map(|x| R::neg(x))
514 }
515}
516
517impl<R> Matrix<R>
518where
519 R: SemiRing,
520{
521 pub fn pow(self, mut n: usize) -> Self {
522 assert_eq!(self.shape.0, self.shape.1);
523 let mut res = Matrix::eye(self.shape);
524 let mut x = self;
525 while n > 0 {
526 if n & 1 == 1 {
527 res = &res * &x;
528 }
529 x = &x * &x;
530 n >>= 1;
531 }
532 res
533 }
534}
535
536impl<R> SerdeByteStr for Matrix<R>
537where
538 R: SemiRing<T: SerdeByteStr>,
539{
540 fn serialize(&self, buf: &mut Vec<u8>) {
541 self.data.serialize(buf);
542 }
543
544 fn deserialize<I>(iter: &mut I) -> Self
545 where
546 I: Iterator<Item = u8>,
547 {
548 Self::from_vec(Vec::deserialize(iter))
549 }
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555 use crate::{
556 algebra::AddMulOperation,
557 num::{One, Zero, mint_basic::DynMIntU32},
558 rand, rand_value,
559 tools::Xorshift,
560 };
561
562 type R = AddMulOperation<DynMIntU32>;
563
564 fn random_matrix(rng: &mut Xorshift, shape: (usize, usize)) -> Matrix<R> {
565 if rng.gen_bool(0.5) {
566 Matrix::new_with(shape, |_, _| rng.random(..))
567 } else if rng.gen_bool(0.5) {
568 let r = rng.randf();
569 Matrix::new_with(shape, |_, _| {
570 if rng.gen_bool(r) {
571 rng.random(..)
572 } else {
573 DynMIntU32::zero()
574 }
575 })
576 } else {
577 let mut mat = Matrix::new_with(shape, |_, _| rng.random(..));
578 let i0 = rng.random(0..shape.0);
579 let i1 = rng.random(0..shape.0);
580 let x: DynMIntU32 = rng.random(..);
581 for j in 0..shape.1 {
582 mat[(i0, j)] = mat[(i1, j)] * x;
583 }
584 mat
585 }
586 }
587
588 #[test]
589 fn test_eye() {
590 for n in 0..10 {
591 for m in 0..10 {
592 let result = Matrix::<R>::eye((n, m));
593 let expected = Matrix::<R>::new_with((n, m), |i, j| {
594 if i == j {
595 DynMIntU32::one()
596 } else {
597 DynMIntU32::zero()
598 }
599 });
600 assert_eq!(result, expected);
601 }
602 }
603 }
604
605 #[test]
606 fn test_add() {
607 let mut rng = Xorshift::default();
608 for _ in 0..100 {
609 rand!(rng, n: 1..30, m: 1..30);
610 let a = Matrix::<R>::new_with((n, m), |_, _| rng.random(..));
611 let b = Matrix::<R>::new_with((n, m), |_, _| rng.random(..));
612 assert_eq!(&a + &b, a.clone() + b.clone());
613 assert_eq!(a.clone() + &b, a.clone() + b.clone());
614 assert_eq!(&a + b.clone(), a.clone() + b.clone());
615 }
616 }
617
618 #[test]
619 fn test_sub() {
620 let mut rng = Xorshift::default();
621 for _ in 0..100 {
622 rand!(rng, n: 1..30, m: 1..30);
623 let a = Matrix::<R>::new_with((n, m), |_, _| rng.random(..));
624 let b = Matrix::<R>::new_with((n, m), |_, _| rng.random(..));
625 assert_eq!(&a - &b, a.clone() - b.clone());
626 assert_eq!(a.clone() - &b, a.clone() - b.clone());
627 assert_eq!(&a - b.clone(), a.clone() - b.clone());
628 }
629 }
630
631 #[test]
632 fn test_mul() {
633 let mut rng = Xorshift::default();
634 for _ in 0..100 {
635 rand!(rng, n: 1..30, m: 1..30, l: 1..30);
636 let a = Matrix::<R>::new_with((n, m), |_, _| rng.random(..));
637 let b = Matrix::<R>::new_with((m, l), |_, _| rng.random(..));
638 assert_eq!(&a * &b, a.clone() * b.clone());
639 assert_eq!(a.clone() * &b, a.clone() * b.clone());
640 assert_eq!(&a * b.clone(), a.clone() * b.clone());
641 assert_eq!(
642 &a * &b,
643 Matrix::new_with((n, l), |i, j| (0..m).map(|k| a[i][k] * b[k][j]).sum())
644 );
645 let c = rng.random(..);
646 let mut ac = a.clone();
647 ac *= &c;
648 assert_eq!(ac, Matrix::new_with(a.shape, |i, j| a[i][j] * c));
649 }
650 }
651
652 #[test]
653 fn test_row_reduction() {
654 const Q: usize = 1000;
655 let mut rng = Xorshift::default();
656 let ps = [2, 3, 1_000_000_007];
657 for _ in 0..Q {
658 let m = ps[rng.random(..ps.len())];
659 DynMIntU32::set_mod(m);
660 let n = rng.random(2..=30);
661 let mat = Matrix::<R>::new_with((n, n), |_, _| rng.random(..));
662 let rank = mat.clone().rank();
663 let inv = mat.inverse();
664 assert_eq!(rank == n, inv.is_some());
665 if let Some(inv) = inv {
666 assert_eq!(&mat * &inv, Matrix::eye((n, n)));
667 }
668 }
669 }
670
671 #[test]
672 fn test_system_of_linear_equations() {
673 const Q: usize = 1000;
674 let mut rng = Xorshift::default();
675 let ps = [2, 3, 1_000_000_007];
676 for _ in 0..Q {
677 let p = ps[rng.random(..ps.len())];
678 DynMIntU32::set_mod(p);
679 let n = rng.random(1..=30);
680 let m = rng.random(1..=30);
681 let a = random_matrix(&mut rng, (n, m));
682 let b = random_matrix(&mut rng, (1, n))
683 .data
684 .into_iter()
685 .next()
686 .unwrap();
687 if let Some(sol) = a.solve_system_of_linear_equations(&b) {
688 assert_eq!(
689 &a * Matrix::from_vec(vec![sol.particular.clone()]).transpose(),
690 Matrix::from_vec(vec![b.clone()]).transpose()
691 );
692 let c: Vec<DynMIntU32> = rand_value!(rng, [..; sol.basis.len()]);
693 let mut x = sol.particular.clone();
694 for (c, v) in c.iter().zip(sol.basis.iter()) {
695 for (x, v) in x.iter_mut().zip(v.iter()) {
696 *x += *c * *v;
697 }
698 }
699 assert_eq!(
700 &a * Matrix::from_vec(vec![x]).transpose(),
701 Matrix::from_vec(vec![b]).transpose()
702 );
703 }
704 }
705 }
706}