competitive/math/
factorial.rs

1use super::{MInt, MIntConvert, One, Zero};
2
3#[derive(Clone, Debug)]
4pub struct MemorizedFactorial<M>
5where
6    M: MIntConvert<usize>,
7{
8    pub fact: Vec<MInt<M>>,
9    pub inv_fact: Vec<MInt<M>>,
10}
11
12impl<M> MemorizedFactorial<M>
13where
14    M: MIntConvert<usize>,
15{
16    pub fn new(max_n: usize) -> Self {
17        let mut fact = vec![MInt::one(); max_n + 1];
18        let mut inv_fact = vec![MInt::one(); max_n + 1];
19        for i in 2..=max_n {
20            fact[i] = fact[i - 1] * MInt::from(i);
21        }
22        inv_fact[max_n] = fact[max_n].inv();
23        for i in (3..=max_n).rev() {
24            inv_fact[i - 1] = inv_fact[i] * MInt::from(i);
25        }
26        Self { fact, inv_fact }
27    }
28
29    pub fn combination(&self, n: usize, r: usize) -> MInt<M> {
30        debug_assert!(n < self.fact.len());
31        if r <= n {
32            self.fact[n] * self.inv_fact[r] * self.inv_fact[n - r]
33        } else {
34            MInt::zero()
35        }
36    }
37
38    pub fn permutation(&self, n: usize, r: usize) -> MInt<M> {
39        debug_assert!(n < self.fact.len());
40        if r <= n {
41            self.fact[n] * self.inv_fact[n - r]
42        } else {
43            MInt::zero()
44        }
45    }
46
47    pub fn homogeneous_product(&self, n: usize, r: usize) -> MInt<M> {
48        debug_assert!(n + r < self.fact.len() + 1);
49        if n == 0 && r == 0 {
50            MInt::one()
51        } else {
52            self.combination(n + r - 1, r)
53        }
54    }
55
56    pub fn inv(&self, n: usize) -> MInt<M> {
57        debug_assert!(n < self.fact.len());
58        debug_assert!(n > 0);
59        self.inv_fact[n] * self.fact[n - 1]
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66
67    #[test]
68    fn test_factorials() {
69        use crate::num::mint_basic::MInt1000000007;
70        let fact = MemorizedFactorial::new(100);
71        type M = MInt1000000007;
72        for i in 0..101 {
73            assert_eq!(fact.fact[i] * fact.inv_fact[i], M::new(1));
74        }
75        for i in 1..101 {
76            assert_eq!(fact.inv(i), M::new(i as u32).inv());
77        }
78        assert_eq!(fact.combination(10, 0), M::new(1));
79        assert_eq!(fact.combination(10, 1), M::new(10));
80        assert_eq!(fact.combination(10, 5), M::new(252));
81        assert_eq!(fact.combination(10, 6), M::new(210));
82        assert_eq!(fact.combination(10, 10), M::new(1));
83        assert_eq!(fact.combination(10, 11), M::new(0));
84
85        assert_eq!(fact.permutation(10, 0), M::new(1));
86        assert_eq!(fact.permutation(10, 1), M::new(10));
87        assert_eq!(fact.permutation(10, 5), M::new(30240));
88        assert_eq!(fact.permutation(10, 6), M::new(151_200));
89        assert_eq!(fact.permutation(10, 10), M::new(3_628_800));
90        assert_eq!(fact.permutation(10, 11), M::new(0));
91    }
92}