competitive/math/
factorial.rs

1use super::prime_factors;
2use crate::num::{MInt, MIntConvert, One, Zero};
3
4#[codesnip::entry("factorial", include("MIntBase"))]
5#[derive(Clone, Debug)]
6pub struct MemorizedFactorial<M>
7where
8    M: MIntConvert<usize>,
9{
10    pub fact: Vec<MInt<M>>,
11    pub inv_fact: Vec<MInt<M>>,
12}
13#[codesnip::entry("factorial")]
14impl<M> MemorizedFactorial<M>
15where
16    M: MIntConvert<usize>,
17{
18    pub fn new(max_n: usize) -> Self {
19        let mut fact = vec![MInt::one(); max_n + 1];
20        let mut inv_fact = vec![MInt::one(); max_n + 1];
21        for i in 2..=max_n {
22            fact[i] = fact[i - 1] * MInt::from(i);
23        }
24        inv_fact[max_n] = fact[max_n].inv();
25        for i in (3..=max_n).rev() {
26            inv_fact[i - 1] = inv_fact[i] * MInt::from(i);
27        }
28        Self { fact, inv_fact }
29    }
30    #[inline]
31    pub fn combination(&self, n: usize, r: usize) -> MInt<M> {
32        debug_assert!(n < self.fact.len());
33        if r <= n {
34            self.fact[n] * self.inv_fact[r] * self.inv_fact[n - r]
35        } else {
36            MInt::zero()
37        }
38    }
39    #[inline]
40    pub fn permutation(&self, n: usize, r: usize) -> MInt<M> {
41        debug_assert!(n < self.fact.len());
42        if r <= n {
43            self.fact[n] * self.inv_fact[n - r]
44        } else {
45            MInt::zero()
46        }
47    }
48    #[inline]
49    pub fn homogeneous_product(&self, n: usize, r: usize) -> MInt<M> {
50        debug_assert!(n + r < self.fact.len() + 1);
51        if n == 0 && r == 0 {
52            MInt::one()
53        } else {
54            self.combination(n + r - 1, r)
55        }
56    }
57    #[inline]
58    pub fn inv(&self, n: usize) -> MInt<M> {
59        debug_assert!(n < self.fact.len());
60        debug_assert!(n > 0);
61        self.inv_fact[n] * self.fact[n - 1]
62    }
63}
64
65#[codesnip::entry("SmallModMemorizedFactorial", include("MIntBase", "prime_factors"))]
66#[derive(Clone, Debug)]
67pub struct SmallModMemorizedFactorial<M>
68where
69    M: MIntConvert<usize>,
70{
71    p: u32,
72    c: u32,
73    fact: Vec<MInt<M>>,
74    inv_fact: Vec<MInt<M>>,
75    pow: Vec<MInt<M>>,
76}
77#[codesnip::entry("SmallModMemorizedFactorial")]
78impl<M> Default for SmallModMemorizedFactorial<M>
79where
80    M: MIntConvert<usize>,
81{
82    fn default() -> Self {
83        let m = M::mod_into();
84        let pf = prime_factors(m as _);
85        assert!(pf.len() <= 1);
86        let p = pf[0].0 as u32;
87        let c = pf[0].1;
88        let mut fact = vec![MInt::one(); m];
89        let mut inv_fact = vec![MInt::one(); m];
90        let mut pow = vec![MInt::one(); c as usize];
91        for i in 2..m {
92            fact[i] = fact[i - 1]
93                * if i as u32 % p != 0 {
94                    MInt::from(i)
95                } else {
96                    MInt::one()
97                };
98        }
99        inv_fact[m - 1] = fact[m - 1].inv();
100        for i in (3..m).rev() {
101            inv_fact[i - 1] = inv_fact[i]
102                * if i as u32 % p != 0 {
103                    MInt::from(i)
104                } else {
105                    MInt::one()
106                };
107        }
108        for i in 1..c as usize {
109            pow[i] = pow[i - 1] * MInt::from(p as usize);
110        }
111        Self {
112            p,
113            c,
114            fact,
115            inv_fact,
116            pow,
117        }
118    }
119}
120#[codesnip::entry("SmallModMemorizedFactorial")]
121impl<M> SmallModMemorizedFactorial<M>
122where
123    M: MIntConvert<usize>,
124{
125    pub fn new() -> Self {
126        Default::default()
127    }
128    /// n! = a * p^e, c==1
129    pub fn factorial(&self, n: usize) -> (MInt<M>, usize) {
130        let p = self.p as usize;
131        if n == 0 {
132            (MInt::<M>::one(), 0)
133        } else {
134            let e = n / p;
135            let res = self.factorial(e);
136            if e % 2 == 0 {
137                (res.0 * self.fact[n % p], res.1 + e)
138            } else {
139                (res.0 * -self.fact[n % p], res.1 + e)
140            }
141        }
142    }
143    pub fn combination(&self, mut n: usize, mut r: usize) -> MInt<M> {
144        if r > n {
145            return MInt::<M>::zero();
146        }
147        if self.p == 2 && self.c == 1 {
148            return MInt::from(((!n & r) == 0) as usize);
149        }
150        let mut k = n - r;
151        let m = M::mod_into();
152        let p = self.p as usize;
153        let cnte = |mut x: usize| {
154            let mut e = 0usize;
155            while x > 0 {
156                e += x;
157                x /= p;
158            }
159            e
160        };
161        let e0 = cnte(n / p) - cnte(r / p) - cnte(k / p);
162        if e0 >= self.c as usize {
163            return MInt::<M>::zero();
164        }
165        let mut res = self.pow[e0];
166        if (self.p > 2 && self.c >= 2 || self.c == 2)
167            && (cnte(n / m) - cnte(r / m) - cnte(k / m)) % 2 == 1
168        {
169            res = -res;
170        }
171        while n > 0 {
172            res *= self.fact[n % m] * self.inv_fact[r % m] * self.inv_fact[k % m];
173            n /= p;
174            r /= p;
175            k /= p;
176        }
177        res
178    }
179}
180
181#[codesnip::entry("PowPrec", include("MIntBase"))]
182#[derive(Debug, Clone)]
183pub struct PowPrec<M>
184where
185    M: MIntConvert<usize>,
186{
187    sqn: usize,
188    p0: Vec<MInt<M>>,
189    p1: Vec<MInt<M>>,
190}
191#[codesnip::entry("PowPrec")]
192impl<M> PowPrec<M>
193where
194    M: MIntConvert<usize>,
195{
196    pub fn new(a: MInt<M>) -> Self {
197        let sqn = (M::mod_into() as f64).sqrt() as usize + 1;
198        let mut p0 = Vec::with_capacity(sqn);
199        let mut p1 = Vec::with_capacity(sqn);
200        let mut acc = MInt::<M>::one();
201        for _ in 0..sqn {
202            p0.push(acc);
203            acc *= a;
204        }
205        let b = acc;
206        acc = MInt::<M>::one();
207        for _ in 0..sqn {
208            p1.push(acc);
209            acc *= b;
210        }
211        Self { sqn, p0, p1 }
212    }
213    pub fn pow(&self, n: usize) -> MInt<M> {
214        let n = n % (M::mod_into() - 1);
215        let (p, q) = (n / self.sqn, n % self.sqn);
216        self.p1[p] * self.p0[q]
217    }
218    pub fn powi(&self, n: isize) -> MInt<M> {
219        let n = n.rem_euclid(M::mod_into() as isize - 1) as usize;
220        let (p, q) = (n / self.sqn, n % self.sqn);
221        self.p1[p] * self.p0[q]
222    }
223    pub fn inv(&self) -> MInt<M> {
224        self.powi(-1)
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    #[test]
233    fn test_factorials() {
234        use crate::num::mint_basic::MInt1000000007;
235        let fact = MemorizedFactorial::new(100);
236        type M = MInt1000000007;
237        for i in 0..101 {
238            assert_eq!(fact.fact[i] * fact.inv_fact[i], M::new(1));
239        }
240        for i in 1..101 {
241            assert_eq!(fact.inv(i), M::new(i as u32).inv());
242        }
243        assert_eq!(fact.combination(10, 0), M::new(1));
244        assert_eq!(fact.combination(10, 1), M::new(10));
245        assert_eq!(fact.combination(10, 5), M::new(252));
246        assert_eq!(fact.combination(10, 6), M::new(210));
247        assert_eq!(fact.combination(10, 10), M::new(1));
248        assert_eq!(fact.combination(10, 11), M::new(0));
249
250        assert_eq!(fact.permutation(10, 0), M::new(1));
251        assert_eq!(fact.permutation(10, 1), M::new(10));
252        assert_eq!(fact.permutation(10, 5), M::new(30240));
253        assert_eq!(fact.permutation(10, 6), M::new(151_200));
254        assert_eq!(fact.permutation(10, 10), M::new(3_628_800));
255        assert_eq!(fact.permutation(10, 11), M::new(0));
256    }
257
258    #[test]
259    fn test_small_factorials() {
260        use crate::num::mint_basic::DynModuloU32;
261        use crate::tools::Xorshift;
262        let mut rng = Xorshift::new();
263        const Q: usize = 100_000;
264        DynModuloU32::set_mod(2);
265        let fact = SmallModMemorizedFactorial::<DynModuloU32>::new();
266        for _ in 0..Q {
267            let n = rng.random(1..=1_000_000_000_000_000_000);
268            let k = rng.random(0..=n);
269            let x = fact.factorial(n).1 - fact.factorial(k).1 - fact.factorial(n - k).1;
270            assert_eq!(x == 0, (n & k) == k);
271            let x = fact.combination(n, k);
272            assert_eq!(x.is_one(), (n & k) == k);
273        }
274    }
275}