1use super::{Field, Invertible, Ring, SemiRing, SerdeByteStr};
2use std::{
3 fmt::{self, Debug},
4 marker::PhantomData,
5 ops::{Add, AddAssign, Index, IndexMut, Mul, MulAssign, Sub, SubAssign},
6};
7
8pub struct Matrix<R>
9where
10 R: SemiRing,
11{
12 pub shape: (usize, usize),
13 pub data: Vec<Vec<R::T>>,
14 _marker: PhantomData<fn() -> R>,
15}
16
17impl<R> Debug for Matrix<R>
18where
19 R: SemiRing,
20 R::T: Debug,
21{
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 f.debug_struct("Matrix")
24 .field("shape", &self.shape)
25 .field("data", &self.data)
26 .field("_marker", &self._marker)
27 .finish()
28 }
29}
30
31impl<R> Clone for Matrix<R>
32where
33 R: SemiRing,
34{
35 fn clone(&self) -> Self {
36 Self {
37 shape: self.shape,
38 data: self.data.clone(),
39 _marker: self._marker,
40 }
41 }
42}
43
44impl<R> PartialEq for Matrix<R>
45where
46 R: SemiRing,
47 R::T: PartialEq,
48{
49 fn eq(&self, other: &Self) -> bool {
50 self.shape == other.shape && self.data == other.data
51 }
52}
53
54impl<R> Eq for Matrix<R>
55where
56 R: SemiRing,
57 R::T: Eq,
58{
59}
60
61impl<R> Matrix<R>
62where
63 R: SemiRing,
64{
65 pub fn new(shape: (usize, usize), z: R::T) -> Self {
66 Self {
67 shape,
68 data: vec![vec![z; shape.1]; shape.0],
69 _marker: PhantomData,
70 }
71 }
72
73 pub fn from_vec(data: Vec<Vec<R::T>>) -> Self {
74 let shape = (data.len(), data.first().map(Vec::len).unwrap_or_default());
75 assert!(data.iter().all(|r| r.len() == shape.1));
76 Self {
77 shape,
78 data,
79 _marker: PhantomData,
80 }
81 }
82
83 pub fn new_with(shape: (usize, usize), mut f: impl FnMut(usize, usize) -> R::T) -> Self {
84 let data = (0..shape.0)
85 .map(|i| (0..shape.1).map(|j| f(i, j)).collect())
86 .collect();
87 Self {
88 shape,
89 data,
90 _marker: PhantomData,
91 }
92 }
93
94 pub fn zeros(shape: (usize, usize)) -> Self {
95 Self {
96 shape,
97 data: vec![vec![R::zero(); shape.1]; shape.0],
98 _marker: PhantomData,
99 }
100 }
101
102 pub fn eye(shape: (usize, usize)) -> Self {
103 let mut data = vec![vec![R::zero(); shape.1]; shape.0];
104 for (i, d) in data.iter_mut().enumerate().take(shape.1) {
105 d[i] = R::one();
106 }
107 Self {
108 shape,
109 data,
110 _marker: PhantomData,
111 }
112 }
113
114 pub fn transpose(&self) -> Self {
115 Self::new_with((self.shape.1, self.shape.0), |i, j| self[j][i].clone())
116 }
117
118 pub fn map<S, F>(&self, mut f: F) -> Matrix<S>
119 where
120 S: SemiRing,
121 F: FnMut(&R::T) -> S::T,
122 {
123 Matrix::<S>::new_with(self.shape, |i, j| f(&self[i][j]))
124 }
125
126 pub fn add_row_with(&mut self, mut f: impl FnMut(usize, usize) -> R::T) {
127 self.data
128 .push((0..self.shape.1).map(|j| f(self.shape.0, j)).collect());
129 self.shape.0 += 1;
130 }
131
132 pub fn add_col_with(&mut self, mut f: impl FnMut(usize, usize) -> R::T) {
133 for i in 0..self.shape.0 {
134 self.data[i].push(f(i, self.shape.1));
135 }
136 self.shape.1 += 1;
137 }
138
139 pub fn pairwise_assign<F>(&mut self, other: &Self, mut f: F)
140 where
141 F: FnMut(&mut R::T, &R::T),
142 {
143 assert_eq!(self.shape, other.shape);
144 for i in 0..self.shape.0 {
145 for j in 0..self.shape.1 {
146 f(&mut self[i][j], &other[i][j]);
147 }
148 }
149 }
150}
151
152#[derive(Debug)]
153pub struct SystemOfLinearEquationsSolution<R>
154where
155 R: Field,
156 R::Additive: Invertible,
157 R::Multiplicative: Invertible,
158{
159 pub particular: Vec<R::T>,
160 pub basis: Vec<Vec<R::T>>,
161}
162
163impl<R> Matrix<R>
164where
165 R: Field,
166 R::Additive: Invertible,
167 R::Multiplicative: Invertible,
168 R::T: PartialEq,
169{
170 pub fn row_reduction_with<F>(&mut self, normalize: bool, mut f: F)
172 where
173 F: FnMut(usize, usize, usize),
174 {
175 let (n, m) = self.shape;
176 let mut c = 0;
177 for r in 0..n {
178 loop {
179 if c >= m {
180 return;
181 }
182 if let Some(pivot) = (r..n).find(|&p| !R::is_zero(&self[p][c])) {
183 f(r, pivot, c);
184 self.data.swap(r, pivot);
185 break;
186 };
187 c += 1;
188 }
189 let d = R::inv(&self[r][c]);
190 if normalize {
191 for j in c..m {
192 R::mul_assign(&mut self[r][j], &d);
193 }
194 }
195 for i in (0..n).filter(|&i| i != r) {
196 let mut e = self[i][c].clone();
197 if !normalize {
198 R::mul_assign(&mut e, &d);
199 }
200 for j in c..m {
201 let e = R::mul(&e, &self[r][j]);
202 R::sub_assign(&mut self[i][j], &e);
203 }
204 }
205 c += 1;
206 }
207 }
208
209 pub fn row_reduction(&mut self, normalize: bool) {
210 self.row_reduction_with(normalize, |_, _, _| {});
211 }
212
213 pub fn rank(&mut self) -> usize {
214 let n = self.shape.0;
215 self.row_reduction(false);
216 (0..n)
217 .filter(|&i| !self.data[i].iter().all(|x| R::is_zero(x)))
218 .count()
219 }
220
221 pub fn determinant(&mut self) -> R::T {
222 assert_eq!(self.shape.0, self.shape.1);
223 let mut neg = false;
224 self.row_reduction_with(false, |r, p, _| neg ^= r != p);
225 let mut d = R::one();
226 if neg {
227 d = R::neg(&d);
228 }
229 for i in 0..self.shape.0 {
230 R::mul_assign(&mut d, &self[i][i]);
231 }
232 d
233 }
234
235 pub fn solve_system_of_linear_equations(
236 &self,
237 b: &[R::T],
238 ) -> Option<SystemOfLinearEquationsSolution<R>> {
239 assert_eq!(self.shape.0, b.len());
240 let (n, m) = self.shape;
241 let mut c = Matrix::<R>::zeros((n, m + 1));
242 for i in 0..n {
243 c[i][..m].clone_from_slice(&self[i]);
244 c[i][m] = b[i].clone();
245 }
246 let mut reduced = vec![!0; m + 1];
247 c.row_reduction_with(true, |r, _, c| reduced[c] = r);
248 if reduced[m] != !0 {
249 return None;
250 }
251 let mut particular = vec![R::zero(); m];
252 let mut basis = vec![];
253 for j in 0..m {
254 if reduced[j] != !0 {
255 particular[j] = c[reduced[j]][m].clone();
256 } else {
257 let mut v = vec![R::zero(); m];
258 v[j] = R::one();
259 for i in 0..m {
260 if reduced[i] != !0 {
261 R::sub_assign(&mut v[i], &c[reduced[i]][j]);
262 }
263 }
264 basis.push(v);
265 }
266 }
267 Some(SystemOfLinearEquationsSolution { particular, basis })
268 }
269
270 pub fn inverse(&self) -> Option<Matrix<R>> {
271 assert_eq!(self.shape.0, self.shape.1);
272 let n = self.shape.0;
273 let mut c = Matrix::<R>::zeros((n, n * 2));
274 for i in 0..n {
275 c[i][..n].clone_from_slice(&self[i]);
276 c[i][n + i] = R::one();
277 }
278 c.row_reduction(true);
279 if (0..n).any(|i| R::is_zero(&c[i][i])) {
280 None
281 } else {
282 Some(Self::from_vec(
283 c.data.into_iter().map(|r| r[n..].to_vec()).collect(),
284 ))
285 }
286 }
287
288 pub fn characteristic_polynomial(&mut self) -> Vec<R::T> {
289 let n = self.shape.0;
290 if n == 0 {
291 return vec![R::one()];
292 }
293 assert!(self.data.iter().all(|a| a.len() == n));
294 for j in 0..(n - 1) {
295 if let Some(x) = ((j + 1)..n).find(|&x| !R::is_zero(&self[x][j])) {
296 self.data.swap(j + 1, x);
297 self.data.iter_mut().for_each(|a| a.swap(j + 1, x));
298 let inv = R::inv(&self[j + 1][j]);
299 let mut v = vec![];
300 let src = std::mem::take(&mut self[j + 1]);
301 for a in self.data[(j + 2)..].iter_mut() {
302 let mul = R::mul(&a[j], &inv);
303 for (a, src) in a[j..].iter_mut().zip(src[j..].iter()) {
304 R::sub_assign(a, &R::mul(&mul, src));
305 }
306 v.push(mul);
307 }
308 self[j + 1] = src;
309 for a in self.data.iter_mut() {
310 let v = a[(j + 2)..]
311 .iter()
312 .zip(v.iter())
313 .fold(R::zero(), |s, a| R::add(&s, &R::mul(a.0, a.1)));
314 R::add_assign(&mut a[j + 1], &v);
315 }
316 }
317 }
318 let mut dp = vec![vec![R::one()]];
319 for i in 0..n {
320 let mut next = vec![R::zero(); i + 2];
321 for (j, dp) in dp[i].iter().enumerate() {
322 R::sub_assign(&mut next[j], &R::mul(dp, &self[i][i]));
323 R::add_assign(&mut next[j + 1], dp);
324 }
325 let mut mul = R::one();
326 for j in (0..i).rev() {
327 mul = R::mul(&mul, &self[j + 1][j]);
328 let c = R::mul(&mul, &self[j][i]);
329 for (next, dp) in next.iter_mut().zip(dp[j].iter()) {
330 R::sub_assign(next, &R::mul(&c, dp));
331 }
332 }
333 dp.push(next);
334 }
335 dp.pop().unwrap()
336 }
337}
338
339impl<R> Index<usize> for Matrix<R>
340where
341 R: SemiRing,
342{
343 type Output = Vec<R::T>;
344 fn index(&self, index: usize) -> &Self::Output {
345 &self.data[index]
346 }
347}
348
349impl<R> IndexMut<usize> for Matrix<R>
350where
351 R: SemiRing,
352{
353 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
354 &mut self.data[index]
355 }
356}
357
358impl<R> Index<(usize, usize)> for Matrix<R>
359where
360 R: SemiRing,
361{
362 type Output = R::T;
363 fn index(&self, index: (usize, usize)) -> &Self::Output {
364 &self.data[index.0][index.1]
365 }
366}
367
368impl<R> IndexMut<(usize, usize)> for Matrix<R>
369where
370 R: SemiRing,
371{
372 fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
373 &mut self.data[index.0][index.1]
374 }
375}
376
377macro_rules! impl_matrix_pairwise_binop {
378 ($imp:ident, $method:ident, $imp_assign:ident, $method_assign:ident $(where [$($clauses:tt)*])?) => {
379 impl<R> $imp_assign for Matrix<R>
380 where
381 R: SemiRing,
382 $($($clauses)*)?
383 {
384 fn $method_assign(&mut self, rhs: Self) {
385 self.pairwise_assign(&rhs, |a, b| R::$method_assign(a, b));
386 }
387 }
388 impl<R> $imp_assign<&Matrix<R>> for Matrix<R>
389 where
390 R: SemiRing,
391 $($($clauses)*)?
392 {
393 fn $method_assign(&mut self, rhs: &Self) {
394 self.pairwise_assign(rhs, |a, b| R::$method_assign(a, b));
395 }
396 }
397 impl<R> $imp for Matrix<R>
398 where
399 R: SemiRing,
400 $($($clauses)*)?
401 {
402 type Output = Matrix<R>;
403 fn $method(mut self, rhs: Self) -> Self::Output {
404 self.$method_assign(rhs);
405 self
406 }
407 }
408 impl<R> $imp<&Matrix<R>> for Matrix<R>
409 where
410 R: SemiRing,
411 $($($clauses)*)?
412 {
413 type Output = Matrix<R>;
414 fn $method(mut self, rhs: &Self) -> Self::Output {
415 self.$method_assign(rhs);
416 self
417 }
418 }
419 impl<R> $imp<Matrix<R>> for &Matrix<R>
420 where
421 R: SemiRing,
422 $($($clauses)*)?
423 {
424 type Output = Matrix<R>;
425 fn $method(self, mut rhs: Matrix<R>) -> Self::Output {
426 rhs.pairwise_assign(self, |a, b| *a = R::$method(b, a));
427 rhs
428 }
429 }
430 impl<R> $imp<&Matrix<R>> for &Matrix<R>
431 where
432 R: SemiRing,
433 $($($clauses)*)?
434 {
435 type Output = Matrix<R>;
436 fn $method(self, rhs: &Matrix<R>) -> Self::Output {
437 let mut this = self.clone();
438 this.$method_assign(rhs);
439 this
440 }
441 }
442 };
443}
444
445impl_matrix_pairwise_binop!(Add, add, AddAssign, add_assign);
446impl_matrix_pairwise_binop!(Sub, sub, SubAssign, sub_assign where [R::Additive: Invertible]);
447
448impl<R> Mul for Matrix<R>
449where
450 R: SemiRing,
451{
452 type Output = Matrix<R>;
453 fn mul(self, rhs: Self) -> Self::Output {
454 (&self).mul(&rhs)
455 }
456}
457impl<R> Mul<&Matrix<R>> for Matrix<R>
458where
459 R: SemiRing,
460{
461 type Output = Matrix<R>;
462 fn mul(self, rhs: &Matrix<R>) -> Self::Output {
463 (&self).mul(rhs)
464 }
465}
466impl<R> Mul<Matrix<R>> for &Matrix<R>
467where
468 R: SemiRing,
469{
470 type Output = Matrix<R>;
471 fn mul(self, rhs: Matrix<R>) -> Self::Output {
472 self.mul(&rhs)
473 }
474}
475impl<R> Mul<&Matrix<R>> for &Matrix<R>
476where
477 R: SemiRing,
478{
479 type Output = Matrix<R>;
480 fn mul(self, rhs: &Matrix<R>) -> Self::Output {
481 assert_eq!(self.shape.1, rhs.shape.0);
482 let mut res = Matrix::zeros((self.shape.0, rhs.shape.1));
483 for i in 0..self.shape.0 {
484 for k in 0..self.shape.1 {
485 for j in 0..rhs.shape.1 {
486 R::add_assign(&mut res[i][j], &R::mul(&self[i][k], &rhs[k][j]));
487 }
488 }
489 }
490 res
491 }
492}
493
494impl<R> MulAssign<&R::T> for Matrix<R>
495where
496 R: SemiRing,
497{
498 fn mul_assign(&mut self, rhs: &R::T) {
499 for i in 0..self.shape.0 {
500 for j in 0..self.shape.1 {
501 R::mul_assign(&mut self[(i, j)], rhs);
502 }
503 }
504 }
505}
506
507impl<R> Matrix<R>
508where
509 R: SemiRing,
510{
511 pub fn pow(self, mut n: usize) -> Self {
512 assert_eq!(self.shape.0, self.shape.1);
513 let mut res = Matrix::eye(self.shape);
514 let mut x = self;
515 while n > 0 {
516 if n & 1 == 1 {
517 res = &res * &x;
518 }
519 x = &x * &x;
520 n >>= 1;
521 }
522 res
523 }
524}
525
526impl<R> SerdeByteStr for Matrix<R>
527where
528 R: SemiRing,
529 R::T: SerdeByteStr,
530{
531 fn serialize(&self, buf: &mut Vec<u8>) {
532 self.data.serialize(buf);
533 }
534
535 fn deserialize<I>(iter: &mut I) -> Self
536 where
537 I: Iterator<Item = u8>,
538 {
539 Self::from_vec(Vec::deserialize(iter))
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546 use crate::{
547 algebra::AddMulOperation,
548 num::{One, Zero, mint_basic::DynMIntU32},
549 rand, rand_value,
550 tools::{RandomSpec, Xorshift},
551 };
552
553 struct D;
554 impl RandomSpec<DynMIntU32> for D {
555 fn rand(&self, rng: &mut Xorshift) -> DynMIntU32 {
556 DynMIntU32::new_unchecked(rng.random(..DynMIntU32::get_mod()))
557 }
558 }
559
560 fn random_matrix(
561 rng: &mut Xorshift,
562 shape: (usize, usize),
563 ) -> Matrix<AddMulOperation<DynMIntU32>> {
564 if rng.gen_bool(0.5) {
565 Matrix::<AddMulOperation<_>>::new_with(shape, |_, _| rng.random(D))
566 } else if rng.gen_bool(0.5) {
567 let r = rng.randf();
568 Matrix::<AddMulOperation<_>>::new_with(shape, |_, _| {
569 if rng.gen_bool(r) {
570 rng.random(D)
571 } else {
572 DynMIntU32::zero()
573 }
574 })
575 } else {
576 let mut mat = Matrix::<AddMulOperation<_>>::new_with(shape, |_, _| rng.random(D));
577 let i0 = rng.random(0..shape.0);
578 let i1 = rng.random(0..shape.0);
579 let x = rng.random(D);
580 for j in 0..shape.1 {
581 mat[(i0, j)] = mat[(i1, j)] * x;
582 }
583 mat
584 }
585 }
586
587 #[test]
588 fn test_eye() {
589 for n in 0..10 {
590 for m in 0..10 {
591 let result = Matrix::<AddMulOperation<DynMIntU32>>::eye((n, m));
592 let expected = Matrix::new_with((n, m), |i, j| {
593 if i == j {
594 DynMIntU32::one()
595 } else {
596 DynMIntU32::zero()
597 }
598 });
599 assert_eq!(result, expected);
600 }
601 }
602 }
603
604 #[test]
605 fn test_add() {
606 let mut rng = Xorshift::default();
607 for _ in 0..100 {
608 rand!(rng, n: 1..30, m: 1..30);
609 let a = Matrix::<AddMulOperation<_>>::from_vec(rand_value!(rng, [[D; m]; n]));
610 let b = Matrix::<AddMulOperation<_>>::from_vec(rand_value!(rng, [[D; m]; n]));
611 assert_eq!(&a + &b, a.clone() + b.clone());
612 assert_eq!(a.clone() + &b, a.clone() + b.clone());
613 assert_eq!(&a + b.clone(), a.clone() + b.clone());
614 }
615 }
616
617 #[test]
618 fn test_sub() {
619 let mut rng = Xorshift::default();
620 for _ in 0..100 {
621 rand!(rng, n: 1..30, m: 1..30);
622 let a = Matrix::<AddMulOperation<_>>::from_vec(rand_value!(rng, [[D; m]; n]));
623 let b = Matrix::<AddMulOperation<_>>::from_vec(rand_value!(rng, [[D; m]; n]));
624 assert_eq!(&a - &b, a.clone() - b.clone());
625 assert_eq!(a.clone() - &b, a.clone() - b.clone());
626 assert_eq!(&a - b.clone(), a.clone() - b.clone());
627 }
628 }
629
630 #[test]
631 fn test_mul() {
632 let mut rng = Xorshift::default();
633 for _ in 0..100 {
634 rand!(rng, n: 1..30, m: 1..30, l: 1..30);
635 let a = Matrix::<AddMulOperation<_>>::from_vec(rand_value!(rng, [[D; m]; n]));
636 let b = Matrix::<AddMulOperation<_>>::from_vec(rand_value!(rng, [[D; l]; m]));
637 assert_eq!(&a * &b, a.clone() * b.clone());
638 assert_eq!(a.clone() * &b, a.clone() * b.clone());
639 assert_eq!(&a * b.clone(), a.clone() * b.clone());
640 assert_eq!(
641 &a * &b,
642 Matrix::new_with((n, l), |i, j| (0..m).map(|k| a[i][k] * b[k][j]).sum())
643 );
644 let c = rand_value!(rng, D);
645 let mut ac = a.clone();
646 ac *= &c;
647 assert_eq!(ac, Matrix::new_with(a.shape, |i, j| a[i][j] * c));
648 }
649 }
650
651 #[test]
652 fn test_row_reduction() {
653 const Q: usize = 1000;
654 let mut rng = Xorshift::new();
655 let ps = [2, 3, 1_000_000_007];
656 for _ in 0..Q {
657 let m = ps[rng.random(..ps.len())];
658 DynMIntU32::set_mod(m);
659 let n = rng.random(2..=30);
660 let mat = Matrix::<AddMulOperation<_>>::from_vec(rand_value!(rng, [[D; n]; n]));
661 let rank = mat.clone().rank();
662 let inv = mat.inverse();
663 assert_eq!(rank == n, inv.is_some());
664 if let Some(inv) = inv {
665 assert_eq!(&mat * &inv, Matrix::<AddMulOperation<_>>::eye((n, n)));
666 }
667 }
668 }
669
670 #[test]
671 fn test_system_of_linear_equations() {
672 const Q: usize = 1000;
673 let mut rng = Xorshift::new();
674 let ps = [2, 3, 1_000_000_007];
675 for _ in 0..Q {
676 let p = ps[rng.random(..ps.len())];
677 DynMIntU32::set_mod(p);
678 let n = rng.random(1..=30);
679 let m = rng.random(1..=30);
680 let a = random_matrix(&mut rng, (n, m));
681 let b = random_matrix(&mut rng, (1, n))
682 .data
683 .into_iter()
684 .next()
685 .unwrap();
686 if let Some(sol) = a.solve_system_of_linear_equations(&b) {
687 assert_eq!(
688 &a * Matrix::from_vec(vec![sol.particular.clone()]).transpose(),
689 Matrix::from_vec(vec![b.clone()]).transpose()
690 );
691 let c = rand_value!(rng, [D; sol.basis.len()]);
692 let mut x = sol.particular.clone();
693 for (c, v) in c.iter().zip(sol.basis.iter()) {
694 for (x, v) in x.iter_mut().zip(v.iter()) {
695 *x += *c * *v;
696 }
697 }
698 assert_eq!(
699 &a * Matrix::from_vec(vec![x]).transpose(),
700 Matrix::from_vec(vec![b]).transpose()
701 );
702 }
703 }
704 }
705}