competitive/math/
factorial.rs1use super::{MInt, MIntConvert, One, Zero};
2
3#[derive(Clone, Debug)]
4pub struct MemorizedFactorial<M>
5where
6 M: MIntConvert<usize>,
7{
8 pub fact: Vec<MInt<M>>,
9 pub inv_fact: Vec<MInt<M>>,
10}
11
12impl<M> MemorizedFactorial<M>
13where
14 M: MIntConvert<usize>,
15{
16 pub fn new(max_n: usize) -> Self {
17 let mut fact = vec![MInt::one(); max_n + 1];
18 let mut inv_fact = vec![MInt::one(); max_n + 1];
19 for i in 2..=max_n {
20 fact[i] = fact[i - 1] * MInt::from(i);
21 }
22 inv_fact[max_n] = fact[max_n].inv();
23 for i in (3..=max_n).rev() {
24 inv_fact[i - 1] = inv_fact[i] * MInt::from(i);
25 }
26 Self { fact, inv_fact }
27 }
28
29 pub fn combination(&self, n: usize, r: usize) -> MInt<M> {
30 debug_assert!(n < self.fact.len());
31 if r <= n {
32 self.fact[n] * self.inv_fact[r] * self.inv_fact[n - r]
33 } else {
34 MInt::zero()
35 }
36 }
37
38 pub fn permutation(&self, n: usize, r: usize) -> MInt<M> {
39 debug_assert!(n < self.fact.len());
40 if r <= n {
41 self.fact[n] * self.inv_fact[n - r]
42 } else {
43 MInt::zero()
44 }
45 }
46
47 pub fn homogeneous_product(&self, n: usize, r: usize) -> MInt<M> {
48 debug_assert!(n + r < self.fact.len() + 1);
49 if n == 0 && r == 0 {
50 MInt::one()
51 } else {
52 self.combination(n + r - 1, r)
53 }
54 }
55
56 pub fn inv(&self, n: usize) -> MInt<M> {
57 debug_assert!(n < self.fact.len());
58 debug_assert!(n > 0);
59 self.inv_fact[n] * self.fact[n - 1]
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66
67 #[test]
68 fn test_factorials() {
69 use crate::num::mint_basic::MInt1000000007;
70 let fact = MemorizedFactorial::new(100);
71 type M = MInt1000000007;
72 for i in 0..101 {
73 assert_eq!(fact.fact[i] * fact.inv_fact[i], M::new(1));
74 }
75 for i in 1..101 {
76 assert_eq!(fact.inv(i), M::new(i as u32).inv());
77 }
78 assert_eq!(fact.combination(10, 0), M::new(1));
79 assert_eq!(fact.combination(10, 1), M::new(10));
80 assert_eq!(fact.combination(10, 5), M::new(252));
81 assert_eq!(fact.combination(10, 6), M::new(210));
82 assert_eq!(fact.combination(10, 10), M::new(1));
83 assert_eq!(fact.combination(10, 11), M::new(0));
84
85 assert_eq!(fact.permutation(10, 0), M::new(1));
86 assert_eq!(fact.permutation(10, 1), M::new(10));
87 assert_eq!(fact.permutation(10, 5), M::new(30240));
88 assert_eq!(fact.permutation(10, 6), M::new(151_200));
89 assert_eq!(fact.permutation(10, 10), M::new(3_628_800));
90 assert_eq!(fact.permutation(10, 11), M::new(0));
91 }
92}