competitive/math/
quotient_array.rs1use super::{Group, Invertible, One, Ring, Zero, with_prime_list};
2use std::ops::{Index, IndexMut};
3
4#[derive(Debug, Clone)]
6pub struct QuotientArray<T> {
7 n: u64,
8 isqrtn: u64,
9 data: Vec<T>,
10}
11
12impl<T> QuotientArray<T>
13where
14 T: Zero,
15{
16 pub fn zeros(n: u64) -> Self {
17 Self::from_fn(n, |_| T::zero())
18 }
19}
20
21impl<T> QuotientArray<T> {
22 pub fn index_iter(n: u64, isqrtn: u64) -> impl Iterator<Item = u64> {
23 (1..=isqrtn)
24 .map(move |i| n / i)
25 .chain((1..n / isqrtn).rev())
26 }
27
28 pub fn map<U>(&self, f: impl FnMut(&T) -> U) -> QuotientArray<U> {
29 let data = self.data.iter().map(f).collect();
30 QuotientArray {
31 n: self.n,
32 isqrtn: self.isqrtn,
33 data,
34 }
35 }
36
37 pub fn quotient_index(&self, i: u64) -> usize {
38 assert!(
39 i <= self.n,
40 "index out of bounds: the len is {} but the index is {}",
41 self.n,
42 i
43 );
44 assert_ne!(i, 0, "index out of bounds: the index is 0");
45 if i <= self.isqrtn {
46 self.data.len() - i as usize
47 } else {
48 (self.n / i) as usize - 1
49 }
50 }
51
52 pub fn from_fn(n: u64, f: impl FnMut(u64) -> T) -> Self {
53 let isqrtn = (n as f64).sqrt().floor() as u64;
54 let data = Self::index_iter(n, isqrtn).map(f).collect();
55 Self { n, isqrtn, data }
56 }
57
58 pub fn lucy_dp<G>(mut self, mut mul_p: impl FnMut(T, u64) -> T) -> Self
62 where
63 G: Group<T = T>,
64 {
65 with_prime_list(self.isqrtn, |pl| {
66 for &p in pl.primes_lte(self.isqrtn) {
67 let k = self.quotient_index(p - 1);
68 let p2 = p * p;
69 for (i, q) in Self::index_iter(self.n, self.isqrtn).enumerate() {
70 if q < p2 {
71 break;
72 }
73 let diff = mul_p(G::rinv_operate(&self[q / p], &self.data[k]), p);
74 G::rinv_operate_assign(&mut self.data[i], &diff);
75 }
76 }
77 });
78 self
79 }
80
81 pub fn min_25_sieve<R>(&self, mut f: impl FnMut(u64, u32) -> T) -> Self
83 where
84 T: Clone + One,
85 R: Ring<T = T>,
86 R::Additive: Invertible,
87 {
88 let mut dp = self.clone();
89 with_prime_list(self.isqrtn, |pl| {
90 for &p in pl.primes_lte(self.isqrtn).iter().rev() {
91 let k = self.quotient_index(p);
92 for (i, q) in Self::index_iter(self.n, self.isqrtn).enumerate() {
93 let mut pc = p;
94 if pc * p > q {
95 break;
96 }
97 let mut c = 1;
98 while q / p >= pc {
99 let x = R::mul(&f(p, c), &(R::sub(&dp[q / pc], &self.data[k])));
100 let x = R::add(&x, &f(p, c + 1));
101 dp.data[i] = R::add(&dp.data[i], &x);
102 c += 1;
103 pc *= p;
104 }
105 }
106 }
107 });
108 for x in &mut dp.data {
109 *x = R::add(x, &T::one());
110 }
111 dp
112 }
113}
114
115impl<T> Index<u64> for QuotientArray<T> {
116 type Output = T;
117 fn index(&self, i: u64) -> &Self::Output {
118 unsafe { self.data.get_unchecked(self.quotient_index(i)) }
119 }
120}
121
122impl<T> IndexMut<u64> for QuotientArray<T> {
123 fn index_mut(&mut self, index: u64) -> &mut Self::Output {
124 let i = self.quotient_index(index);
125 unsafe { self.data.get_unchecked_mut(i) }
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use crate::{
133 algebra::{AddMulOperation, AdditiveOperation, ArrayOperation},
134 math::{PrimeList, PrimeTable},
135 tools::Xorshift,
136 };
137
138 #[test]
139 fn prime_count() {
140 let mut rng = Xorshift::default();
141 let pl = PrimeList::new(100_000);
142 for n in 1..=100 {
143 let n = if n <= 10 { n } else { rng.random(1..10_000) };
144 let qa = QuotientArray::from_fn(n, |i| i as i64 - 1)
145 .lucy_dp::<AdditiveOperation<_>>(|x, _p| x);
146 assert_eq!(pl.primes_lte(n).len(), qa[n] as usize);
147 }
148 }
149
150 #[test]
151 fn divisor_sum() {
152 let mut rng = Xorshift::default();
153 let pt = PrimeTable::new(10_000);
154 for n in 1..=100 {
155 let n = if n <= 10 { n } else { rng.random(1..10_000) };
156 let qa = QuotientArray::from_fn(n, |i| [i as i64, i as i64 * (i as i64 + 1) / 2])
157 .map(|[x, y]| [x - 1, y - 1])
158 .lucy_dp::<ArrayOperation<AdditiveOperation<_>, 2>>(|[x, y], p| [x, y * p as i64])
159 .map(|[x, y]| x + y)
160 .min_25_sieve::<AddMulOperation<_>>(|p, c| {
161 let mut x = 1;
162 let mut s = 1;
163 for _ in 0..c {
164 x *= p as i64;
165 s += x;
166 }
167 s
168 });
169 assert_eq!(
170 (1..=n)
171 .flat_map(|i| pt.divisors(i as _))
172 .map(|d| d as u64)
173 .sum::<u64>(),
174 qa[n] as u64
175 );
176 }
177 }
178}