Skip to main content

competitive/math/
pow_prec.rs

1use super::{MInt, MIntConvert, One, prime_factors};
2
3#[derive(Debug, Clone)]
4pub struct PowPrec<M>
5where
6    M: MIntConvert<usize>,
7{
8    period: usize,
9    sqn: usize,
10    p0: Vec<MInt<M>>,
11    p1: Vec<MInt<M>>,
12}
13
14impl<M> PowPrec<M>
15where
16    M: MIntConvert<usize>,
17{
18    pub fn new(a: MInt<M>) -> Self {
19        let mut maxe = 0;
20        let period: u64 = prime_factors(M::mod_into() as u64)
21            .into_iter()
22            .map(|(p, e)| {
23                maxe = maxe.max(e);
24                p.pow(e - 1) * (p - 1)
25            })
26            .product();
27        let period = period as usize;
28        let sqn = ((period as f64).sqrt() as usize).max(maxe as usize) + 1;
29        let mut p0 = Vec::with_capacity(sqn);
30        let mut p1 = Vec::with_capacity(sqn);
31        let mut acc = MInt::<M>::one();
32        for _ in 0..sqn {
33            p0.push(acc);
34            acc *= a;
35        }
36        let b = acc;
37        acc = MInt::<M>::one();
38        for _ in 0..sqn {
39            p1.push(acc);
40            acc *= b;
41        }
42        Self {
43            period,
44            sqn,
45            p0,
46            p1,
47        }
48    }
49
50    pub fn pow(&self, n: usize) -> MInt<M> {
51        if n < self.sqn {
52            return self.p0[n];
53        }
54        let n = (n + 1 - self.sqn) % self.period;
55        let (p, q) = (n / self.sqn, n % self.sqn);
56        self.p1[p] * self.p0[q] * self.p0[self.sqn - 1]
57    }
58
59    /// gcd(a, mod) must be 1
60    pub fn powi(&self, n: isize) -> MInt<M> {
61        let n = n.rem_euclid(self.period as isize) as usize;
62        let (p, q) = (n / self.sqn, n % self.sqn);
63        self.p1[p] * self.p0[q]
64    }
65
66    /// gcd(a, mod) must be 1
67    pub fn inv(&self) -> MInt<M> {
68        self.powi(-1)
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75    use crate::{
76        num::{
77            Unsigned,
78            mint_basic::{DynMIntU32, MInt998244353},
79        },
80        tools::Xorshift,
81    };
82
83    #[test]
84    fn test_pow_prec_small() {
85        for m in 2..=100 {
86            DynMIntU32::set_mod(m);
87            for a in 0..m {
88                let a = DynMIntU32::new(a);
89                let p = PowPrec::new(a);
90                for i in 0..=m * 2 {
91                    assert_eq!(p.pow(i as _), a.pow(i as _));
92                }
93                if m.gcd(a.inner()) == 1 {
94                    for i in -(m as isize * 2)..=(m as isize * 2) {
95                        assert_eq!(
96                            p.powi(i),
97                            if i >= 0 {
98                                a.pow(i as _)
99                            } else {
100                                a.inv().pow((-i) as _)
101                            }
102                        );
103                    }
104                    assert_eq!(p.inv(), a.inv());
105                }
106            }
107        }
108    }
109
110    #[test]
111    fn test_pow_prec_large() {
112        let mut rng = Xorshift::default();
113        for _ in 0..10 {
114            let a = rng.random(1..MInt998244353::get_mod());
115            let a = MInt998244353::new(a);
116            let p = PowPrec::new(a);
117            for _ in 0..100 {
118                let i = rng.random(0..2_000_000_000);
119                assert_eq!(p.pow(i as _), a.pow(i as _));
120                assert_eq!(p.powi(i as _), a.pow(i as _));
121                assert_eq!(p.powi(-(i as isize)), a.inv().pow(i as _));
122            }
123            assert_eq!(p.inv(), a.inv());
124        }
125    }
126}