competitive/math/
lagrange_interpolation.rs1use super::{MInt, MIntBase, MIntConvert, MemorizedFactorial, One, Zero};
2
3pub fn lagrange_interpolation<M>(x: &[MInt<M>], y: &[MInt<M>], t: MInt<M>) -> MInt<M>
4where
5 M: MIntBase,
6{
7 let n = x.len();
8 debug_assert!(n == y.len());
9 x.iter().position(|&x| x == t).map_or_else(
10 || {
11 (0..n)
12 .map(|i| {
13 y[i] * (0..n)
14 .filter(|&j| j != i)
15 .map(|j| (t - x[j]) / (x[i] - x[j]))
16 .product::<MInt<M>>()
17 })
18 .sum()
19 },
20 |i| y[i],
21 )
22}
23
24impl<M> MemorizedFactorial<M>
25where
26 M: MIntConvert<usize>,
27{
28 pub fn lagrange_interpolation<F>(&self, n: usize, f: F, t: MInt<M>) -> MInt<M>
30 where
31 F: Fn(MInt<M>) -> MInt<M>,
32 {
33 debug_assert!(0 < n && n < M::mod_into() + 1);
34 if usize::from(t) <= n {
35 return f(t);
36 }
37 let mut left = vec![MInt::one(); n + 1];
38 for i in 0..n {
39 left[i + 1] = left[i] * (t - MInt::from(i));
40 }
41 let (mut res, mut right) = (MInt::zero(), MInt::one());
42 for i in (0..=n).rev() {
43 res += f(MInt::from(i)) * left[i] * right * self.inv_fact[i] * self.inv_fact[n - i];
44 right *= MInt::from(i) - t;
45 }
46 res
47 }
48}
49
50pub fn lagrange_interpolation_polynomial<M>(x: &[MInt<M>], y: &[MInt<M>]) -> Vec<MInt<M>>
51where
52 M: MIntBase,
53{
54 let n = x.len() - 1;
55 let mut dp = vec![MInt::zero(); n + 2];
56 let mut ndp = vec![MInt::zero(); n + 2];
57 dp[0] = -x[0];
58 dp[1] = MInt::one();
59 for x in x.iter().skip(1) {
60 for j in 0..=n + 1 {
61 ndp[j] = -dp[j] * x + if j >= 1 { dp[j - 1] } else { MInt::zero() };
62 }
63 std::mem::swap(&mut dp, &mut ndp);
64 }
65 let mut res = vec![MInt::zero(); n + 1];
66 for i in 0..=n {
67 let t = y[i]
68 / (0..=n)
69 .map(|j| if i != j { x[i] - x[j] } else { MInt::one() })
70 .product::<MInt<M>>();
71 if t.is_zero() {
72 continue;
73 } else if x[i].is_zero() {
74 for j in 0..=n {
75 res[j] += dp[j + 1] * t;
76 }
77 } else {
78 let xinv = x[i].inv();
79 let mut pre = MInt::zero();
80 for j in 0..=n {
81 let d = -(dp[j] - pre) * xinv;
82 res[j] += d * t;
83 pre = d;
84 }
85 }
86 }
87 res
88}