competitive/math/formal_power_series/
formal_power_series_impls.rs

1use super::*;
2use std::{
3    cmp::Reverse,
4    collections::BinaryHeap,
5    iter::repeat_with,
6    iter::{FromIterator, once},
7    marker::PhantomData,
8    ops::{Index, IndexMut},
9    slice::{Iter, IterMut},
10};
11
12impl<T, C> FormalPowerSeries<T, C> {
13    pub fn from_vec(data: Vec<T>) -> Self {
14        Self {
15            data,
16            _marker: PhantomData,
17        }
18    }
19    pub fn length(&self) -> usize {
20        self.data.len()
21    }
22    pub fn truncate(&mut self, deg: usize) {
23        self.data.truncate(deg)
24    }
25    pub fn iter(&self) -> Iter<'_, T> {
26        self.data.iter()
27    }
28    pub fn iter_mut(&mut self) -> IterMut<'_, T> {
29        self.data.iter_mut()
30    }
31}
32
33impl<T, C> Clone for FormalPowerSeries<T, C>
34where
35    T: Clone,
36{
37    fn clone(&self) -> Self {
38        Self::from_vec(self.data.clone())
39    }
40}
41impl<T, C> PartialEq for FormalPowerSeries<T, C>
42where
43    T: PartialEq,
44{
45    fn eq(&self, other: &Self) -> bool {
46        self.data.eq(&other.data)
47    }
48}
49impl<T, C> Eq for FormalPowerSeries<T, C> where T: PartialEq {}
50
51impl<T, C> FormalPowerSeries<T, C>
52where
53    T: Zero,
54{
55    pub fn zeros(deg: usize) -> Self {
56        repeat_with(T::zero).take(deg).collect()
57    }
58    pub fn resize(&mut self, deg: usize) {
59        self.data.resize_with(deg, Zero::zero)
60    }
61    pub fn resized(mut self, deg: usize) -> Self {
62        self.resize(deg);
63        self
64    }
65    pub fn reversed(mut self) -> Self {
66        self.data.reverse();
67        self
68    }
69}
70
71impl<T, C> FormalPowerSeries<T, C>
72where
73    T: Zero + PartialEq,
74{
75    pub fn trim_tail_zeros(&mut self) {
76        let mut len = self.length();
77        while len > 0 {
78            if self.data[len - 1].is_zero() {
79                len -= 1;
80            } else {
81                break;
82            }
83        }
84        self.truncate(len);
85    }
86}
87
88impl<T, C> Zero for FormalPowerSeries<T, C>
89where
90    T: PartialEq,
91{
92    fn zero() -> Self {
93        Self::from_vec(Vec::new())
94    }
95}
96impl<T, C> One for FormalPowerSeries<T, C>
97where
98    T: PartialEq + One,
99{
100    fn one() -> Self {
101        Self::from(T::one())
102    }
103}
104
105impl<T, C> IntoIterator for FormalPowerSeries<T, C> {
106    type Item = T;
107    type IntoIter = std::vec::IntoIter<T>;
108    fn into_iter(self) -> Self::IntoIter {
109        self.data.into_iter()
110    }
111}
112impl<'a, T, C> IntoIterator for &'a FormalPowerSeries<T, C> {
113    type Item = &'a T;
114    type IntoIter = Iter<'a, T>;
115    fn into_iter(self) -> Self::IntoIter {
116        self.data.iter()
117    }
118}
119impl<'a, T, C> IntoIterator for &'a mut FormalPowerSeries<T, C> {
120    type Item = &'a mut T;
121    type IntoIter = IterMut<'a, T>;
122    fn into_iter(self) -> Self::IntoIter {
123        self.data.iter_mut()
124    }
125}
126
127impl<T, C> FromIterator<T> for FormalPowerSeries<T, C> {
128    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
129        Self::from_vec(iter.into_iter().collect())
130    }
131}
132
133impl<T, C> Index<usize> for FormalPowerSeries<T, C> {
134    type Output = T;
135    fn index(&self, index: usize) -> &Self::Output {
136        &self.data[index]
137    }
138}
139impl<T, C> IndexMut<usize> for FormalPowerSeries<T, C> {
140    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
141        &mut self.data[index]
142    }
143}
144
145impl<T, C> From<T> for FormalPowerSeries<T, C> {
146    fn from(x: T) -> Self {
147        once(x).collect()
148    }
149}
150impl<T, C> From<Vec<T>> for FormalPowerSeries<T, C> {
151    fn from(data: Vec<T>) -> Self {
152        Self::from_vec(data)
153    }
154}
155
156impl<T, C> FormalPowerSeries<T, C>
157where
158    T: FormalPowerSeriesCoefficient,
159{
160    pub fn prefix_ref(&self, deg: usize) -> Self {
161        if deg < self.length() {
162            Self::from_vec(self.data[..deg].to_vec())
163        } else {
164            self.clone()
165        }
166    }
167    pub fn prefix(mut self, deg: usize) -> Self {
168        self.data.truncate(deg);
169        self
170    }
171    pub fn even(mut self) -> Self {
172        let mut keep = false;
173        self.data.retain(|_| {
174            keep = !keep;
175            keep
176        });
177        self
178    }
179    pub fn odd(mut self) -> Self {
180        let mut keep = true;
181        self.data.retain(|_| {
182            keep = !keep;
183            keep
184        });
185        self
186    }
187    pub fn diff(mut self) -> Self {
188        let mut c = T::one();
189        for x in self.iter_mut().skip(1) {
190            *x *= &c;
191            c += T::one();
192        }
193        if self.length() > 0 {
194            self.data.remove(0);
195        }
196        self
197    }
198    pub fn integral(mut self) -> Self {
199        let n = self.length();
200        self.data.insert(0, Zero::zero());
201        let mut fact = Vec::with_capacity(n + 1);
202        let mut c = T::one();
203        fact.push(c.clone());
204        for _ in 1..n {
205            fact.push(fact.last().cloned().unwrap() * c.clone());
206            c += T::one();
207        }
208        let mut invf = T::one() / (fact.last().cloned().unwrap() * c.clone());
209        for x in self.iter_mut().skip(1).rev() {
210            *x *= invf.clone() * fact.pop().unwrap();
211            invf *= c.clone();
212            c -= T::one();
213        }
214        self
215    }
216    pub fn parity_inversion(mut self) -> Self {
217        self.iter_mut()
218            .skip(1)
219            .step_by(2)
220            .for_each(|x| *x = -x.clone());
221        self
222    }
223    pub fn eval(&self, x: T) -> T {
224        let mut base = T::one();
225        let mut res = T::zero();
226        for a in self.iter() {
227            res += base.clone() * a.clone();
228            base *= x.clone();
229        }
230        res
231    }
232}
233
234impl<T, C> FormalPowerSeries<T, C>
235where
236    T: FormalPowerSeriesCoefficient,
237    C: ConvolveSteps<T = Vec<T>>,
238{
239    pub fn inv(&self, deg: usize) -> Self {
240        debug_assert!(!self[0].is_zero());
241        let mut f = Self::from(T::one() / self[0].clone());
242        let mut i = 1;
243        while i < deg {
244            let g = self.prefix_ref((i * 2).min(deg));
245            let h = f.clone();
246            let mut g = C::transform(g.data, 2 * i);
247            let h = C::transform(h.data, 2 * i);
248            C::multiply(&mut g, &h);
249            let mut g = Self::from_vec(C::inverse_transform(g, 2 * i));
250            g >>= i;
251            let mut g = C::transform(g.data, 2 * i);
252            C::multiply(&mut g, &h);
253            let g = Self::from_vec(C::inverse_transform(g, 2 * i));
254            f.data.extend((-g).into_iter().take(i));
255            i *= 2;
256        }
257        f.truncate(deg);
258        f
259    }
260    pub fn exp(&self, deg: usize) -> Self {
261        debug_assert!(self[0].is_zero());
262        let mut f = Self::one();
263        let mut i = 1;
264        while i < deg {
265            let mut g = -f.log(i * 2);
266            g[0] += T::one();
267            for (g, x) in g.iter_mut().zip(self.iter().take(i * 2)) {
268                *g += x.clone();
269            }
270            f = (f * g).prefix(i * 2);
271            i *= 2;
272        }
273        f.prefix(deg)
274    }
275    pub fn log(&self, deg: usize) -> Self {
276        (self.inv(deg) * self.clone().diff()).integral().prefix(deg)
277    }
278    pub fn pow(&self, rhs: usize, deg: usize) -> Self {
279        if rhs == 0 {
280            return Self::from_vec(
281                once(T::one())
282                    .chain(repeat_with(T::zero))
283                    .take(deg)
284                    .collect(),
285            );
286        }
287        if let Some(k) = self.iter().position(|x| !x.is_zero()) {
288            if k >= (deg + rhs - 1) / rhs {
289                Self::zeros(deg)
290            } else {
291                let mut x0 = self[k].clone();
292                let rev = T::one() / x0.clone();
293                let x = {
294                    let mut x = T::one();
295                    let mut y = rhs;
296                    while y > 0 {
297                        if y & 1 == 1 {
298                            x *= x0.clone();
299                        }
300                        x0 *= x0.clone();
301                        y >>= 1;
302                    }
303                    x
304                };
305                let mut f = (self.clone() * &rev) >> k;
306                f = (f.log(deg) * &T::from(rhs)).exp(deg) * &x;
307                f.truncate(deg - k * rhs);
308                f <<= k * rhs;
309                f
310            }
311        } else {
312            Self::zeros(deg)
313        }
314    }
315}
316
317impl<T, C> FormalPowerSeries<T, C>
318where
319    T: FormalPowerSeriesCoefficientSqrt,
320    C: ConvolveSteps<T = Vec<T>>,
321{
322    pub fn sqrt(&self, deg: usize) -> Option<Self> {
323        if self[0].is_zero() {
324            if let Some(k) = self.iter().position(|x| !x.is_zero()) {
325                if k % 2 != 0 {
326                    return None;
327                } else if deg > k / 2 {
328                    return Some((self >> k).sqrt(deg - k / 2)? << (k / 2));
329                }
330            }
331        } else {
332            let inv2 = T::one() / (T::one() + T::one());
333            let mut f = Self::from(self[0].sqrt_coefficient()?);
334            let mut i = 1;
335            while i < deg {
336                f = (&f + &(self.prefix_ref(i * 2) * f.inv(i * 2))).prefix(i * 2) * &inv2;
337                i *= 2;
338            }
339            f.truncate(deg);
340            return Some(f);
341        }
342        Some(Self::zeros(deg))
343    }
344}
345
346impl<T, C> FormalPowerSeries<T, C>
347where
348    T: FormalPowerSeriesCoefficient,
349    C: ConvolveSteps<T = Vec<T>>,
350{
351    pub fn count_subset_sum<F>(&self, deg: usize, mut inverse: F) -> Self
352    where
353        F: FnMut(usize) -> T,
354    {
355        let n = self.length();
356        let mut f = Self::zeros(n);
357        for i in 1..n {
358            if !self[i].is_zero() {
359                for (j, d) in (0..n).step_by(i).enumerate().skip(1) {
360                    if j & 1 != 0 {
361                        f[d] += self[i].clone() * &inverse(j);
362                    } else {
363                        f[d] -= self[i].clone() * &inverse(j);
364                    }
365                }
366            }
367        }
368        f.exp(deg)
369    }
370    pub fn count_multiset_sum<F>(&self, deg: usize, mut inverse: F) -> Self
371    where
372        F: FnMut(usize) -> T,
373    {
374        let n = self.length();
375        let mut f = Self::zeros(n);
376        for i in 1..n {
377            if !self[i].is_zero() {
378                for (j, d) in (0..n).step_by(i).enumerate().skip(1) {
379                    f[d] += self[i].clone() * &inverse(j);
380                }
381            }
382        }
383        f.exp(deg)
384    }
385    /// [x^n] P(x) / Q(x)
386    pub fn bostan_mori(self, rhs: Self, mut n: usize) -> T {
387        let mut p = self;
388        let mut q = rhs;
389        while n > 0 {
390            let mq = q.clone().parity_inversion();
391            let u = p * mq.clone();
392            p = if n % 2 == 0 { u.even() } else { u.odd() };
393            q = (q * mq).even();
394            n /= 2;
395        }
396        p[0].clone() / q[0].clone()
397    }
398    /// return F(x) where [x^n] P(x) / Q(x) = [x^d-1] P(x) F(x)
399    pub fn bostan_mori_msb(self, n: usize) -> Self {
400        let d = self.length() - 1;
401        if n == 0 {
402            return (Self::one() << (d - 1)) / self[0].clone();
403        }
404        let q = self;
405        let mq = q.clone().parity_inversion();
406        let w = (q * &mq).even().bostan_mori_msb(n / 2);
407        let mut s = Self::zeros(w.length() * 2 - (n % 2));
408        for (i, x) in w.iter().enumerate() {
409            s[i * 2 + (1 - n % 2)] = x.clone();
410        }
411        let len = 2 * d + 1;
412        let ts = C::transform(s.prefix(len).data, len);
413        mq.reversed().middle_product(&ts, len).prefix(d + 1)
414    }
415    /// x^n mod self
416    pub fn pow_mod(self, n: usize) -> Self {
417        let d = self.length() - 1;
418        let q = self.reversed();
419        let u = q.clone().bostan_mori_msb(n);
420        let mut f = (u * q).prefix(d).reversed();
421        f.trim_tail_zeros();
422        f
423    }
424    fn middle_product(self, other: &C::F, deg: usize) -> Self {
425        let n = self.length();
426        let mut s = C::transform(self.reversed().data, deg);
427        C::multiply(&mut s, other);
428        Self::from_vec((C::inverse_transform(s, deg))[n - 1..].to_vec())
429    }
430    pub fn multipoint_evaluation(self, points: &[T]) -> Vec<T> {
431        let n = points.len();
432        if n <= 32 {
433            return points.iter().map(|p| self.eval(p.clone())).collect();
434        }
435        let mut subproduct_tree = Vec::with_capacity(n * 2);
436        subproduct_tree.resize_with(n, Zero::zero);
437        for x in points {
438            subproduct_tree.push(Self::from_vec(vec![-x.clone(), T::one()]));
439        }
440        for i in (1..n).rev() {
441            subproduct_tree[i] = &subproduct_tree[i * 2] * &subproduct_tree[i * 2 + 1];
442        }
443        let mut uptree_t = Vec::with_capacity(n * 2);
444        uptree_t.resize_with(1, Zero::zero);
445        subproduct_tree.reverse();
446        subproduct_tree.pop();
447        let m = self.length();
448        let v = subproduct_tree.pop().unwrap().reversed().resized(m);
449        let s = C::transform(self.data, m * 2);
450        uptree_t.push(v.inv(m).middle_product(&s, m * 2).resized(n).reversed());
451        for i in 1..n {
452            let subl = subproduct_tree.pop().unwrap();
453            let subr = subproduct_tree.pop().unwrap();
454            let (dl, dr) = (subl.length(), subr.length());
455            let len = dl.max(dr) + uptree_t[i].length();
456            let s = C::transform(uptree_t[i].data.to_vec(), len);
457            uptree_t.push(subr.middle_product(&s, len).prefix(dl));
458            uptree_t.push(subl.middle_product(&s, len).prefix(dr));
459        }
460        uptree_t[n..]
461            .iter()
462            .map(|u| u.data.first().cloned().unwrap_or_else(Zero::zero))
463            .collect()
464    }
465    pub fn product_all<I>(iter: I, deg: usize) -> Self
466    where
467        I: IntoIterator<Item = Self>,
468    {
469        let mut heap: BinaryHeap<_> = iter
470            .into_iter()
471            .map(|f| PartialIgnoredOrd(Reverse(f.length()), f))
472            .collect();
473        while let Some(PartialIgnoredOrd(_, x)) = heap.pop() {
474            if let Some(PartialIgnoredOrd(_, y)) = heap.pop() {
475                let z = (x * y).prefix(deg);
476                heap.push(PartialIgnoredOrd(Reverse(z.length()), z));
477            } else {
478                return x;
479            }
480        }
481        Self::one()
482    }
483    pub fn sum_all_rational<I>(iter: I, deg: usize) -> (Self, Self)
484    where
485        I: IntoIterator<Item = (Self, Self)>,
486    {
487        let mut heap: BinaryHeap<_> = iter
488            .into_iter()
489            .map(|(f, g)| PartialIgnoredOrd(Reverse(f.length().max(g.length())), (f, g)))
490            .collect();
491        while let Some(PartialIgnoredOrd(_, (xa, xb))) = heap.pop() {
492            if let Some(PartialIgnoredOrd(_, (ya, yb))) = heap.pop() {
493                let zb = (&xb * &yb).prefix(deg);
494                let za = (xa * yb + ya * xb).prefix(deg);
495                heap.push(PartialIgnoredOrd(
496                    Reverse(za.length().max(zb.length())),
497                    (za, zb),
498                ));
499            } else {
500                return (xa, xb);
501            }
502        }
503        (Self::zero(), Self::one())
504    }
505    pub fn kth_term_of_linearly_recurrence(self, a: Vec<T>, k: usize) -> T {
506        if let Some(x) = a.get(k) {
507            return x.clone();
508        }
509        let p = (Self::from_vec(a).prefix(self.length() - 1) * &self).prefix(self.length() - 1);
510        p.bostan_mori(self, k)
511    }
512    pub fn kth_term(a: Vec<T>, k: usize) -> T {
513        if let Some(x) = a.get(k) {
514            return x.clone();
515        }
516        Self::from_vec(berlekamp_massey(&a)).kth_term_of_linearly_recurrence(a, k)
517    }
518    /// sum_i a_i exp(b_i x)
519    pub fn linear_sum_of_exp<I, F>(iter: I, deg: usize, mut inv_fact: F) -> Self
520    where
521        I: IntoIterator<Item = (T, T)>,
522        F: FnMut(usize) -> T,
523    {
524        let (p, q) = Self::sum_all_rational(
525            iter.into_iter()
526                .map(|(a, b)| (Self::from_vec(vec![a]), Self::from_vec(vec![T::one(), -b]))),
527            deg,
528        );
529        let mut f = (p * q.inv(deg)).prefix(deg);
530        for i in 0..f.length() {
531            f[i] *= inv_fact(i);
532        }
533        f
534    }
535}
536
537impl<M, C> FormalPowerSeries<MInt<M>, C>
538where
539    M: MIntConvert<usize>,
540    C: ConvolveSteps<T = Vec<MInt<M>>>,
541{
542    /// f(x) <- f(x + a)
543    pub fn taylor_shift(mut self, a: MInt<M>, f: &MemorizedFactorial<M>) -> Self {
544        let n = self.length();
545        for i in 0..n {
546            self.data[i] *= f.fact[i];
547        }
548        self.data.reverse();
549        let mut b = a;
550        let mut g = Self::from_vec(f.inv_fact[..n].to_vec());
551        for i in 1..n {
552            g[i] *= b;
553            b *= a;
554        }
555        self *= g;
556        self.truncate(n);
557        self.data.reverse();
558        for i in 0..n {
559            self.data[i] *= f.inv_fact[i];
560        }
561        self
562    }
563}
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568    use crate::{
569        rand,
570        tools::{RandomSpec, Xorshift},
571    };
572
573    struct D;
574    impl RandomSpec<MInt998244353> for D {
575        fn rand(&self, rng: &mut Xorshift) -> MInt998244353 {
576            MInt998244353::new_unchecked(rng.random(..MInt998244353::get_mod()))
577        }
578    }
579
580    #[test]
581    fn test_bostan_mori_msb() {
582        let mut rng = Xorshift::default();
583        for _ in 0..100 {
584            rand!(rng, n: 2..20, t: 0usize..=1, k: 0..[10, 1_000_000_000][t]);
585            let f = Fps998244353::from_vec((0..n - 1).map(|_| rng.random(D)).collect());
586            let g = Fps998244353::from_vec((0..n).map(|_| rng.random(D)).collect());
587            let expected = f.clone().bostan_mori(g.clone(), k);
588            let result = (f * g.bostan_mori_msb(k))[n - 2];
589            assert_eq!(result, expected);
590        }
591    }
592
593    #[test]
594    fn test_pow_mod() {
595        let mut rng = Xorshift::default();
596        for _ in 0..100 {
597            rand!(rng, n: 2..20, t: 0usize..=1, k: 0..[10, 1_000_000_000][t]);
598            let f = Fps998244353::from_vec((0..n).map(|_| rng.random(D)).collect());
599            let mut expected = Fps998244353::one();
600            {
601                let mut p = Fps998244353::one() << 1;
602                let mut k = k;
603                while k > 0 {
604                    if k & 1 == 1 {
605                        expected = (expected * &p) % &f;
606                    }
607                    p = (&p * &p) % &f;
608                    k >>= 1;
609                }
610            }
611
612            let result = f.pow_mod(k);
613            assert_eq!(result, expected);
614        }
615    }
616}