competitive/math/
prime_table.rs

1use std::iter::once;
2
3#[derive(Clone, Debug)]
4pub struct PrimeTable {
5    table: Vec<u32>,
6}
7
8impl PrimeTable {
9    pub fn new(max_n: u32) -> Self {
10        let mut table = vec![1; (max_n as usize + 1) / 2];
11        table[0] = 0;
12        for i in (3..).step_by(2) {
13            let i2 = i * i;
14            if i2 > max_n {
15                break;
16            }
17            if table[i as usize >> 1] == 1 {
18                for j in (i2..=max_n).step_by(i as usize * 2) {
19                    if table[j as usize >> 1] == 1 {
20                        table[j as usize >> 1] = i;
21                    }
22                }
23            }
24        }
25        PrimeTable { table }
26    }
27    pub fn is_prime(&self, n: u32) -> bool {
28        n == 2 || n % 2 == 1 && self.table[n as usize >> 1] == 1
29    }
30    pub fn primes(&self) -> impl Iterator<Item = u32> + '_ {
31        once(2).chain(self.table.iter().enumerate().filter_map(|(i, b)| {
32            if *b == 1 {
33                Some(i as u32 * 2 + 1)
34            } else {
35                None
36            }
37        }))
38    }
39    pub fn trial_division<F>(&self, mut n: u32, mut f: F)
40    where
41        F: FnMut(u32, u32),
42    {
43        let k = n.trailing_zeros();
44        if k > 0 {
45            f(2, k);
46        }
47        n >>= k;
48        while self.table[n as usize >> 1] > 1 {
49            let p = self.table[n as usize >> 1];
50            let mut cnt = 1;
51            n /= p;
52            while self.table[n as usize >> 1] == p {
53                n /= p;
54                cnt += 1;
55            }
56            if n == p {
57                cnt += 1;
58                n /= p;
59            }
60            f(p, cnt);
61        }
62        if n > 1 {
63            f(n, 1);
64        }
65    }
66    pub fn prime_factors(&self, n: u32) -> Vec<(u32, u32)> {
67        let mut factors = vec![];
68        self.trial_division(n, |p, c| factors.push((p, c)));
69        factors
70    }
71    pub fn count_divisors(&self, n: u32) -> u32 {
72        let mut divisor_cnt = 1;
73        self.trial_division(n, |_, cnt| divisor_cnt *= cnt + 1);
74        divisor_cnt
75    }
76    pub fn divisors(&self, n: u32) -> Vec<u32> {
77        let mut d = vec![1u32];
78        self.trial_division(n, |p, c| {
79            let k = d.len();
80            let mut acc = 1;
81            for _ in 0..c {
82                acc *= p;
83                for i in 0..k {
84                    d.push(d[i] * acc);
85                }
86            }
87        });
88        d.sort_unstable();
89        d
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use crate::tools::Xorshift;
97
98    pub fn divisors(n: u32) -> Vec<u32> {
99        let mut res = vec![];
100        for i in 1..(n as f32).sqrt() as u32 + 1 {
101            if n % i == 0 {
102                res.push(i);
103                if i * i != n {
104                    res.push(n / i);
105                }
106            }
107        }
108        res.sort_unstable();
109        res
110    }
111
112    #[test]
113    fn test_prime_table() {
114        const N: u32 = 100_000;
115        let primes = PrimeTable::new(N);
116        assert!(!primes.is_prime(N));
117        assert!(primes.is_prime(99991));
118
119        let factors = primes.prime_factors(99991);
120        assert_eq!(factors, vec![(99991, 1)]);
121        let factors = primes.prime_factors(2016);
122        assert_eq!(factors, vec![(2, 5), (3, 2), (7, 1)]);
123        for i in 1..=N {
124            assert_eq!(
125                i,
126                primes
127                    .prime_factors(i)
128                    .into_iter()
129                    .map(|(p, c)| p.pow(c))
130                    .product::<u32>()
131            );
132            assert_eq!(
133                primes
134                    .prime_factors(i)
135                    .into_iter()
136                    .map(|(_, c)| c + 1)
137                    .product::<u32>(),
138                primes.count_divisors(i)
139            );
140        }
141    }
142
143    #[test]
144    fn test_divisors() {
145        let mut rng = Xorshift::default();
146        let pt = PrimeTable::new(200001);
147        for n in (1..1000).chain(rng.random_iter(1..=200000).take(100)) {
148            assert_eq!(pt.divisors(n), divisors(n));
149        }
150    }
151}