competitive/math/
black_box_matrix.rs

1use super::{
2    AddMulOperation, ConvolveSteps, FormalPowerSeries, MInt, MIntBase, MIntConvert, Matrix, One,
3    SemiRing, Xorshift, Zero, berlekamp_massey,
4};
5use std::{
6    fmt::{self, Debug},
7    marker::PhantomData,
8};
9
10pub trait BlackBoxMatrix<R>
11where
12    R: SemiRing,
13{
14    fn apply(&self, v: &[R::T]) -> Vec<R::T>;
15
16    fn shape(&self) -> (usize, usize);
17}
18
19impl<R> BlackBoxMatrix<R> for Matrix<R>
20where
21    R: SemiRing,
22{
23    fn apply(&self, v: &[R::T]) -> Vec<R::T> {
24        assert_eq!(self.shape.1, v.len());
25        let mut res = vec![R::zero(); self.shape.0];
26        for i in 0..self.shape.0 {
27            for j in 0..self.shape.1 {
28                R::add_assign(&mut res[i], &R::mul(&self[(i, j)], &v[j]));
29            }
30        }
31        res
32    }
33
34    fn shape(&self) -> (usize, usize) {
35        self.shape
36    }
37}
38
39pub struct SparseMatrix<R>
40where
41    R: SemiRing,
42{
43    shape: (usize, usize),
44    nonzero: Vec<(usize, usize, R::T)>,
45}
46
47impl<R> Debug for SparseMatrix<R>
48where
49    R: SemiRing<T: Debug>,
50{
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        f.debug_struct("SparseMatrix")
53            .field("shape", &self.shape)
54            .field("nonzero", &self.nonzero)
55            .finish()
56    }
57}
58
59impl<R> Clone for SparseMatrix<R>
60where
61    R: SemiRing,
62{
63    fn clone(&self) -> Self {
64        Self {
65            shape: self.shape,
66            nonzero: self.nonzero.clone(),
67        }
68    }
69}
70
71impl<R> SparseMatrix<R>
72where
73    R: SemiRing,
74{
75    pub fn new(shape: (usize, usize)) -> Self {
76        Self {
77            shape,
78            nonzero: vec![],
79        }
80    }
81    pub fn new_with<F>(shape: (usize, usize), f: F) -> Self
82    where
83        R: SemiRing<T: PartialEq>,
84        F: Fn(usize, usize) -> R::T,
85    {
86        let mut nonzero = vec![];
87        for i in 0..shape.0 {
88            for j in 0..shape.1 {
89                let v = f(i, j);
90                if !R::is_zero(&v) {
91                    nonzero.push((i, j, v));
92                }
93            }
94        }
95        Self { shape, nonzero }
96    }
97    pub fn from_nonzero(shape: (usize, usize), nonzero: Vec<(usize, usize, R::T)>) -> Self {
98        Self { shape, nonzero }
99    }
100}
101
102impl<R> From<Matrix<R>> for SparseMatrix<R>
103where
104    R: SemiRing<T: PartialEq>,
105{
106    fn from(mat: Matrix<R>) -> Self {
107        let mut nonzero = vec![];
108        for i in 0..mat.shape.0 {
109            for j in 0..mat.shape.1 {
110                let v = mat[(i, j)].clone();
111                if !R::is_zero(&v) {
112                    nonzero.push((i, j, v));
113                }
114            }
115        }
116        Self {
117            shape: mat.shape,
118            nonzero,
119        }
120    }
121}
122
123impl<R> From<SparseMatrix<R>> for Matrix<R>
124where
125    R: SemiRing,
126{
127    fn from(smat: SparseMatrix<R>) -> Self {
128        let mut mat = Matrix::zeros(smat.shape);
129        for &(i, j, ref v) in &smat.nonzero {
130            R::add_assign(&mut mat[(i, j)], v);
131        }
132        mat
133    }
134}
135
136impl<R> BlackBoxMatrix<R> for SparseMatrix<R>
137where
138    R: SemiRing,
139{
140    fn apply(&self, v: &[R::T]) -> Vec<R::T> {
141        assert_eq!(self.shape.1, v.len());
142        let mut res = vec![R::zero(); self.shape.0];
143        for &(i, j, ref val) in &self.nonzero {
144            R::add_assign(&mut res[i], &R::mul(val, &v[j]));
145        }
146        res
147    }
148
149    fn shape(&self) -> (usize, usize) {
150        self.shape
151    }
152}
153
154pub struct BlackBoxMatrixImpl<R, F> {
155    shape: (usize, usize),
156    apply_fn: F,
157    _marker: PhantomData<fn() -> R>,
158}
159
160impl<R, F> Debug for BlackBoxMatrixImpl<R, F>
161where
162    F: Debug,
163{
164    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165        f.debug_struct("BlackBoxMatrixImpl")
166            .field("shape", &self.shape)
167            .field("apply_fn", &self.apply_fn)
168            .finish()
169    }
170}
171
172impl<R, F> Clone for BlackBoxMatrixImpl<R, F>
173where
174    F: Clone,
175{
176    fn clone(&self) -> Self {
177        Self {
178            shape: self.shape,
179            apply_fn: self.apply_fn.clone(),
180            _marker: PhantomData,
181        }
182    }
183}
184
185impl<R, F> BlackBoxMatrixImpl<R, F> {
186    pub fn new(shape: (usize, usize), apply_fn: F) -> Self {
187        Self {
188            shape,
189            apply_fn,
190            _marker: PhantomData,
191        }
192    }
193}
194
195impl<R, F> BlackBoxMatrix<R> for BlackBoxMatrixImpl<R, F>
196where
197    R: SemiRing,
198    F: Fn(&[R::T]) -> Vec<R::T>,
199{
200    fn apply(&self, v: &[R::T]) -> Vec<R::T> {
201        assert_eq!(self.shape.1, v.len());
202        (self.apply_fn)(v)
203    }
204
205    fn shape(&self) -> (usize, usize) {
206        self.shape
207    }
208}
209
210pub trait BlackBoxMIntMatrix<M>: BlackBoxMatrix<AddMulOperation<MInt<M>>>
211where
212    M: MIntBase,
213{
214    fn minimal_polynomial(&self) -> Vec<MInt<M>>
215    where
216        M: MIntConvert<u64>,
217    {
218        assert_eq!(self.shape().0, self.shape().1);
219        let n = self.shape().0;
220        let mut rng = Xorshift::new();
221        let b: Vec<MInt<M>> = (0..n).map(|_| MInt::from(rng.rand64())).collect();
222        let u: Vec<MInt<M>> = (0..n).map(|_| MInt::from(rng.rand64())).collect();
223        let a: Vec<MInt<M>> = (0..2 * n)
224            .scan(b, |b, _| {
225                let a = b.iter().zip(&u).fold(MInt::zero(), |s, (x, y)| s + x * y);
226                *b = self.apply(b);
227                Some(a)
228            })
229            .collect();
230        let mut p = berlekamp_massey(&a);
231        p.reverse();
232        p
233    }
234
235    fn apply_pow<C>(&self, mut b: Vec<MInt<M>>, k: usize) -> Vec<MInt<M>>
236    where
237        M: MIntConvert<usize> + MIntConvert<u64>,
238        C: ConvolveSteps<T = Vec<MInt<M>>>,
239    {
240        assert_eq!(self.shape().0, self.shape().1);
241        assert_eq!(self.shape().1, b.len());
242        let n = self.shape().0;
243        let p = self.minimal_polynomial();
244        let f = FormalPowerSeries::<MInt<M>, C>::from_vec(p).pow_mod(k);
245        let mut res = vec![MInt::zero(); n];
246        for f in f {
247            for j in 0..n {
248                res[j] += f * b[j];
249            }
250            b = self.apply(&b);
251        }
252        res
253    }
254
255    fn black_box_determinant(&self) -> MInt<M>
256    where
257        M: MIntConvert<u64>,
258    {
259        assert_eq!(self.shape().0, self.shape().1);
260        let n = self.shape().0;
261        let mut rng = Xorshift::new();
262        let d: Vec<MInt<M>> = (0..n).map(|_| MInt::from(rng.rand64())).collect();
263        let det_d = d.iter().fold(MInt::one(), |s, x| s * x);
264        let ad = BlackBoxMatrixImpl::<AddMulOperation<MInt<M>>, _>::new(
265            self.shape(),
266            |v: &[MInt<M>]| {
267                let mut w = self.apply(v);
268                for (w, d) in w.iter_mut().zip(&d) {
269                    *w *= d;
270                }
271                w
272            },
273        );
274        let p = ad.minimal_polynomial();
275        let det_ad = if n % 2 == 0 { p[0] } else { -p[0] };
276        det_ad / det_d
277    }
278
279    fn black_box_linear_equation(&self, mut b: Vec<MInt<M>>) -> Option<Vec<MInt<M>>>
280    where
281        M: MIntConvert<u64>,
282    {
283        assert_eq!(self.shape().0, self.shape().1);
284        assert_eq!(self.shape().1, b.len());
285        let n = self.shape().0;
286        let p = self.minimal_polynomial();
287        if p.is_empty() || p[0].is_zero() {
288            return None;
289        }
290        let p0_inv = p[0].inv();
291        let mut x = vec![MInt::zero(); n];
292        for p in p.into_iter().skip(1) {
293            let p = -p * p0_inv;
294            for i in 0..n {
295                x[i] += p * b[i];
296            }
297            b = self.apply(&b);
298        }
299        Some(x)
300    }
301}
302
303impl<M, B> BlackBoxMIntMatrix<M> for B
304where
305    M: MIntBase,
306    B: BlackBoxMatrix<AddMulOperation<MInt<M>>>,
307{
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use crate::{math::Convolve998244353, num::montgomery::MInt998244353, rand};
314
315    type R = AddMulOperation<MInt998244353>;
316
317    fn random_matrix(rng: &mut Xorshift, shape: (usize, usize)) -> Matrix<R> {
318        if rng.gen_bool(0.5) {
319            Matrix::<R>::new_with(shape, |_, _| rng.random(..))
320        } else if rng.gen_bool(0.5) {
321            let r = rng.randf();
322            Matrix::<R>::new_with(shape, |_, _| {
323                if rng.gen_bool(r) {
324                    rng.random(..)
325                } else {
326                    MInt998244353::zero()
327                }
328            })
329        } else {
330            let mut mat = Matrix::<R>::new_with(shape, |_, _| rng.random(..));
331            let i0 = rng.random(0..shape.0);
332            let i1 = rng.random(0..shape.0);
333            let x: MInt998244353 = rng.random(..);
334            for j in 0..shape.1 {
335                mat[(i0, j)] = mat[(i1, j)] * x;
336            }
337            mat
338        }
339    }
340
341    #[test]
342    fn test_apply() {
343        let mut rng = Xorshift::default();
344        for _ in 0..100 {
345            rand!(rng, n: 1..30, m: 1..30);
346            let mat = random_matrix(&mut rng, (n, m));
347            let smat = SparseMatrix::from(mat.clone());
348            let v: Vec<_> = (0..m).map(|_| rng.random(..)).collect();
349            let av = mat.apply(&v);
350            let asv = smat.apply(&v);
351            assert_eq!(av, asv);
352        }
353    }
354
355    #[test]
356    fn test_minimal_polynomial() {
357        let mut rng = Xorshift::default();
358        for _ in 0..100 {
359            rand!(rng, n: 1..30);
360            let a = random_matrix(&mut rng, (n, n));
361            let p = a.minimal_polynomial();
362            assert!(p.len() <= n + 1);
363            let mut res = Matrix::<R>::zeros((n, n));
364            let mut pow = Matrix::<R>::eye((n, n));
365            for p in p {
366                for i in 0..n {
367                    for j in 0..n {
368                        res[(i, j)] += p * pow[(i, j)];
369                    }
370                }
371                pow = &pow * &a;
372            }
373            assert_eq!(res, Matrix::<R>::zeros((n, n)));
374        }
375    }
376
377    #[test]
378    fn test_apply_pow() {
379        let mut rng = Xorshift::default();
380        for _ in 0..100 {
381            rand!(rng, n: 1..30, k: 0..1_000_000_000);
382            let a = random_matrix(&mut rng, (n, n));
383            let b: Vec<_> = (0..n).map(|_| rng.random(..)).collect();
384            let expected = a.clone().pow(k).apply(&b);
385            let result = a.apply_pow::<Convolve998244353>(b, k);
386            assert_eq!(result, expected);
387        }
388    }
389
390    #[test]
391    fn test_black_box_determinant() {
392        let mut rng = Xorshift::default();
393        for _ in 0..100 {
394            rand!(rng, n: 1..30);
395            let mut a = random_matrix(&mut rng, (n, n));
396            let result = a.black_box_determinant();
397            let expected = a.determinant();
398            assert_eq!(result, expected);
399        }
400    }
401
402    #[test]
403    fn test_black_box_linear_equation() {
404        let mut rng = Xorshift::default();
405        for _ in 0..100 {
406            rand!(rng, n: 1..30);
407            let a = random_matrix(&mut rng, (n, n));
408            let b: Vec<_> = (0..n).map(|_| rng.random(..)).collect();
409            let expected = a
410                .solve_system_of_linear_equations(&b)
411                .map(|sol| sol.particular);
412            let result = a.black_box_linear_equation(b);
413            assert_eq!(result, expected);
414        }
415    }
416}