competitive/num/
barrett_reduction.rs

1#[derive(Debug, Clone, Copy)]
2pub struct BarrettReduction<T> {
3    m: T,
4    im: T,
5}
6macro_rules! impl_barrett {
7    ($basety:ty, |$a:ident, $im:ident| $quotient:expr) => {
8        impl BarrettReduction<$basety> {
9            pub const fn new(m: $basety) -> Self {
10                Self { m, im: !0 / m }
11            }
12            pub const fn get_mod(&self) -> $basety {
13                self.m
14            }
15            pub const fn div_rem(&self, $a: $basety) -> ($basety, $basety) {
16                if self.m == 1 {
17                    return ($a, 0);
18                }
19                let $im = self.im;
20                let mut q = $quotient;
21                let mut r = $a - q * self.m;
22                if self.m <= r {
23                    r -= self.m;
24                    q += 1;
25                }
26                (q, r)
27            }
28            pub const fn div(&self, a: $basety) -> $basety {
29                self.div_rem(a).0
30            }
31            pub const fn rem(&self, a: $basety) -> $basety {
32                self.div_rem(a).1
33            }
34        }
35    };
36}
37impl_barrett!(u32, |a, im| ((a as u64 * im as u64) >> 32) as u32);
38impl_barrett!(u64, |a, im| ((a as u128 * im as u128) >> 64) as u64);
39impl_barrett!(u128, |a, im| {
40    const MASK64: u128 = 0xffff_ffff_ffff_ffff;
41    let au = a >> 64;
42    let ad = a & MASK64;
43    let imu = im >> 64;
44    let imd = im & MASK64;
45    let mut res = au * imu;
46    let x = (ad * imd) >> 64;
47    let (x, c) = x.overflowing_add(au * imd);
48    res += c as u128;
49    let (x, c) = x.overflowing_add(ad * imu);
50    res += c as u128;
51    res + (x >> 64)
52});
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57    use crate::tools::Xorshift;
58
59    macro_rules! test_barrett {
60        ($test_name:ident, $ty:ty, |$rng:ident| $res:expr) => {
61            #[test]
62            fn $test_name() {
63                let mut $rng = Xorshift::default();
64                const Q: usize = 10_000;
65                for _ in 0..Q {
66                    let (a, b): ($ty, $ty) = $res;
67                    let barrett = BarrettReduction::<$ty>::new(b);
68                    assert_eq!(a / b, barrett.div(a));
69                    assert_eq!(a % b, barrett.rem(a));
70                }
71            }
72        };
73    }
74    test_barrett!(test_barrett_u32_small, u32, |rng| (
75        rng.random(..=100),
76        rng.random(1..=100)
77    ));
78    test_barrett!(test_barrett_u64_small, u64, |rng| (
79        rng.random(..=100),
80        rng.random(1..=100)
81    ));
82    test_barrett!(test_barrett_u128_small, u128, |rng| {
83        (
84            rng.random(..=100u64) as u128 * rng.random(..=100u64) as u128,
85            rng.random(1..=100u64) as u128 * rng.random(1..=100u64) as u128,
86        )
87    });
88
89    test_barrett!(test_barrett_u32_large, u32, |rng| (
90        rng.random(..=!0),
91        rng.random(1..=!0)
92    ));
93    test_barrett!(test_barrett_u64_large, u64, |rng| (
94        rng.random(..=!0),
95        rng.random(1..=!0)
96    ));
97    test_barrett!(test_barrett_u128_large, u128, |rng| {
98        (
99            rng.random(..=!0u64) as u128 * rng.random(..=!0u64) as u128,
100            rng.random(1..=!0u64) as u128 * rng.random(1..=!0u64) as u128,
101        )
102    });
103
104    test_barrett!(test_barrett_u32_max, u32, |rng| (
105        rng.random(!0 - 100..=!0),
106        rng.random(!0 - 100..=!0)
107    ));
108    test_barrett!(test_barrett_u64_max, u64, |rng| (
109        rng.random(!0 - 100..=!0),
110        rng.random(!0 - 100..=!0)
111    ));
112    test_barrett!(test_barrett_u128_max, u128, |rng| {
113        (
114            rng.random(!0 - 100..=!0u64) as u128 * rng.random(!0 - 100..=!0u64) as u128,
115            rng.random(!0 - 100..=!0u64) as u128 * rng.random(!0 - 100..=!0u64) as u128,
116        )
117    });
118
119    test_barrett!(test_barrett_u128_mul, u128, |rng| {
120        (
121            rng.random(0u64..) as u128 * rng.random(0u64..) as u128,
122            rng.random(0u64..) as u128,
123        )
124    });
125}