competitive/math/
arbitrary_mod_binomial.rs1use super::{BarrettReduction, Unsigned, prime_factors, solve_simultaneous_linear_congruence};
2
3fn pow64(x: u64, mut y: u64, br: &BarrettReduction<u128>) -> u64 {
4 let mut x = x as u128;
5 let mut z: u128 = 1;
6 while y > 0 {
7 if y & 1 == 1 {
8 z = br.rem(z * x);
9 }
10 x = br.rem(x * x);
11 y >>= 1;
12 }
13 z as u64
14}
15
16fn pow32(x: u32, mut y: u64, br: &BarrettReduction<u64>) -> u32 {
17 let mut x = x as u64;
18 let mut z: u64 = 1;
19 while y > 0 {
20 if y & 1 == 1 {
21 z = br.rem(z * x);
22 }
23 x = br.rem(x * x);
24 y >>= 1;
25 }
26 z as u32
27}
28
29#[derive(Debug)]
30struct PrimePowerBinomial {
31 p: u64,
32 e: u32,
33 m: u64,
34 size: usize,
35 fact: Vec<u64>,
36 inv_fact: Vec<u64>,
37 delta: u64,
38 bp: BarrettReduction<u64>,
39 bm: BarrettReduction<u64>,
40 bm128: BarrettReduction<u128>,
41}
42
43impl PrimePowerBinomial {
44 fn new(p: u64, e: u32, max_n: u64) -> Self {
45 let m = p.checked_pow(e).expect("prime power overflow");
46 let bp = BarrettReduction::new(p);
47 let bm = BarrettReduction::new(m);
48 let bm128 = BarrettReduction::new(m as u128);
49 let size = max_n.min(m - 1);
50 assert!(size < usize::MAX as u64);
51 let size = size as usize;
52 let mut fact = vec![1u64; size + 1];
53 let mut inv_fact = vec![1u64; size + 1];
54 if m < 1 << 31 {
55 for i in 2..=size {
56 fact[i] = if bp.rem(i as u64) == 0 {
57 fact[i - 1]
58 } else {
59 bm.rem(fact[i - 1] * i as u64)
60 };
61 }
62 inv_fact[size] = fact[size].mod_inv(m);
63 for i in (3..=size).rev() {
64 inv_fact[i - 1] = if bp.rem(i as u64) == 0 {
65 inv_fact[i]
66 } else {
67 bm.rem(inv_fact[i] * i as u64)
68 };
69 }
70 } else {
71 for i in 2..=size {
72 fact[i] = if bp.rem(i as u64) == 0 {
73 fact[i - 1]
74 } else {
75 bm128.rem(fact[i - 1] as u128 * i as u128) as u64
76 };
77 }
78 inv_fact[size] = fact[size].mod_inv(m);
79 for i in (3..=size).rev() {
80 inv_fact[i - 1] = if bp.rem(i as u64) == 0 {
81 inv_fact[i]
82 } else {
83 bm128.rem(inv_fact[i] as u128 * i as u128) as u64
84 };
85 }
86 }
87 let delta = if p == 2 && e >= 3 { 1 } else { m - 1 };
88 Self {
89 p,
90 e,
91 m,
92 size,
93 fact,
94 inv_fact,
95 delta,
96 bp,
97 bm,
98 bm128,
99 }
100 }
101
102 fn combination(&self, mut n: u64, mut k: u64) -> u64 {
103 if k > n {
104 return 0;
105 }
106 assert!(self.size as u64 >= n.min(self.m - 1));
107 if self.m < 1 << 31 {
108 let mut res = 1u64;
109 if self.e == 1 {
110 while n > 0 {
111 let (nn, n0) = self.bp.div_rem(n);
112 let (nk, k0) = self.bp.div_rem(k);
113 if n0 < k0 {
114 return 0;
115 }
116 res = self.bm.rem(res * self.fact[n0 as usize]);
117 res = self.bm.rem(res * self.inv_fact[k0 as usize]);
118 res = self.bm.rem(res * self.inv_fact[(n0 - k0) as usize]);
119 n = nn;
120 k = nk;
121 }
122 } else {
123 let mut r = n - k;
124 let mut e0 = 0;
125 let mut eq = 0;
126 let mut i = 0;
127 while n > 0 {
128 res = self.bm.rem(res * self.fact[self.bm.rem(n) as usize]);
129 res = self.bm.rem(res * self.inv_fact[self.bm.rem(k) as usize]);
130 res = self.bm.rem(res * self.inv_fact[self.bm.rem(r) as usize]);
131 n = self.bp.div(n);
132 k = self.bp.div(k);
133 r = self.bp.div(r);
134 let eps = n - k - r;
135 e0 += eps;
136 if e0 >= self.e as u64 {
137 return 0;
138 }
139 i += 1;
140 if i >= self.e {
141 eq ^= eps & 1;
142 }
143 }
144 if eq == 1 {
145 res = self.bm.rem(res * self.delta);
146 }
147 res = self
148 .bm
149 .rem(res * pow32(self.p as _, e0 as _, &self.bm) as u64);
150 }
151 res
152 } else {
153 let mut res = 1u128;
154 if self.e == 1 {
155 while n > 0 {
156 let (nn, n0) = self.bp.div_rem(n);
157 let (nk, k0) = self.bp.div_rem(k);
158 if n0 < k0 {
159 return 0;
160 }
161 res = self.bm128.rem(res * self.fact[n0 as usize] as u128);
162 res = self.bm128.rem(res * self.inv_fact[k0 as usize] as u128);
163 res = self
164 .bm128
165 .rem(res * self.inv_fact[(n0 - k0) as usize] as u128);
166 n = nn;
167 k = nk;
168 }
169 } else {
170 let mut r = n - k;
171 let mut e0 = 0;
172 let mut eq = 0;
173 let mut i = 0;
174 while n > 0 {
175 res = self
176 .bm128
177 .rem(res * self.fact[self.bm.rem(n) as usize] as u128);
178 res = self
179 .bm128
180 .rem(res * self.inv_fact[self.bm.rem(k) as usize] as u128);
181 res = self
182 .bm128
183 .rem(res * self.inv_fact[self.bm.rem(r) as usize] as u128);
184 n = self.bp.div(n);
185 k = self.bp.div(k);
186 r = self.bp.div(r);
187 let eps = n - k - r;
188 e0 += eps;
189 if e0 >= self.e as u64 {
190 return 0;
191 }
192 i += 1;
193 if i >= self.e {
194 eq ^= eps & 1;
195 }
196 }
197 if eq == 1 {
198 res = self.bm128.rem(res * self.delta as u128);
199 }
200 res = self.bm128.rem(res * pow64(self.p, e0, &self.bm128) as u128);
201 }
202 res as u64
203 }
204 }
205}
206
207#[derive(Debug)]
208pub struct ArbitraryModBinomial {
209 ppbs: Vec<PrimePowerBinomial>,
210}
211
212impl ArbitraryModBinomial {
213 pub fn new(modulus: u64, max_n: u64) -> Self {
214 assert_ne!(modulus, 0);
215 let ppbs = prime_factors(modulus)
216 .into_iter()
217 .map(|(p, e)| PrimePowerBinomial::new(p, e, max_n))
218 .collect();
219 Self { ppbs }
220 }
221
222 pub fn combination(&self, n: u64, k: u64) -> u64 {
223 solve_simultaneous_linear_congruence(
224 self.ppbs
225 .iter()
226 .map(|ppb| (1u64, ppb.combination(n, k), ppb.m)),
227 )
228 .unwrap()
229 .0
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use crate::{
237 math::MemorizedFactorial,
238 num::mint_basic::{MInt998244353, Modulo998244353},
239 tools::Xorshift,
240 };
241
242 #[test]
243 fn test_arbitrary_mod_binomial_small_mod() {
244 for m in 1..=200 {
245 let binom = ArbitraryModBinomial::new(m, 100);
246 let mut dp = vec![vec![0u64; 101]; 101];
247 dp[0][0] = 1 % m;
248 for n in 1..=100 {
249 dp[n][0] = 1 % m;
250 for k in 1..=n {
251 dp[n][k] = dp[n - 1][k - 1].mod_add(dp[n - 1][k], m);
252 }
253 }
254 for n in 0..=100 {
255 for k in 0..=100 {
256 assert_eq!(binom.combination(n, k), dp[n as usize][k as usize]);
257 }
258 }
259 }
260 }
261
262 #[test]
263 fn test_arbitrary_mod_binomial_large_mod() {
264 let mut rng = Xorshift::default();
265 for i in 1..=200 {
266 let m = if i <= 2 {
267 (1 << 31) + 1 - i
268 } else {
269 rng.random(1..=1_000_000_000_000u64)
270 };
271 let binom = ArbitraryModBinomial::new(m, 100);
272 let mut dp = vec![vec![0u64; 101]; 101];
273 dp[0][0] = 1 % m;
274 for n in 1..=100 {
275 dp[n][0] = 1 % m;
276 for k in 1..=n {
277 dp[n][k] = dp[n - 1][k - 1].mod_add(dp[n - 1][k], m);
278 }
279 }
280 for n in 0..=100 {
281 for k in 0..=100 {
282 assert_eq!(binom.combination(n, k), dp[n as usize][k as usize]);
283 }
284 }
285 }
286 }
287
288 #[test]
289 fn test_arbitrary_mod_binomial_prime_mod() {
290 let mut rng = Xorshift::default();
291 let binom = ArbitraryModBinomial::new(MInt998244353::get_mod() as _, 1_000_000);
292 let fact = MemorizedFactorial::<Modulo998244353>::new(1_000_000);
293 for _ in 0..100_000 {
294 let n = rng.random(0..=1_000_000);
295 let k = rng.random(0..=n);
296 assert_eq!(
297 binom.combination(n, k),
298 fact.combination(n as _, k as _).inner() as u64
299 );
300 }
301 }
302}