competitive/math/
lagrange_interpolation.rs

1use super::{MInt, MIntBase, MIntConvert, MemorizedFactorial, One, Zero};
2
3pub fn lagrange_interpolation<M>(x: &[MInt<M>], y: &[MInt<M>], t: MInt<M>) -> MInt<M>
4where
5    M: MIntBase,
6{
7    let n = x.len();
8    debug_assert!(n == y.len());
9    x.iter().position(|&x| x == t).map_or_else(
10        || {
11            (0..n)
12                .map(|i| {
13                    y[i] * (0..n)
14                        .filter(|&j| j != i)
15                        .map(|j| (t - x[j]) / (x[i] - x[j]))
16                        .product::<MInt<M>>()
17                })
18                .sum()
19        },
20        |i| y[i],
21    )
22}
23
24impl<M> MemorizedFactorial<M>
25where
26    M: MIntConvert<usize>,
27{
28    /// Lagrange interpolation with (i, f(i)) (0 <= i <= n)
29    pub fn lagrange_interpolation<F>(&self, n: usize, f: F, t: MInt<M>) -> MInt<M>
30    where
31        F: Fn(MInt<M>) -> MInt<M>,
32    {
33        debug_assert!(0 < n && n < M::mod_into() + 1);
34        if usize::from(t) <= n {
35            return f(t);
36        }
37        let mut left = vec![MInt::one(); n + 1];
38        for i in 0..n {
39            left[i + 1] = left[i] * (t - MInt::from(i));
40        }
41        let (mut res, mut right) = (MInt::zero(), MInt::one());
42        for i in (0..=n).rev() {
43            res += f(MInt::from(i)) * left[i] * right * self.inv_fact[i] * self.inv_fact[n - i];
44            right *= MInt::from(i) - t;
45        }
46        res
47    }
48}
49
50pub fn lagrange_interpolation_polynomial<M>(x: &[MInt<M>], y: &[MInt<M>]) -> Vec<MInt<M>>
51where
52    M: MIntBase,
53{
54    let n = x.len() - 1;
55    let mut dp = vec![MInt::zero(); n + 2];
56    let mut ndp = vec![MInt::zero(); n + 2];
57    dp[0] = -x[0];
58    dp[1] = MInt::one();
59    for x in x.iter().skip(1) {
60        for j in 0..=n + 1 {
61            ndp[j] = -dp[j] * x + if j >= 1 { dp[j - 1] } else { MInt::zero() };
62        }
63        std::mem::swap(&mut dp, &mut ndp);
64    }
65    let mut res = vec![MInt::zero(); n + 1];
66    for i in 0..=n {
67        let t = y[i]
68            / (0..=n)
69                .map(|j| if i != j { x[i] - x[j] } else { MInt::one() })
70                .product::<MInt<M>>();
71        if t.is_zero() {
72            continue;
73        } else if x[i].is_zero() {
74            for j in 0..=n {
75                res[j] += dp[j + 1] * t;
76            }
77        } else {
78            let xinv = x[i].inv();
79            let mut pre = MInt::zero();
80            for j in 0..=n {
81                let d = -(dp[j] - pre) * xinv;
82                res[j] += d * t;
83                pre = d;
84            }
85        }
86    }
87    res
88}