competitive/math/
garner.rs

1use super::Unsigned;
2
3/// Garner's algorithm with precomputation for fixed moduli.
4pub struct Garner<T>
5where
6    T: Unsigned,
7{
8    moduli: Vec<T>,
9    coeff: Vec<T>,
10    inv: Vec<T>,
11}
12
13impl<T> Garner<T>
14where
15    T: Unsigned,
16{
17    pub fn new<M>(moduli: M, modulo: T) -> Option<Self>
18    where
19        M: IntoIterator<Item = T>,
20    {
21        if modulo == T::zero() {
22            return None;
23        }
24        let moduli: Vec<_> = moduli.into_iter().collect();
25        if moduli.iter().any(|&m| m.is_zero()) {
26            return None;
27        }
28        let n = moduli.len();
29        for i in 0..n {
30            for j in 0..i {
31                if moduli[i].gcd(moduli[j]) != T::one() {
32                    return None;
33                }
34            }
35        }
36        Some(Self::new_unchecked(moduli, modulo))
37    }
38
39    pub fn new_unchecked<M>(moduli: M, modulo: T) -> Self
40    where
41        M: IntoIterator<Item = T>,
42    {
43        let mut moduli: Vec<_> = moduli.into_iter().collect();
44        let n = moduli.len();
45        moduli.push(modulo);
46        let coeff_len = n * (n + 1) / 2;
47        let mut coeff = Vec::with_capacity(coeff_len);
48        let mut inv = Vec::with_capacity(n);
49        let mut prefix = vec![T::one(); moduli.len()];
50        for i in 0..n {
51            let modulus = moduli[i];
52            inv.push(prefix[i].mod_inv(modulus));
53            for j in i + 1..=n {
54                coeff.push(prefix[j]);
55                prefix[j] = prefix[j].mod_mul(modulus, moduli[j]);
56            }
57        }
58        Self { moduli, coeff, inv }
59    }
60
61    pub fn solve<B, I>(&self, residues: B) -> Option<T>
62    where
63        B: IntoIterator<Item = T, IntoIter = I>,
64        I: ExactSizeIterator<Item = T>,
65    {
66        let residues = residues.into_iter();
67        if residues.len() != self.inv.len() {
68            return None;
69        }
70        let n = residues.len();
71        let mut constants = vec![T::zero(); n + 1];
72        let mut start = 0;
73        for (((i, residue), &modulus), &inv) in residues
74            .into_iter()
75            .enumerate()
76            .zip(&self.moduli)
77            .zip(&self.inv)
78        {
79            debug_assert!(residue < modulus);
80            let t = residue.mod_sub(constants[i], modulus).mod_mul(inv, modulus);
81            let coeff = &self.coeff[start..start + n - i];
82            start += n - i;
83            for ((constant, &modulus), &coeff) in constants
84                .iter_mut()
85                .zip(&self.moduli)
86                .skip(i + 1)
87                .zip(coeff)
88            {
89                *constant = coeff.mod_mul(t, modulus).mod_add(*constant, modulus);
90            }
91        }
92        Some(constants[n])
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use crate::{math::solve_simultaneous_linear_congruence, tools::Xorshift};
100
101    #[test]
102    fn test_garner() {
103        let mut rng = Xorshift::default();
104        for _ in 0..200 {
105            let mut mod_candidates = [2u64, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41];
106            rng.shuffle(&mut mod_candidates);
107            let len = rng.random(1..=mod_candidates.len());
108            let moduli: Vec<_> = mod_candidates[..len].to_vec();
109            let product: u64 = moduli.iter().copied().product();
110            let final_mods: Vec<_> = rng.random_iter(2..=product).take(10).collect();
111            for final_mod in final_mods {
112                let solver = Garner::new(moduli.iter().copied(), final_mod).unwrap();
113                for _ in 0..10 {
114                    let residues: Vec<_> = moduli
115                        .iter()
116                        .map(|&modulus| rng.random(0..modulus))
117                        .collect();
118                    let value = solver.solve(residues.clone()).unwrap();
119                    let pairs: Vec<_> = residues
120                        .iter()
121                        .zip(moduli.iter())
122                        .map(|(&b, &modulus)| (b, modulus))
123                        .collect();
124                    let (expected, _) = solve_simultaneous_linear_congruence(
125                        pairs.iter().copied().map(|(b, modulus)| (1u64, b, modulus)),
126                    )
127                    .unwrap();
128                    assert_eq!(value, expected % final_mod);
129                }
130            }
131        }
132    }
133}