competitive/math/
black_box_matrix.rs

1use super::{
2    AddMulOperation, ConvolveSteps, FormalPowerSeries, MInt, MIntBase, MIntConvert, Matrix, One,
3    SemiRing, Xorshift, Zero, berlekamp_massey,
4};
5use std::{
6    fmt::{self, Debug},
7    marker::PhantomData,
8};
9
10pub trait BlackBoxMatrix<R>
11where
12    R: SemiRing,
13{
14    fn apply(&self, v: &[R::T]) -> Vec<R::T>;
15
16    fn shape(&self) -> (usize, usize);
17}
18
19impl<R> BlackBoxMatrix<R> for Matrix<R>
20where
21    R: SemiRing,
22{
23    fn apply(&self, v: &[R::T]) -> Vec<R::T> {
24        assert_eq!(self.shape.1, v.len());
25        let mut res = vec![R::zero(); self.shape.0];
26        for i in 0..self.shape.0 {
27            for j in 0..self.shape.1 {
28                R::add_assign(&mut res[i], &R::mul(&self[(i, j)], &v[j]));
29            }
30        }
31        res
32    }
33
34    fn shape(&self) -> (usize, usize) {
35        self.shape
36    }
37}
38
39pub struct SparseMatrix<R>
40where
41    R: SemiRing,
42{
43    shape: (usize, usize),
44    nonzero: Vec<(usize, usize, R::T)>,
45}
46
47impl<R> Debug for SparseMatrix<R>
48where
49    R: SemiRing,
50    R::T: Debug,
51{
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        f.debug_struct("SparseMatrix")
54            .field("shape", &self.shape)
55            .field("nonzero", &self.nonzero)
56            .finish()
57    }
58}
59
60impl<R> Clone for SparseMatrix<R>
61where
62    R: SemiRing,
63{
64    fn clone(&self) -> Self {
65        Self {
66            shape: self.shape,
67            nonzero: self.nonzero.clone(),
68        }
69    }
70}
71
72impl<R> SparseMatrix<R>
73where
74    R: SemiRing,
75{
76    pub fn new(shape: (usize, usize)) -> Self {
77        Self {
78            shape,
79            nonzero: vec![],
80        }
81    }
82    pub fn new_with<F>(shape: (usize, usize), f: F) -> Self
83    where
84        R::T: PartialEq,
85        F: Fn(usize, usize) -> R::T,
86    {
87        let mut nonzero = vec![];
88        for i in 0..shape.0 {
89            for j in 0..shape.1 {
90                let v = f(i, j);
91                if !R::is_zero(&v) {
92                    nonzero.push((i, j, v));
93                }
94            }
95        }
96        Self { shape, nonzero }
97    }
98    pub fn from_nonzero(shape: (usize, usize), nonzero: Vec<(usize, usize, R::T)>) -> Self {
99        Self { shape, nonzero }
100    }
101}
102
103impl<R> From<Matrix<R>> for SparseMatrix<R>
104where
105    R: SemiRing,
106    R::T: PartialEq,
107{
108    fn from(mat: Matrix<R>) -> Self {
109        let mut nonzero = vec![];
110        for i in 0..mat.shape.0 {
111            for j in 0..mat.shape.1 {
112                let v = mat[(i, j)].clone();
113                if !R::is_zero(&v) {
114                    nonzero.push((i, j, v));
115                }
116            }
117        }
118        Self {
119            shape: mat.shape,
120            nonzero,
121        }
122    }
123}
124
125impl<R> From<SparseMatrix<R>> for Matrix<R>
126where
127    R: SemiRing,
128{
129    fn from(smat: SparseMatrix<R>) -> Self {
130        let mut mat = Matrix::zeros(smat.shape);
131        for &(i, j, ref v) in &smat.nonzero {
132            R::add_assign(&mut mat[(i, j)], v);
133        }
134        mat
135    }
136}
137
138impl<R> BlackBoxMatrix<R> for SparseMatrix<R>
139where
140    R: SemiRing,
141{
142    fn apply(&self, v: &[R::T]) -> Vec<R::T> {
143        assert_eq!(self.shape.1, v.len());
144        let mut res = vec![R::zero(); self.shape.0];
145        for &(i, j, ref val) in &self.nonzero {
146            R::add_assign(&mut res[i], &R::mul(val, &v[j]));
147        }
148        res
149    }
150
151    fn shape(&self) -> (usize, usize) {
152        self.shape
153    }
154}
155
156pub struct BlackBoxMatrixImpl<R, F> {
157    shape: (usize, usize),
158    apply_fn: F,
159    _marker: PhantomData<fn() -> R>,
160}
161
162impl<R, F> Debug for BlackBoxMatrixImpl<R, F>
163where
164    F: Debug,
165{
166    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167        f.debug_struct("BlackBoxMatrixImpl")
168            .field("shape", &self.shape)
169            .field("apply_fn", &self.apply_fn)
170            .finish()
171    }
172}
173
174impl<R, F> Clone for BlackBoxMatrixImpl<R, F>
175where
176    F: Clone,
177{
178    fn clone(&self) -> Self {
179        Self {
180            shape: self.shape,
181            apply_fn: self.apply_fn.clone(),
182            _marker: PhantomData,
183        }
184    }
185}
186
187impl<R, F> BlackBoxMatrixImpl<R, F> {
188    pub fn new(shape: (usize, usize), apply_fn: F) -> Self {
189        Self {
190            shape,
191            apply_fn,
192            _marker: PhantomData,
193        }
194    }
195}
196
197impl<R, F> BlackBoxMatrix<R> for BlackBoxMatrixImpl<R, F>
198where
199    R: SemiRing,
200    F: Fn(&[R::T]) -> Vec<R::T>,
201{
202    fn apply(&self, v: &[R::T]) -> Vec<R::T> {
203        assert_eq!(self.shape.1, v.len());
204        (self.apply_fn)(v)
205    }
206
207    fn shape(&self) -> (usize, usize) {
208        self.shape
209    }
210}
211
212pub trait BlackBoxMIntMatrix<M>: BlackBoxMatrix<AddMulOperation<MInt<M>>>
213where
214    M: MIntBase,
215{
216    fn minimal_polynomial(&self) -> Vec<MInt<M>>
217    where
218        M: MIntConvert<u64>,
219    {
220        assert_eq!(self.shape().0, self.shape().1);
221        let n = self.shape().0;
222        let mut rng = Xorshift::new();
223        let b: Vec<MInt<M>> = (0..n).map(|_| MInt::from(rng.rand64())).collect();
224        let u: Vec<MInt<M>> = (0..n).map(|_| MInt::from(rng.rand64())).collect();
225        let a: Vec<MInt<M>> = (0..2 * n)
226            .scan(b, |b, _| {
227                let a = b.iter().zip(&u).fold(MInt::zero(), |s, (x, y)| s + x * y);
228                *b = self.apply(b);
229                Some(a)
230            })
231            .collect();
232        let mut p = berlekamp_massey(&a);
233        p.reverse();
234        p
235    }
236
237    fn apply_pow<C>(&self, mut b: Vec<MInt<M>>, k: usize) -> Vec<MInt<M>>
238    where
239        M: MIntConvert<usize> + MIntConvert<u64>,
240        C: ConvolveSteps<T = Vec<MInt<M>>>,
241    {
242        assert_eq!(self.shape().0, self.shape().1);
243        assert_eq!(self.shape().1, b.len());
244        let n = self.shape().0;
245        let p = self.minimal_polynomial();
246        let f = FormalPowerSeries::<MInt<M>, C>::from_vec(p).pow_mod(k);
247        let mut res = vec![MInt::zero(); n];
248        for f in f {
249            for j in 0..n {
250                res[j] += f * b[j];
251            }
252            b = self.apply(&b);
253        }
254        res
255    }
256
257    fn black_box_determinant(&self) -> MInt<M>
258    where
259        M: MIntConvert<u64>,
260    {
261        assert_eq!(self.shape().0, self.shape().1);
262        let n = self.shape().0;
263        let mut rng = Xorshift::new();
264        let d: Vec<MInt<M>> = (0..n).map(|_| MInt::from(rng.rand64())).collect();
265        let det_d = d.iter().fold(MInt::one(), |s, x| s * x);
266        let ad = BlackBoxMatrixImpl::<AddMulOperation<MInt<M>>, _>::new(
267            self.shape(),
268            |v: &[MInt<M>]| {
269                let mut w = self.apply(v);
270                for (w, d) in w.iter_mut().zip(&d) {
271                    *w *= d;
272                }
273                w
274            },
275        );
276        let p = ad.minimal_polynomial();
277        let det_ad = if n % 2 == 0 { p[0] } else { -p[0] };
278        det_ad / det_d
279    }
280
281    fn black_box_linear_equation(&self, mut b: Vec<MInt<M>>) -> Option<Vec<MInt<M>>>
282    where
283        M: MIntConvert<u64>,
284    {
285        assert_eq!(self.shape().0, self.shape().1);
286        assert_eq!(self.shape().1, b.len());
287        let n = self.shape().0;
288        let p = self.minimal_polynomial();
289        if p.is_empty() || p[0].is_zero() {
290            return None;
291        }
292        let p0_inv = p[0].inv();
293        let mut x = vec![MInt::zero(); n];
294        for p in p.into_iter().skip(1) {
295            let p = -p * p0_inv;
296            for i in 0..n {
297                x[i] += p * b[i];
298            }
299            b = self.apply(&b);
300        }
301        Some(x)
302    }
303}
304
305impl<M, B> BlackBoxMIntMatrix<M> for B
306where
307    M: MIntBase,
308    B: BlackBoxMatrix<AddMulOperation<MInt<M>>>,
309{
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use crate::{math::Convolve998244353, num::montgomery::MInt998244353, rand, tools::RandomSpec};
316
317    struct D;
318    impl RandomSpec<MInt998244353> for D {
319        fn rand(&self, rng: &mut Xorshift) -> MInt998244353 {
320            MInt998244353::new_unchecked(rng.random(..MInt998244353::get_mod()))
321        }
322    }
323
324    fn random_matrix(
325        rng: &mut Xorshift,
326        shape: (usize, usize),
327    ) -> Matrix<AddMulOperation<MInt998244353>> {
328        if rng.gen_bool(0.5) {
329            Matrix::<AddMulOperation<_>>::new_with(shape, |_, _| rng.random(D))
330        } else if rng.gen_bool(0.5) {
331            let r = rng.randf();
332            Matrix::<AddMulOperation<_>>::new_with(shape, |_, _| {
333                if rng.gen_bool(r) {
334                    rng.random(D)
335                } else {
336                    MInt998244353::zero()
337                }
338            })
339        } else {
340            let mut mat = Matrix::<AddMulOperation<_>>::new_with(shape, |_, _| rng.random(D));
341            let i0 = rng.random(0..shape.0);
342            let i1 = rng.random(0..shape.0);
343            let x = rng.random(D);
344            for j in 0..shape.1 {
345                mat[(i0, j)] = mat[(i1, j)] * x;
346            }
347            mat
348        }
349    }
350
351    #[test]
352    fn test_apply() {
353        let mut rng = Xorshift::default();
354        for _ in 0..100 {
355            rand!(rng, n: 1..30, m: 1..30);
356            let mat = random_matrix(&mut rng, (n, m));
357            let smat = SparseMatrix::from(mat.clone());
358            let v: Vec<_> = (0..m).map(|_| rng.random(D)).collect();
359            let av = mat.apply(&v);
360            let asv = smat.apply(&v);
361            assert_eq!(av, asv);
362        }
363    }
364
365    #[test]
366    fn test_minimal_polynomial() {
367        let mut rng = Xorshift::default();
368        for _ in 0..100 {
369            rand!(rng, n: 1..30);
370            let a = random_matrix(&mut rng, (n, n));
371            let p = a.minimal_polynomial();
372            assert!(p.len() <= n + 1);
373            let mut res = Matrix::<AddMulOperation<MInt998244353>>::zeros((n, n));
374            let mut pow = Matrix::<AddMulOperation<MInt998244353>>::eye((n, n));
375            for p in p {
376                for i in 0..n {
377                    for j in 0..n {
378                        res[(i, j)] += p * pow[(i, j)];
379                    }
380                }
381                pow = &pow * &a;
382            }
383            assert_eq!(res, Matrix::<AddMulOperation<MInt998244353>>::zeros((n, n)));
384        }
385    }
386
387    #[test]
388    fn test_apply_pow() {
389        let mut rng = Xorshift::default();
390        for _ in 0..100 {
391            rand!(rng, n: 1..30, k: 0..1_000_000_000);
392            let a = random_matrix(&mut rng, (n, n));
393            let b: Vec<_> = (0..n).map(|_| rng.random(D)).collect();
394            let expected = a.clone().pow(k).apply(&b);
395            let result = a.apply_pow::<Convolve998244353>(b, k);
396            assert_eq!(result, expected);
397        }
398    }
399
400    #[test]
401    fn test_black_box_determinant() {
402        let mut rng = Xorshift::default();
403        for _ in 0..100 {
404            rand!(rng, n: 1..30);
405            let mut a = random_matrix(&mut rng, (n, n));
406            let result = a.black_box_determinant();
407            let expected = a.determinant();
408            assert_eq!(result, expected);
409        }
410    }
411
412    #[test]
413    fn test_black_box_linear_equation() {
414        let mut rng = Xorshift::default();
415        for _ in 0..100 {
416            rand!(rng, n: 1..30);
417            let a = random_matrix(&mut rng, (n, n));
418            let b: Vec<_> = (0..n).map(|_| rng.random(D)).collect();
419            let expected = a
420                .solve_system_of_linear_equations(&b)
421                .map(|sol| sol.particular);
422            let result = a.black_box_linear_equation(b);
423            assert_eq!(result, expected);
424        }
425    }
426}