competitive/math/
arbitrary_mod_binomial.rs

1use super::{BarrettReduction, Unsigned, prime_factors, solve_simultaneous_linear_congruence};
2
3fn pow64(x: u64, mut y: u64, br: &BarrettReduction<u128>) -> u64 {
4    let mut x = x as u128;
5    let mut z: u128 = 1;
6    while y > 0 {
7        if y & 1 == 1 {
8            z = br.rem(z * x);
9        }
10        x = br.rem(x * x);
11        y >>= 1;
12    }
13    z as u64
14}
15
16fn pow32(x: u32, mut y: u64, br: &BarrettReduction<u64>) -> u32 {
17    let mut x = x as u64;
18    let mut z: u64 = 1;
19    while y > 0 {
20        if y & 1 == 1 {
21            z = br.rem(z * x);
22        }
23        x = br.rem(x * x);
24        y >>= 1;
25    }
26    z as u32
27}
28
29#[derive(Debug)]
30struct PrimePowerBinomial {
31    p: u64,
32    e: u32,
33    m: u64,
34    size: usize,
35    fact: Vec<u64>,
36    inv_fact: Vec<u64>,
37    delta: u64,
38    bp: BarrettReduction<u64>,
39    bm: BarrettReduction<u64>,
40    bm128: BarrettReduction<u128>,
41}
42
43impl PrimePowerBinomial {
44    fn new(p: u64, e: u32, max_n: u64) -> Self {
45        let m = p.checked_pow(e).expect("prime power overflow");
46        let bp = BarrettReduction::new(p);
47        let bm = BarrettReduction::new(m);
48        let bm128 = BarrettReduction::new(m as u128);
49        let size = max_n.min(m - 1);
50        assert!(size < usize::MAX as u64);
51        let size = size as usize;
52        let mut fact = vec![1u64; size + 1];
53        let mut inv_fact = vec![1u64; size + 1];
54        if m < 1 << 31 {
55            for i in 2..=size {
56                fact[i] = if bp.rem(i as u64) == 0 {
57                    fact[i - 1]
58                } else {
59                    bm.rem(fact[i - 1] * i as u64)
60                };
61            }
62            inv_fact[size] = fact[size].mod_inv(m);
63            for i in (3..=size).rev() {
64                inv_fact[i - 1] = if bp.rem(i as u64) == 0 {
65                    inv_fact[i]
66                } else {
67                    bm.rem(inv_fact[i] * i as u64)
68                };
69            }
70        } else {
71            for i in 2..=size {
72                fact[i] = if bp.rem(i as u64) == 0 {
73                    fact[i - 1]
74                } else {
75                    bm128.rem(fact[i - 1] as u128 * i as u128) as u64
76                };
77            }
78            inv_fact[size] = fact[size].mod_inv(m);
79            for i in (3..=size).rev() {
80                inv_fact[i - 1] = if bp.rem(i as u64) == 0 {
81                    inv_fact[i]
82                } else {
83                    bm128.rem(inv_fact[i] as u128 * i as u128) as u64
84                };
85            }
86        }
87        let delta = if p == 2 && e >= 3 { 1 } else { m - 1 };
88        Self {
89            p,
90            e,
91            m,
92            size,
93            fact,
94            inv_fact,
95            delta,
96            bp,
97            bm,
98            bm128,
99        }
100    }
101
102    fn combination(&self, mut n: u64, mut k: u64) -> u64 {
103        if k > n {
104            return 0;
105        }
106        assert!(self.size as u64 >= n.min(self.m - 1));
107        if self.m < 1 << 31 {
108            let mut res = 1u64;
109            if self.e == 1 {
110                while n > 0 {
111                    let (nn, n0) = self.bp.div_rem(n);
112                    let (nk, k0) = self.bp.div_rem(k);
113                    if n0 < k0 {
114                        return 0;
115                    }
116                    res = self.bm.rem(res * self.fact[n0 as usize]);
117                    res = self.bm.rem(res * self.inv_fact[k0 as usize]);
118                    res = self.bm.rem(res * self.inv_fact[(n0 - k0) as usize]);
119                    n = nn;
120                    k = nk;
121                }
122            } else {
123                let mut r = n - k;
124                let mut e0 = 0;
125                let mut eq = 0;
126                let mut i = 0;
127                while n > 0 {
128                    res = self.bm.rem(res * self.fact[self.bm.rem(n) as usize]);
129                    res = self.bm.rem(res * self.inv_fact[self.bm.rem(k) as usize]);
130                    res = self.bm.rem(res * self.inv_fact[self.bm.rem(r) as usize]);
131                    n = self.bp.div(n);
132                    k = self.bp.div(k);
133                    r = self.bp.div(r);
134                    let eps = n - k - r;
135                    e0 += eps;
136                    if e0 >= self.e as u64 {
137                        return 0;
138                    }
139                    i += 1;
140                    if i >= self.e {
141                        eq ^= eps & 1;
142                    }
143                }
144                if eq == 1 {
145                    res = self.bm.rem(res * self.delta);
146                }
147                res = self
148                    .bm
149                    .rem(res * pow32(self.p as _, e0 as _, &self.bm) as u64);
150            }
151            res
152        } else {
153            let mut res = 1u128;
154            if self.e == 1 {
155                while n > 0 {
156                    let (nn, n0) = self.bp.div_rem(n);
157                    let (nk, k0) = self.bp.div_rem(k);
158                    if n0 < k0 {
159                        return 0;
160                    }
161                    res = self.bm128.rem(res * self.fact[n0 as usize] as u128);
162                    res = self.bm128.rem(res * self.inv_fact[k0 as usize] as u128);
163                    res = self
164                        .bm128
165                        .rem(res * self.inv_fact[(n0 - k0) as usize] as u128);
166                    n = nn;
167                    k = nk;
168                }
169            } else {
170                let mut r = n - k;
171                let mut e0 = 0;
172                let mut eq = 0;
173                let mut i = 0;
174                while n > 0 {
175                    res = self
176                        .bm128
177                        .rem(res * self.fact[self.bm.rem(n) as usize] as u128);
178                    res = self
179                        .bm128
180                        .rem(res * self.inv_fact[self.bm.rem(k) as usize] as u128);
181                    res = self
182                        .bm128
183                        .rem(res * self.inv_fact[self.bm.rem(r) as usize] as u128);
184                    n = self.bp.div(n);
185                    k = self.bp.div(k);
186                    r = self.bp.div(r);
187                    let eps = n - k - r;
188                    e0 += eps;
189                    if e0 >= self.e as u64 {
190                        return 0;
191                    }
192                    i += 1;
193                    if i >= self.e {
194                        eq ^= eps & 1;
195                    }
196                }
197                if eq == 1 {
198                    res = self.bm128.rem(res * self.delta as u128);
199                }
200                res = self.bm128.rem(res * pow64(self.p, e0, &self.bm128) as u128);
201            }
202            res as u64
203        }
204    }
205}
206
207#[derive(Debug)]
208pub struct ArbitraryModBinomial {
209    ppbs: Vec<PrimePowerBinomial>,
210}
211
212impl ArbitraryModBinomial {
213    pub fn new(modulus: u64, max_n: u64) -> Self {
214        assert_ne!(modulus, 0);
215        let ppbs = prime_factors(modulus)
216            .into_iter()
217            .map(|(p, e)| PrimePowerBinomial::new(p, e, max_n))
218            .collect();
219        Self { ppbs }
220    }
221
222    pub fn combination(&self, n: u64, k: u64) -> u64 {
223        solve_simultaneous_linear_congruence(
224            self.ppbs
225                .iter()
226                .map(|ppb| (1u64, ppb.combination(n, k), ppb.m)),
227        )
228        .unwrap()
229        .0
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use crate::{
237        math::MemorizedFactorial,
238        num::mint_basic::{MInt998244353, Modulo998244353},
239        tools::Xorshift,
240    };
241
242    #[test]
243    fn test_arbitrary_mod_binomial_small_mod() {
244        for m in 1..=200 {
245            let binom = ArbitraryModBinomial::new(m, 100);
246            let mut dp = vec![vec![0u64; 101]; 101];
247            dp[0][0] = 1 % m;
248            for n in 1..=100 {
249                dp[n][0] = 1 % m;
250                for k in 1..=n {
251                    dp[n][k] = dp[n - 1][k - 1].mod_add(dp[n - 1][k], m);
252                }
253            }
254            for n in 0..=100 {
255                for k in 0..=100 {
256                    assert_eq!(binom.combination(n, k), dp[n as usize][k as usize]);
257                }
258            }
259        }
260    }
261
262    #[test]
263    fn test_arbitrary_mod_binomial_large_mod() {
264        let mut rng = Xorshift::default();
265        for i in 1..=200 {
266            let m = if i <= 2 {
267                (1 << 31) + 1 - i
268            } else {
269                rng.random(1..=1_000_000_000_000u64)
270            };
271            let binom = ArbitraryModBinomial::new(m, 100);
272            let mut dp = vec![vec![0u64; 101]; 101];
273            dp[0][0] = 1 % m;
274            for n in 1..=100 {
275                dp[n][0] = 1 % m;
276                for k in 1..=n {
277                    dp[n][k] = dp[n - 1][k - 1].mod_add(dp[n - 1][k], m);
278                }
279            }
280            for n in 0..=100 {
281                for k in 0..=100 {
282                    assert_eq!(binom.combination(n, k), dp[n as usize][k as usize]);
283                }
284            }
285        }
286    }
287
288    #[test]
289    fn test_arbitrary_mod_binomial_prime_mod() {
290        let mut rng = Xorshift::default();
291        let binom = ArbitraryModBinomial::new(MInt998244353::get_mod() as _, 1_000_000);
292        let fact = MemorizedFactorial::<Modulo998244353>::new(1_000_000);
293        for _ in 0..100_000 {
294            let n = rng.random(0..=1_000_000);
295            let k = rng.random(0..=n);
296            assert_eq!(
297                binom.combination(n, k),
298                fact.combination(n as _, k as _).inner() as u64
299            );
300        }
301    }
302}