competitive/math/
prime_list.rs1use std::{cell::UnsafeCell, mem::replace, slice::Iter};
2
3#[derive(Debug, Clone)]
4pub struct PrimeList {
5 primes: Vec<u64>,
6 max_n: u64,
7}
8
9impl Default for PrimeList {
10 fn default() -> Self {
11 Self {
12 primes: Default::default(),
13 max_n: 1,
14 }
15 }
16}
17
18impl PrimeList {
19 pub fn new(max_n: u64) -> Self {
20 let mut self_: Self = Default::default();
21 self_.reserve(max_n);
22 self_
23 }
24 pub fn primes(&self) -> &[u64] {
25 self.primes.as_slice()
26 }
27 pub fn primes_lte(&self, n: u64) -> &[u64] {
28 assert!(n <= self.max_n, "expected `n={} <= {}`", n, self.max_n);
29 let i = self.primes.partition_point(|&p| p <= n);
30 &self.primes[..i]
31 }
32 pub fn is_prime(&self, n: u64) -> bool {
33 assert!(n <= self.max_n, "expected `n={} <= {}`", n, self.max_n);
34 self.primes.binary_search(&n).is_ok()
35 }
36 pub fn trial_division(&self, n: u64) -> PrimeListTrialDivision<'_> {
37 let bound = self.max_n.saturating_mul(self.max_n);
38 assert!(n <= bound, "expected `n={} <= {}`", n, bound);
39 PrimeListTrialDivision {
40 primes: self.primes.iter(),
41 n,
42 }
43 }
44 pub fn prime_factors(&self, n: u64) -> Vec<(u64, u32)> {
45 self.trial_division(n).collect()
46 }
47 pub fn count_divisors(&self, n: u64) -> u64 {
48 let mut divisor_cnt = 1u64;
49 for (_, cnt) in self.trial_division(n) {
50 divisor_cnt *= cnt as u64 + 1;
51 }
52 divisor_cnt
53 }
54 pub fn divisors(&self, n: u64) -> Vec<u64> {
55 let mut d = vec![1u64];
56 for (p, c) in self.trial_division(n) {
57 let k = d.len();
58 let mut acc = 1;
59 for _ in 0..c {
60 acc *= p;
61 for i in 0..k {
62 d.push(d[i] * acc);
63 }
64 }
65 }
66 d.sort_unstable();
67 d
68 }
69 pub fn reserve(&mut self, max_n: u64) {
71 if max_n <= self.max_n || max_n < 2 {
72 return;
73 }
74
75 if self.primes.is_empty() {
76 self.primes.push(2);
77 self.max_n = 2;
78 }
79 if max_n == 2 {
80 return;
81 }
82
83 let max_n = (max_n + 1) / 2 * 2; let sqrt_n = ((max_n as f64).sqrt() as usize + 1) / 2 * 2; let mut table = Vec::with_capacity(sqrt_n >> 1);
86 if self.max_n < sqrt_n as u64 {
87 let start = (self.max_n as usize + 1) | 1; let end = sqrt_n + 1;
89 let sqrt_end = (sqrt_n as f64).sqrt() as usize;
90 let plen = self.primes[1..]
91 .binary_search(&(sqrt_end as u64 + 1))
92 .unwrap_or_else(|x| x);
93 table.resize(end / 2 - start / 2, false);
94 for &p in self.primes.iter().skip(1).take(plen) {
95 let y = p.max((start as u64 + p - 1) / (2 * p) * 2 + 1) * p / 2;
96 (y as usize - start / 2..end / 2 - start / 2)
97 .step_by(p as usize)
98 .for_each(|i| table[i] = true);
99 }
100 for i in 0..=(sqrt_end / 2).saturating_sub(start / 2) {
101 if !table[i] {
102 let p = (i + start / 2) * 2 + 1;
103 for j in (p * p / 2 - start / 2..sqrt_n / 2 - start / 2).step_by(p) {
104 table[j] = true;
105 }
106 }
107 }
108 self.primes
109 .extend(table.iter().cloned().enumerate().filter_map(|(i, b)| {
110 if !b {
111 Some((i + start / 2) as u64 * 2 + 1)
112 } else {
113 None
114 }
115 }));
116 self.max_n = sqrt_n as u64;
117 }
118
119 let sqrt_n = sqrt_n as u64;
120 for start in (self.max_n + 1..=max_n).step_by(sqrt_n as usize) {
121 let end = (start + sqrt_n).min(max_n + 1);
122 let sqrt_end = (end as f64).sqrt() as u64;
123 let length = end - start;
124 let plen = self.primes[1..]
125 .binary_search(&(sqrt_end + 1))
126 .unwrap_or_else(|x| x);
127 table.clear();
128 table.resize(length as usize / 2, false);
129 for &p in self.primes.iter().skip(1).take(plen) {
130 let y = p.max((start + p - 1) / (2 * p) * 2 + 1) * p / 2;
131 ((y - start / 2) as usize..length as usize / 2)
132 .step_by(p as usize)
133 .for_each(|i| table[i] = true);
134 }
135 self.primes
136 .extend(table.iter().cloned().enumerate().filter_map(|(i, b)| {
137 if !b {
138 Some((i as u64 + start / 2) * 2 + 1)
139 } else {
140 None
141 }
142 }));
143 }
144 self.max_n = max_n;
145 }
146}
147
148#[derive(Debug, Clone)]
149pub struct PrimeListTrialDivision<'p> {
150 primes: Iter<'p, u64>,
151 n: u64,
152}
153impl Iterator for PrimeListTrialDivision<'_> {
154 type Item = (u64, u32);
155 fn next(&mut self) -> Option<Self::Item> {
156 if self.n <= 1 {
157 return None;
158 }
159 loop {
160 match self.primes.next() {
161 Some(&p) if p * p <= self.n => {
162 if self.n % p == 0 {
163 let mut cnt = 1u32;
164 self.n /= p;
165 while self.n % p == 0 {
166 cnt += 1;
167 self.n /= p;
168 }
169 return Some((p, cnt));
170 }
171 }
172 _ => break,
173 }
174 }
175 if self.n > 1 {
176 return Some((replace(&mut self.n, 1), 1));
177 }
178 None
179 }
180}
181
182pub fn with_prime_list<F>(max_n: u64, f: F)
183where
184 F: FnOnce(&PrimeList),
185{
186 thread_local!(static PRIME_LIST: UnsafeCell<PrimeList> = Default::default());
187 PRIME_LIST.with(|cell| {
188 unsafe {
189 let pl = &mut *cell.get();
190 pl.reserve(max_n);
191 f(pl);
192 };
193 });
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::math::prime_factors;
200 use crate::tools::Xorshift;
201
202 fn primes(n: usize) -> Vec<usize> {
203 if n < 2 {
204 return vec![];
205 }
206 let mut res = vec![2];
207 let sqrt_n = (n as f32).sqrt() as usize | 1;
208 let mut seive = vec![true; n / 2];
209 for i in (3..=sqrt_n).step_by(2) {
210 if seive[i / 2 - 1] {
211 res.push(i);
212 for j in (i * i..=n).step_by(i * 2) {
213 seive[j / 2 - 1] = false;
214 }
215 }
216 }
217 for i in (std::cmp::max(3, sqrt_n + 2)..=n).step_by(2) {
218 if seive[i / 2 - 1] {
219 res.push(i);
220 }
221 }
222 res
223 }
224
225 fn segmented_sieve_primes(n: usize) -> Vec<usize> {
226 if n < 2 {
227 return Vec::new();
228 }
229 let seg_size = ((n as f32).sqrt() as usize + 2) >> 1;
230 let mut primes = vec![2];
231 let mut table = vec![true; seg_size];
232 for i in 1..seg_size {
233 if table[i] {
234 let p = i * 2 + 1;
235 primes.push(p);
236 for j in (p * p / 2..seg_size).step_by(p) {
237 table[j] = false;
238 }
239 }
240 }
241 for s in (seg_size..=n / 2).step_by(seg_size) {
242 let m = seg_size.min((n + 1) / 2 - s);
243 table.clear();
244 table.resize(m, true);
245 let plen = primes[1..]
246 .binary_search(&((((s + m) * 2 + 1) as f32).sqrt() as usize + 1))
247 .unwrap_or_else(|x| x);
248 for &p in primes[1..plen + 1].iter() {
249 for k in (((s * 2 + p * 3) / (p * 2) * p * 2 - p) / 2 - s..m).step_by(p) {
250 table[k] = false;
251 }
252 }
253 primes.extend((s..m + s).filter(|k| table[k - s]).map(|k| k * 2 + 1));
254 }
255 primes
256 }
257
258 pub fn divisors(n: u64) -> Vec<u64> {
259 let mut res = vec![];
260 for i in 1..(n as f32).sqrt() as u64 + 1 {
261 if n % i == 0 {
262 res.push(i);
263 if i * i != n {
264 res.push(n / i);
265 }
266 }
267 }
268 res.sort_unstable();
269 res
270 }
271
272 #[test]
273 fn test_prime_list() {
274 let mut rng = Xorshift::default();
275
276 for n in (0..1000).chain(rng.random_iter(0..=20000).take(100)) {
277 let pl = PrimeList::new(n);
278 let ps: Vec<_> = primes(n as _).into_iter().map(|p| p as u64).collect();
279 assert_eq!(pl.primes(), ps.as_slice());
280 }
281
282 for _ in 0..100 {
283 let b = rng.randf() * 0.0001;
284 let mut pl = PrimeList::new(0);
285 for n in 0..20000 {
286 if rng.gen_bool(b) {
287 pl.reserve(n);
288 let ps: Vec<_> = primes(n as _).into_iter().map(|p| p as u64).collect();
289 assert_eq!(pl.primes(), ps.as_slice());
290 }
291 }
292 }
293
294 let pl = PrimeList::new(100_000);
295 for n in (0..1000).chain(rng.random_iter(0..=1_000_000_000).take(100)) {
296 assert_eq!(prime_factors(n), pl.prime_factors(n));
297 }
298 }
299
300 #[test]
301 fn test_primes() {
302 let pl = PrimeList::new(2000);
303 for i in 0..=2000 {
304 assert_eq!(
305 primes(i),
306 (2..=i).filter(|&i| pl.is_prime(i as _)).collect::<Vec<_>>(),
307 );
308 assert_eq!(
309 primes(i).iter().map(|&p| p as _).collect::<Vec<u64>>(),
310 pl.primes_lte(i as _)
311 );
312 }
313 }
314
315 #[test]
316 fn test_segmented_sieve_primes() {
317 for i in 0..300 {
318 assert_eq!(primes(i), segmented_sieve_primes(i));
319 }
320 assert_eq!(primes(1_000_000), segmented_sieve_primes(1_000_000));
321 }
322
323 #[test]
324 fn test_divisors() {
325 let mut rng = Xorshift::default();
326 let pl = PrimeList::new(20000);
327 for n in (1..1000).chain(rng.random_iter(1..=20000000).take(100)) {
328 assert_eq!(pl.divisors(n), divisors(n));
329 }
330 }
331}