competitive/math/
prime_list.rs

1use std::{cell::UnsafeCell, mem::replace, slice::Iter};
2
3#[derive(Debug, Clone)]
4pub struct PrimeList {
5    primes: Vec<u64>,
6    max_n: u64,
7}
8
9impl Default for PrimeList {
10    fn default() -> Self {
11        Self {
12            primes: Default::default(),
13            max_n: 1,
14        }
15    }
16}
17
18impl PrimeList {
19    pub fn new(max_n: u64) -> Self {
20        let mut self_: Self = Default::default();
21        self_.reserve(max_n);
22        self_
23    }
24    pub fn primes(&self) -> &[u64] {
25        self.primes.as_slice()
26    }
27    pub fn primes_lte(&self, n: u64) -> &[u64] {
28        assert!(n <= self.max_n, "expected `n={} <= {}`", n, self.max_n);
29        let i = self.primes.partition_point(|&p| p <= n);
30        &self.primes[..i]
31    }
32    pub fn is_prime(&self, n: u64) -> bool {
33        assert!(n <= self.max_n, "expected `n={} <= {}`", n, self.max_n);
34        self.primes.binary_search(&n).is_ok()
35    }
36    pub fn trial_division(&self, n: u64) -> PrimeListTrialDivision<'_> {
37        let bound = self.max_n.saturating_mul(self.max_n);
38        assert!(n <= bound, "expected `n={} <= {}`", n, bound);
39        PrimeListTrialDivision {
40            primes: self.primes.iter(),
41            n,
42        }
43    }
44    pub fn prime_factors(&self, n: u64) -> Vec<(u64, u32)> {
45        self.trial_division(n).collect()
46    }
47    pub fn count_divisors(&self, n: u64) -> u64 {
48        let mut divisor_cnt = 1u64;
49        for (_, cnt) in self.trial_division(n) {
50            divisor_cnt *= cnt as u64 + 1;
51        }
52        divisor_cnt
53    }
54    pub fn divisors(&self, n: u64) -> Vec<u64> {
55        let mut d = vec![1u64];
56        for (p, c) in self.trial_division(n) {
57            let k = d.len();
58            let mut acc = 1;
59            for _ in 0..c {
60                acc *= p;
61                for i in 0..k {
62                    d.push(d[i] * acc);
63                }
64            }
65        }
66        d.sort_unstable();
67        d
68    }
69    /// list primes less than or equal to `max_n` by segmented sieve
70    pub fn reserve(&mut self, max_n: u64) {
71        if max_n <= self.max_n || max_n < 2 {
72            return;
73        }
74
75        if self.primes.is_empty() {
76            self.primes.push(2);
77            self.max_n = 2;
78        }
79        if max_n == 2 {
80            return;
81        }
82
83        let max_n = (max_n + 1) / 2 * 2; // odd
84        let sqrt_n = ((max_n as f64).sqrt() as usize + 1) / 2 * 2; // even
85        let mut table = Vec::with_capacity(sqrt_n >> 1);
86        if self.max_n < sqrt_n as u64 {
87            let start = (self.max_n as usize + 1) | 1; // odd
88            let end = sqrt_n + 1;
89            let sqrt_end = (sqrt_n as f64).sqrt() as usize;
90            let plen = self.primes[1..]
91                .binary_search(&(sqrt_end as u64 + 1))
92                .unwrap_or_else(|x| x);
93            table.resize(end / 2 - start / 2, false);
94            for &p in self.primes.iter().skip(1).take(plen) {
95                let y = p.max((start as u64 + p - 1) / (2 * p) * 2 + 1) * p / 2;
96                (y as usize - start / 2..end / 2 - start / 2)
97                    .step_by(p as usize)
98                    .for_each(|i| table[i] = true);
99            }
100            for i in 0..=(sqrt_end / 2).saturating_sub(start / 2) {
101                if !table[i] {
102                    let p = (i + start / 2) * 2 + 1;
103                    for j in (p * p / 2 - start / 2..sqrt_n / 2 - start / 2).step_by(p) {
104                        table[j] = true;
105                    }
106                }
107            }
108            self.primes
109                .extend(table.iter().cloned().enumerate().filter_map(|(i, b)| {
110                    if !b {
111                        Some((i + start / 2) as u64 * 2 + 1)
112                    } else {
113                        None
114                    }
115                }));
116            self.max_n = sqrt_n as u64;
117        }
118
119        let sqrt_n = sqrt_n as u64;
120        for start in (self.max_n + 1..=max_n).step_by(sqrt_n as usize) {
121            let end = (start + sqrt_n).min(max_n + 1);
122            let sqrt_end = (end as f64).sqrt() as u64;
123            let length = end - start;
124            let plen = self.primes[1..]
125                .binary_search(&(sqrt_end + 1))
126                .unwrap_or_else(|x| x);
127            table.clear();
128            table.resize(length as usize / 2, false);
129            for &p in self.primes.iter().skip(1).take(plen) {
130                let y = p.max((start + p - 1) / (2 * p) * 2 + 1) * p / 2;
131                ((y - start / 2) as usize..length as usize / 2)
132                    .step_by(p as usize)
133                    .for_each(|i| table[i] = true);
134            }
135            self.primes
136                .extend(table.iter().cloned().enumerate().filter_map(|(i, b)| {
137                    if !b {
138                        Some((i as u64 + start / 2) * 2 + 1)
139                    } else {
140                        None
141                    }
142                }));
143        }
144        self.max_n = max_n;
145    }
146}
147
148#[derive(Debug, Clone)]
149pub struct PrimeListTrialDivision<'p> {
150    primes: Iter<'p, u64>,
151    n: u64,
152}
153impl Iterator for PrimeListTrialDivision<'_> {
154    type Item = (u64, u32);
155    fn next(&mut self) -> Option<Self::Item> {
156        if self.n <= 1 {
157            return None;
158        }
159        loop {
160            match self.primes.next() {
161                Some(&p) if p * p <= self.n => {
162                    if self.n % p == 0 {
163                        let mut cnt = 1u32;
164                        self.n /= p;
165                        while self.n % p == 0 {
166                            cnt += 1;
167                            self.n /= p;
168                        }
169                        return Some((p, cnt));
170                    }
171                }
172                _ => break,
173            }
174        }
175        if self.n > 1 {
176            return Some((replace(&mut self.n, 1), 1));
177        }
178        None
179    }
180}
181
182pub fn with_prime_list<F>(max_n: u64, f: F)
183where
184    F: FnOnce(&PrimeList),
185{
186    thread_local!(static PRIME_LIST: UnsafeCell<PrimeList> = Default::default());
187    PRIME_LIST.with(|cell| {
188        unsafe {
189            let pl = &mut *cell.get();
190            pl.reserve(max_n);
191            f(pl);
192        };
193    });
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use crate::math::prime_factors;
200    use crate::tools::Xorshift;
201
202    fn primes(n: usize) -> Vec<usize> {
203        if n < 2 {
204            return vec![];
205        }
206        let mut res = vec![2];
207        let sqrt_n = (n as f32).sqrt() as usize | 1;
208        let mut seive = vec![true; n / 2];
209        for i in (3..=sqrt_n).step_by(2) {
210            if seive[i / 2 - 1] {
211                res.push(i);
212                for j in (i * i..=n).step_by(i * 2) {
213                    seive[j / 2 - 1] = false;
214                }
215            }
216        }
217        for i in (std::cmp::max(3, sqrt_n + 2)..=n).step_by(2) {
218            if seive[i / 2 - 1] {
219                res.push(i);
220            }
221        }
222        res
223    }
224
225    fn segmented_sieve_primes(n: usize) -> Vec<usize> {
226        if n < 2 {
227            return Vec::new();
228        }
229        let seg_size = ((n as f32).sqrt() as usize + 2) >> 1;
230        let mut primes = vec![2];
231        let mut table = vec![true; seg_size];
232        for i in 1..seg_size {
233            if table[i] {
234                let p = i * 2 + 1;
235                primes.push(p);
236                for j in (p * p / 2..seg_size).step_by(p) {
237                    table[j] = false;
238                }
239            }
240        }
241        for s in (seg_size..=n / 2).step_by(seg_size) {
242            let m = seg_size.min((n + 1) / 2 - s);
243            table.clear();
244            table.resize(m, true);
245            let plen = primes[1..]
246                .binary_search(&((((s + m) * 2 + 1) as f32).sqrt() as usize + 1))
247                .unwrap_or_else(|x| x);
248            for &p in primes[1..plen + 1].iter() {
249                for k in (((s * 2 + p * 3) / (p * 2) * p * 2 - p) / 2 - s..m).step_by(p) {
250                    table[k] = false;
251                }
252            }
253            primes.extend((s..m + s).filter(|k| table[k - s]).map(|k| k * 2 + 1));
254        }
255        primes
256    }
257
258    pub fn divisors(n: u64) -> Vec<u64> {
259        let mut res = vec![];
260        for i in 1..(n as f32).sqrt() as u64 + 1 {
261            if n % i == 0 {
262                res.push(i);
263                if i * i != n {
264                    res.push(n / i);
265                }
266            }
267        }
268        res.sort_unstable();
269        res
270    }
271
272    #[test]
273    fn test_prime_list() {
274        let mut rng = Xorshift::default();
275
276        for n in (0..1000).chain(rng.random_iter(0..=20000).take(100)) {
277            let pl = PrimeList::new(n);
278            let ps: Vec<_> = primes(n as _).into_iter().map(|p| p as u64).collect();
279            assert_eq!(pl.primes(), ps.as_slice());
280        }
281
282        for _ in 0..100 {
283            let b = rng.randf() * 0.0001;
284            let mut pl = PrimeList::new(0);
285            for n in 0..20000 {
286                if rng.gen_bool(b) {
287                    pl.reserve(n);
288                    let ps: Vec<_> = primes(n as _).into_iter().map(|p| p as u64).collect();
289                    assert_eq!(pl.primes(), ps.as_slice());
290                }
291            }
292        }
293
294        let pl = PrimeList::new(100_000);
295        for n in (0..1000).chain(rng.random_iter(0..=1_000_000_000).take(100)) {
296            assert_eq!(prime_factors(n), pl.prime_factors(n));
297        }
298    }
299
300    #[test]
301    fn test_primes() {
302        let pl = PrimeList::new(2000);
303        for i in 0..=2000 {
304            assert_eq!(
305                primes(i),
306                (2..=i).filter(|&i| pl.is_prime(i as _)).collect::<Vec<_>>(),
307            );
308            assert_eq!(
309                primes(i).iter().map(|&p| p as _).collect::<Vec<u64>>(),
310                pl.primes_lte(i as _)
311            );
312        }
313    }
314
315    #[test]
316    fn test_segmented_sieve_primes() {
317        for i in 0..300 {
318            assert_eq!(primes(i), segmented_sieve_primes(i));
319        }
320        assert_eq!(primes(1_000_000), segmented_sieve_primes(1_000_000));
321    }
322
323    #[test]
324    fn test_divisors() {
325        let mut rng = Xorshift::default();
326        let pl = PrimeList::new(20000);
327        for n in (1..1000).chain(rng.random_iter(1..=20000000).take(100)) {
328            assert_eq!(pl.divisors(n), divisors(n));
329        }
330    }
331}