competitive/math/
quotient_array.rs1use super::{Group, Invertible, One, Ring, Zero, with_prime_list};
2use std::ops::{Index, IndexMut};
3
4#[derive(Debug, Clone)]
6pub struct QuotientArray<T> {
7 n: u64,
8 isqrtn: u64,
9 data: Vec<T>,
10}
11
12impl<T> QuotientArray<T>
13where
14 T: Zero,
15{
16 pub fn zeros(n: u64) -> Self {
17 Self::from_fn(n, |_| T::zero())
18 }
19}
20
21impl<T> QuotientArray<T> {
22 pub fn index_iter(n: u64, isqrtn: u64) -> impl Iterator<Item = u64> {
23 (1..=isqrtn)
24 .map(move |i| n / i)
25 .chain((1..n / isqrtn).rev())
26 }
27
28 pub fn map<U>(&self, f: impl FnMut(&T) -> U) -> QuotientArray<U> {
29 let data = self.data.iter().map(f).collect();
30 QuotientArray {
31 n: self.n,
32 isqrtn: self.isqrtn,
33 data,
34 }
35 }
36
37 pub fn quotient_index(&self, i: u64) -> usize {
38 assert!(
39 i <= self.n,
40 "index out of bounds: the len is {} but the index is {}",
41 self.n,
42 i
43 );
44 assert_ne!(i, 0, "index out of bounds: the index is 0");
45 if i <= self.isqrtn {
46 self.data.len() - i as usize
47 } else {
48 (self.n / i) as usize - 1
49 }
50 }
51
52 pub fn from_fn(n: u64, f: impl FnMut(u64) -> T) -> Self {
53 let isqrtn = (n as f64).sqrt().floor() as u64;
54 let data = Self::index_iter(n, isqrtn).map(f).collect();
55 Self { n, isqrtn, data }
56 }
57
58 pub fn lucy_dp<G>(mut self, mut mul_p: impl FnMut(T, u64) -> T) -> Self
62 where
63 G: Group<T = T>,
64 {
65 with_prime_list(self.isqrtn, |pl| {
66 for &p in pl.primes_lte(self.isqrtn) {
67 let k = self.quotient_index(p - 1);
68 let p2 = p * p;
69 for (i, q) in Self::index_iter(self.n, self.isqrtn).enumerate() {
70 if q < p2 {
71 break;
72 }
73 let diff = mul_p(G::rinv_operate(&self[q / p], &self.data[k]), p);
74 G::rinv_operate_assign(&mut self.data[i], &diff);
75 }
76 }
77 });
78 self
79 }
80
81 pub fn min_25_sieve<R>(&self, mut f: impl FnMut(u64, u32) -> T) -> Self
83 where
84 T: Clone + One,
85 R: Ring<T = T, Additive: Invertible>,
86 {
87 let mut dp = self.clone();
88 with_prime_list(self.isqrtn, |pl| {
89 for &p in pl.primes_lte(self.isqrtn).iter().rev() {
90 let k = self.quotient_index(p);
91 for (i, q) in Self::index_iter(self.n, self.isqrtn).enumerate() {
92 let mut pc = p;
93 if pc * p > q {
94 break;
95 }
96 let mut c = 1;
97 while q / p >= pc {
98 let x = R::mul(&f(p, c), &(R::sub(&dp[q / pc], &self.data[k])));
99 let x = R::add(&x, &f(p, c + 1));
100 dp.data[i] = R::add(&dp.data[i], &x);
101 c += 1;
102 pc *= p;
103 }
104 }
105 }
106 });
107 for x in &mut dp.data {
108 *x = R::add(x, &T::one());
109 }
110 dp
111 }
112}
113
114impl<T> Index<u64> for QuotientArray<T> {
115 type Output = T;
116 fn index(&self, i: u64) -> &Self::Output {
117 unsafe { self.data.get_unchecked(self.quotient_index(i)) }
118 }
119}
120
121impl<T> IndexMut<u64> for QuotientArray<T> {
122 fn index_mut(&mut self, index: u64) -> &mut Self::Output {
123 let i = self.quotient_index(index);
124 unsafe { self.data.get_unchecked_mut(i) }
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use crate::{
132 algebra::{AddMulOperation, AdditiveOperation, ArrayOperation},
133 math::{PrimeList, PrimeTable},
134 tools::Xorshift,
135 };
136
137 #[test]
138 fn prime_count() {
139 let mut rng = Xorshift::default();
140 let pl = PrimeList::new(100_000);
141 for n in 1..=100 {
142 let n = if n <= 10 { n } else { rng.random(1..10_000) };
143 let qa = QuotientArray::from_fn(n, |i| i as i64 - 1)
144 .lucy_dp::<AdditiveOperation<_>>(|x, _p| x);
145 assert_eq!(pl.primes_lte(n).len(), qa[n] as usize);
146 }
147 }
148
149 #[test]
150 fn divisor_sum() {
151 let mut rng = Xorshift::default();
152 let pt = PrimeTable::new(10_000);
153 for n in 1..=100 {
154 let n = if n <= 10 { n } else { rng.random(1..10_000) };
155 let qa = QuotientArray::from_fn(n, |i| [i as i64, i as i64 * (i as i64 + 1) / 2])
156 .map(|[x, y]| [x - 1, y - 1])
157 .lucy_dp::<ArrayOperation<AdditiveOperation<_>, 2>>(|[x, y], p| [x, y * p as i64])
158 .map(|[x, y]| x + y)
159 .min_25_sieve::<AddMulOperation<_>>(|p, c| {
160 let mut x = 1;
161 let mut s = 1;
162 for _ in 0..c {
163 x *= p as i64;
164 s += x;
165 }
166 s
167 });
168 assert_eq!(
169 (1..=n)
170 .flat_map(|i| pt.divisors(i as _))
171 .map(|d| d as u64)
172 .sum::<u64>(),
173 qa[n] as u64
174 );
175 }
176 }
177}