competitive/math/formal_power_series/
formal_power_series_impls.rs1use super::*;
2use std::{
3 cmp::Reverse,
4 collections::BinaryHeap,
5 iter::repeat_with,
6 iter::{FromIterator, once},
7 marker::PhantomData,
8 ops::{Index, IndexMut},
9 slice::{Iter, IterMut},
10};
11
12impl<T, C> FormalPowerSeries<T, C> {
13 pub fn from_vec(data: Vec<T>) -> Self {
14 Self {
15 data,
16 _marker: PhantomData,
17 }
18 }
19 pub fn length(&self) -> usize {
20 self.data.len()
21 }
22 pub fn truncate(&mut self, deg: usize) {
23 self.data.truncate(deg)
24 }
25 pub fn iter(&self) -> Iter<'_, T> {
26 self.data.iter()
27 }
28 pub fn iter_mut(&mut self) -> IterMut<'_, T> {
29 self.data.iter_mut()
30 }
31}
32
33impl<T, C> Clone for FormalPowerSeries<T, C>
34where
35 T: Clone,
36{
37 fn clone(&self) -> Self {
38 Self::from_vec(self.data.clone())
39 }
40}
41impl<T, C> PartialEq for FormalPowerSeries<T, C>
42where
43 T: PartialEq,
44{
45 fn eq(&self, other: &Self) -> bool {
46 self.data.eq(&other.data)
47 }
48}
49impl<T, C> Eq for FormalPowerSeries<T, C> where T: PartialEq {}
50
51impl<T, C> FormalPowerSeries<T, C>
52where
53 T: Zero,
54{
55 pub fn zeros(deg: usize) -> Self {
56 repeat_with(T::zero).take(deg).collect()
57 }
58 pub fn resize(&mut self, deg: usize) {
59 self.data.resize_with(deg, Zero::zero)
60 }
61 pub fn resized(mut self, deg: usize) -> Self {
62 self.resize(deg);
63 self
64 }
65 pub fn reversed(mut self) -> Self {
66 self.data.reverse();
67 self
68 }
69}
70
71impl<T, C> FormalPowerSeries<T, C>
72where
73 T: Zero + PartialEq,
74{
75 pub fn trim_tail_zeros(&mut self) {
76 let mut len = self.length();
77 while len > 0 {
78 if self.data[len - 1].is_zero() {
79 len -= 1;
80 } else {
81 break;
82 }
83 }
84 self.truncate(len);
85 }
86}
87
88impl<T, C> Zero for FormalPowerSeries<T, C>
89where
90 T: PartialEq,
91{
92 fn zero() -> Self {
93 Self::from_vec(Vec::new())
94 }
95}
96impl<T, C> One for FormalPowerSeries<T, C>
97where
98 T: PartialEq + One,
99{
100 fn one() -> Self {
101 Self::from(T::one())
102 }
103}
104
105impl<T, C> IntoIterator for FormalPowerSeries<T, C> {
106 type Item = T;
107 type IntoIter = std::vec::IntoIter<T>;
108 fn into_iter(self) -> Self::IntoIter {
109 self.data.into_iter()
110 }
111}
112impl<'a, T, C> IntoIterator for &'a FormalPowerSeries<T, C> {
113 type Item = &'a T;
114 type IntoIter = Iter<'a, T>;
115 fn into_iter(self) -> Self::IntoIter {
116 self.data.iter()
117 }
118}
119impl<'a, T, C> IntoIterator for &'a mut FormalPowerSeries<T, C> {
120 type Item = &'a mut T;
121 type IntoIter = IterMut<'a, T>;
122 fn into_iter(self) -> Self::IntoIter {
123 self.data.iter_mut()
124 }
125}
126
127impl<T, C> FromIterator<T> for FormalPowerSeries<T, C> {
128 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
129 Self::from_vec(iter.into_iter().collect())
130 }
131}
132
133impl<T, C> Index<usize> for FormalPowerSeries<T, C> {
134 type Output = T;
135 fn index(&self, index: usize) -> &Self::Output {
136 &self.data[index]
137 }
138}
139impl<T, C> IndexMut<usize> for FormalPowerSeries<T, C> {
140 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
141 &mut self.data[index]
142 }
143}
144
145impl<T, C> From<T> for FormalPowerSeries<T, C> {
146 fn from(x: T) -> Self {
147 once(x).collect()
148 }
149}
150impl<T, C> From<Vec<T>> for FormalPowerSeries<T, C> {
151 fn from(data: Vec<T>) -> Self {
152 Self::from_vec(data)
153 }
154}
155
156impl<T, C> FormalPowerSeries<T, C>
157where
158 T: FormalPowerSeriesCoefficient,
159{
160 pub fn prefix_ref(&self, deg: usize) -> Self {
161 if deg < self.length() {
162 Self::from_vec(self.data[..deg].to_vec())
163 } else {
164 self.clone()
165 }
166 }
167 pub fn prefix(mut self, deg: usize) -> Self {
168 self.data.truncate(deg);
169 self
170 }
171 pub fn even(mut self) -> Self {
172 let mut keep = false;
173 self.data.retain(|_| {
174 keep = !keep;
175 keep
176 });
177 self
178 }
179 pub fn odd(mut self) -> Self {
180 let mut keep = true;
181 self.data.retain(|_| {
182 keep = !keep;
183 keep
184 });
185 self
186 }
187 pub fn diff(mut self) -> Self {
188 let mut c = T::one();
189 for x in self.iter_mut().skip(1) {
190 *x *= &c;
191 c += T::one();
192 }
193 if self.length() > 0 {
194 self.data.remove(0);
195 }
196 self
197 }
198 pub fn integral(mut self) -> Self {
199 let n = self.length();
200 self.data.insert(0, Zero::zero());
201 let mut fact = Vec::with_capacity(n + 1);
202 let mut c = T::one();
203 fact.push(c.clone());
204 for _ in 1..n {
205 fact.push(fact.last().cloned().unwrap() * c.clone());
206 c += T::one();
207 }
208 let mut invf = T::one() / (fact.last().cloned().unwrap() * c.clone());
209 for x in self.iter_mut().skip(1).rev() {
210 *x *= invf.clone() * fact.pop().unwrap();
211 invf *= c.clone();
212 c -= T::one();
213 }
214 self
215 }
216 pub fn parity_inversion(mut self) -> Self {
217 self.iter_mut()
218 .skip(1)
219 .step_by(2)
220 .for_each(|x| *x = -x.clone());
221 self
222 }
223 pub fn eval(&self, x: T) -> T {
224 let mut base = T::one();
225 let mut res = T::zero();
226 for a in self.iter() {
227 res += base.clone() * a.clone();
228 base *= x.clone();
229 }
230 res
231 }
232}
233
234impl<T, C> FormalPowerSeries<T, C>
235where
236 T: FormalPowerSeriesCoefficient,
237 C: ConvolveSteps<T = Vec<T>>,
238{
239 pub fn inv(&self, deg: usize) -> Self {
240 debug_assert!(!self[0].is_zero());
241 let mut f = Self::from(T::one() / self[0].clone());
242 let mut i = 1;
243 while i < deg {
244 let g = self.prefix_ref((i * 2).min(deg));
245 let h = f.clone();
246 let mut g = C::transform(g.data, 2 * i);
247 let h = C::transform(h.data, 2 * i);
248 C::multiply(&mut g, &h);
249 let mut g = Self::from_vec(C::inverse_transform(g, 2 * i));
250 g >>= i;
251 let mut g = C::transform(g.data, 2 * i);
252 C::multiply(&mut g, &h);
253 let g = Self::from_vec(C::inverse_transform(g, 2 * i));
254 f.data.extend((-g).into_iter().take(i));
255 i *= 2;
256 }
257 f.truncate(deg);
258 f
259 }
260 pub fn exp(&self, deg: usize) -> Self {
261 debug_assert!(self[0].is_zero());
262 let mut f = Self::one();
263 let mut i = 1;
264 while i < deg {
265 let mut g = -f.log(i * 2);
266 g[0] += T::one();
267 for (g, x) in g.iter_mut().zip(self.iter().take(i * 2)) {
268 *g += x.clone();
269 }
270 f = (f * g).prefix(i * 2);
271 i *= 2;
272 }
273 f.prefix(deg)
274 }
275 pub fn log(&self, deg: usize) -> Self {
276 (self.inv(deg) * self.clone().diff()).integral().prefix(deg)
277 }
278 pub fn pow(&self, rhs: usize, deg: usize) -> Self {
279 if rhs == 0 {
280 return Self::from_vec(
281 once(T::one())
282 .chain(repeat_with(T::zero))
283 .take(deg)
284 .collect(),
285 );
286 }
287 if let Some(k) = self.iter().position(|x| !x.is_zero()) {
288 if k >= (deg + rhs - 1) / rhs {
289 Self::zeros(deg)
290 } else {
291 let mut x0 = self[k].clone();
292 let rev = T::one() / x0.clone();
293 let x = {
294 let mut x = T::one();
295 let mut y = rhs;
296 while y > 0 {
297 if y & 1 == 1 {
298 x *= x0.clone();
299 }
300 x0 *= x0.clone();
301 y >>= 1;
302 }
303 x
304 };
305 let mut f = (self.clone() * &rev) >> k;
306 f = (f.log(deg) * &T::from(rhs)).exp(deg) * &x;
307 f.truncate(deg - k * rhs);
308 f <<= k * rhs;
309 f
310 }
311 } else {
312 Self::zeros(deg)
313 }
314 }
315}
316
317impl<T, C> FormalPowerSeries<T, C>
318where
319 T: FormalPowerSeriesCoefficientSqrt,
320 C: ConvolveSteps<T = Vec<T>>,
321{
322 pub fn sqrt(&self, deg: usize) -> Option<Self> {
323 if self[0].is_zero() {
324 if let Some(k) = self.iter().position(|x| !x.is_zero()) {
325 if k % 2 != 0 {
326 return None;
327 } else if deg > k / 2 {
328 return Some((self >> k).sqrt(deg - k / 2)? << (k / 2));
329 }
330 }
331 } else {
332 let inv2 = T::one() / (T::one() + T::one());
333 let mut f = Self::from(self[0].sqrt_coefficient()?);
334 let mut i = 1;
335 while i < deg {
336 f = (&f + &(self.prefix_ref(i * 2) * f.inv(i * 2))).prefix(i * 2) * &inv2;
337 i *= 2;
338 }
339 f.truncate(deg);
340 return Some(f);
341 }
342 Some(Self::zeros(deg))
343 }
344}
345
346impl<T, C> FormalPowerSeries<T, C>
347where
348 T: FormalPowerSeriesCoefficient,
349 C: ConvolveSteps<T = Vec<T>>,
350{
351 pub fn count_subset_sum<F>(&self, deg: usize, mut inverse: F) -> Self
352 where
353 F: FnMut(usize) -> T,
354 {
355 let n = self.length();
356 let mut f = Self::zeros(n);
357 for i in 1..n {
358 if !self[i].is_zero() {
359 for (j, d) in (0..n).step_by(i).enumerate().skip(1) {
360 if j & 1 != 0 {
361 f[d] += self[i].clone() * &inverse(j);
362 } else {
363 f[d] -= self[i].clone() * &inverse(j);
364 }
365 }
366 }
367 }
368 f.exp(deg)
369 }
370 pub fn count_multiset_sum<F>(&self, deg: usize, mut inverse: F) -> Self
371 where
372 F: FnMut(usize) -> T,
373 {
374 let n = self.length();
375 let mut f = Self::zeros(n);
376 for i in 1..n {
377 if !self[i].is_zero() {
378 for (j, d) in (0..n).step_by(i).enumerate().skip(1) {
379 f[d] += self[i].clone() * &inverse(j);
380 }
381 }
382 }
383 f.exp(deg)
384 }
385 pub fn bostan_mori(self, rhs: Self, mut n: usize) -> T {
387 let mut p = self;
388 let mut q = rhs;
389 while n > 0 {
390 let mq = q.clone().parity_inversion();
391 let u = p * mq.clone();
392 p = if n % 2 == 0 { u.even() } else { u.odd() };
393 q = (q * mq).even();
394 n /= 2;
395 }
396 p[0].clone() / q[0].clone()
397 }
398 pub fn bostan_mori_msb(self, n: usize) -> Self {
400 let d = self.length() - 1;
401 if n == 0 {
402 return (Self::one() << (d - 1)) / self[0].clone();
403 }
404 let q = self;
405 let mq = q.clone().parity_inversion();
406 let w = (q * &mq).even().bostan_mori_msb(n / 2);
407 let mut s = Self::zeros(w.length() * 2 - (n % 2));
408 for (i, x) in w.iter().enumerate() {
409 s[i * 2 + (1 - n % 2)] = x.clone();
410 }
411 let len = 2 * d + 1;
412 let ts = C::transform(s.prefix(len).data, len);
413 mq.reversed().middle_product(&ts, len).prefix(d + 1)
414 }
415 pub fn pow_mod(self, n: usize) -> Self {
417 let d = self.length() - 1;
418 let q = self.reversed();
419 let u = q.clone().bostan_mori_msb(n);
420 let mut f = (u * q).prefix(d).reversed();
421 f.trim_tail_zeros();
422 f
423 }
424 fn middle_product(self, other: &C::F, deg: usize) -> Self {
425 let n = self.length();
426 let mut s = C::transform(self.reversed().data, deg);
427 C::multiply(&mut s, other);
428 Self::from_vec((C::inverse_transform(s, deg))[n - 1..].to_vec())
429 }
430 pub fn multipoint_evaluation(self, points: &[T]) -> Vec<T> {
431 let n = points.len();
432 if n <= 32 {
433 return points.iter().map(|p| self.eval(p.clone())).collect();
434 }
435 let mut subproduct_tree = Vec::with_capacity(n * 2);
436 subproduct_tree.resize_with(n, Zero::zero);
437 for x in points {
438 subproduct_tree.push(Self::from_vec(vec![-x.clone(), T::one()]));
439 }
440 for i in (1..n).rev() {
441 subproduct_tree[i] = &subproduct_tree[i * 2] * &subproduct_tree[i * 2 + 1];
442 }
443 let mut uptree_t = Vec::with_capacity(n * 2);
444 uptree_t.resize_with(1, Zero::zero);
445 subproduct_tree.reverse();
446 subproduct_tree.pop();
447 let m = self.length();
448 let v = subproduct_tree.pop().unwrap().reversed().resized(m);
449 let s = C::transform(self.data, m * 2);
450 uptree_t.push(v.inv(m).middle_product(&s, m * 2).resized(n).reversed());
451 for i in 1..n {
452 let subl = subproduct_tree.pop().unwrap();
453 let subr = subproduct_tree.pop().unwrap();
454 let (dl, dr) = (subl.length(), subr.length());
455 let len = dl.max(dr) + uptree_t[i].length();
456 let s = C::transform(uptree_t[i].data.to_vec(), len);
457 uptree_t.push(subr.middle_product(&s, len).prefix(dl));
458 uptree_t.push(subl.middle_product(&s, len).prefix(dr));
459 }
460 uptree_t[n..]
461 .iter()
462 .map(|u| u.data.first().cloned().unwrap_or_else(Zero::zero))
463 .collect()
464 }
465 pub fn product_all<I>(iter: I, deg: usize) -> Self
466 where
467 I: IntoIterator<Item = Self>,
468 {
469 let mut heap: BinaryHeap<_> = iter
470 .into_iter()
471 .map(|f| PartialIgnoredOrd(Reverse(f.length()), f))
472 .collect();
473 while let Some(PartialIgnoredOrd(_, x)) = heap.pop() {
474 if let Some(PartialIgnoredOrd(_, y)) = heap.pop() {
475 let z = (x * y).prefix(deg);
476 heap.push(PartialIgnoredOrd(Reverse(z.length()), z));
477 } else {
478 return x;
479 }
480 }
481 Self::one()
482 }
483 pub fn sum_all_rational<I>(iter: I, deg: usize) -> (Self, Self)
484 where
485 I: IntoIterator<Item = (Self, Self)>,
486 {
487 let mut heap: BinaryHeap<_> = iter
488 .into_iter()
489 .map(|(f, g)| PartialIgnoredOrd(Reverse(f.length().max(g.length())), (f, g)))
490 .collect();
491 while let Some(PartialIgnoredOrd(_, (xa, xb))) = heap.pop() {
492 if let Some(PartialIgnoredOrd(_, (ya, yb))) = heap.pop() {
493 let zb = (&xb * &yb).prefix(deg);
494 let za = (xa * yb + ya * xb).prefix(deg);
495 heap.push(PartialIgnoredOrd(
496 Reverse(za.length().max(zb.length())),
497 (za, zb),
498 ));
499 } else {
500 return (xa, xb);
501 }
502 }
503 (Self::zero(), Self::one())
504 }
505 pub fn kth_term_of_linearly_recurrence(self, a: Vec<T>, k: usize) -> T {
506 if let Some(x) = a.get(k) {
507 return x.clone();
508 }
509 let p = (Self::from_vec(a).prefix(self.length() - 1) * &self).prefix(self.length() - 1);
510 p.bostan_mori(self, k)
511 }
512 pub fn kth_term(a: Vec<T>, k: usize) -> T {
513 if let Some(x) = a.get(k) {
514 return x.clone();
515 }
516 Self::from_vec(berlekamp_massey(&a)).kth_term_of_linearly_recurrence(a, k)
517 }
518 pub fn linear_sum_of_exp<I, F>(iter: I, deg: usize, mut inv_fact: F) -> Self
520 where
521 I: IntoIterator<Item = (T, T)>,
522 F: FnMut(usize) -> T,
523 {
524 let (p, q) = Self::sum_all_rational(
525 iter.into_iter()
526 .map(|(a, b)| (Self::from_vec(vec![a]), Self::from_vec(vec![T::one(), -b]))),
527 deg,
528 );
529 let mut f = (p * q.inv(deg)).prefix(deg);
530 for i in 0..f.length() {
531 f[i] *= inv_fact(i);
532 }
533 f
534 }
535}
536
537impl<M, C> FormalPowerSeries<MInt<M>, C>
538where
539 M: MIntConvert<usize>,
540 C: ConvolveSteps<T = Vec<MInt<M>>>,
541{
542 pub fn taylor_shift(mut self, a: MInt<M>, f: &MemorizedFactorial<M>) -> Self {
544 let n = self.length();
545 for i in 0..n {
546 self.data[i] *= f.fact[i];
547 }
548 self.data.reverse();
549 let mut b = a;
550 let mut g = Self::from_vec(f.inv_fact[..n].to_vec());
551 for i in 1..n {
552 g[i] *= b;
553 b *= a;
554 }
555 self *= g;
556 self.truncate(n);
557 self.data.reverse();
558 for i in 0..n {
559 self.data[i] *= f.inv_fact[i];
560 }
561 self
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568 use crate::{
569 rand,
570 tools::{RandomSpec, Xorshift},
571 };
572
573 struct D;
574 impl RandomSpec<MInt998244353> for D {
575 fn rand(&self, rng: &mut Xorshift) -> MInt998244353 {
576 MInt998244353::new_unchecked(rng.random(..MInt998244353::get_mod()))
577 }
578 }
579
580 #[test]
581 fn test_bostan_mori_msb() {
582 let mut rng = Xorshift::default();
583 for _ in 0..100 {
584 rand!(rng, n: 2..20, t: 0usize..=1, k: 0..[10, 1_000_000_000][t]);
585 let f = Fps998244353::from_vec((0..n - 1).map(|_| rng.random(D)).collect());
586 let g = Fps998244353::from_vec((0..n).map(|_| rng.random(D)).collect());
587 let expected = f.clone().bostan_mori(g.clone(), k);
588 let result = (f * g.bostan_mori_msb(k))[n - 2];
589 assert_eq!(result, expected);
590 }
591 }
592
593 #[test]
594 fn test_pow_mod() {
595 let mut rng = Xorshift::default();
596 for _ in 0..100 {
597 rand!(rng, n: 2..20, t: 0usize..=1, k: 0..[10, 1_000_000_000][t]);
598 let f = Fps998244353::from_vec((0..n).map(|_| rng.random(D)).collect());
599 let mut expected = Fps998244353::one();
600 {
601 let mut p = Fps998244353::one() << 1;
602 let mut k = k;
603 while k > 0 {
604 if k & 1 == 1 {
605 expected = (expected * &p) % &f;
606 }
607 p = (&p * &p) % &f;
608 k >>= 1;
609 }
610 }
611
612 let result = f.pow_mod(k);
613 assert_eq!(result, expected);
614 }
615 }
616}