competitive/num/
barrett_reduction.rs

1use super::{One, Zero};
2use std::ops::{Add, Mul, Sub};
3
4#[derive(Debug, Clone, Copy)]
5pub struct BarrettReduction<T> {
6    m: T,
7    im: T,
8}
9
10impl<T> BarrettReduction<T>
11where
12    T: Barrettable,
13{
14    pub fn new(m: T) -> Self {
15        Self {
16            m,
17            im: T::inv_mod_approx(m),
18        }
19    }
20    pub const fn new_with_im(m: T, im: T) -> Self {
21        Self { m, im }
22    }
23    pub const fn get_mod(&self) -> T {
24        self.m
25    }
26    pub fn div_rem(&self, a: T) -> (T, T) {
27        T::barrett_reduce(a, self.m, self.im)
28    }
29    pub fn div(&self, a: T) -> T {
30        self.div_rem(a).0
31    }
32    pub fn rem(&self, a: T) -> T {
33        self.div_rem(a).1
34    }
35}
36
37pub trait Barrettable:
38    Sized
39    + Copy
40    + PartialOrd
41    + Zero
42    + One
43    + Add<Output = Self>
44    + Sub<Output = Self>
45    + Mul<Output = Self>
46{
47    fn inv_mod_approx(m: Self) -> Self;
48    fn div_approx(self, im: Self) -> Self;
49    fn barrett_reduce(self, m: Self, im: Self) -> (Self, Self) {
50        if m == Self::one() {
51            return (self, Self::zero());
52        }
53        let q = self.div_approx(im);
54        let r = self - q * m;
55        if m <= r {
56            (q + Self::one(), r - m)
57        } else {
58            (q, r)
59        }
60    }
61}
62
63impl Barrettable for u32 {
64    fn inv_mod_approx(m: Self) -> Self {
65        !0 / m
66    }
67    fn div_approx(self, im: Self) -> Self {
68        ((self as u64 * im as u64) >> 32) as u32
69    }
70}
71
72impl Barrettable for u64 {
73    fn inv_mod_approx(m: Self) -> Self {
74        !0 / m
75    }
76    fn div_approx(self, im: Self) -> Self {
77        ((self as u128 * im as u128) >> 64) as u64
78    }
79}
80
81impl Barrettable for u128 {
82    fn inv_mod_approx(m: Self) -> Self {
83        !0 / m
84    }
85    fn div_approx(self, im: Self) -> Self {
86        const MASK64: u128 = 0xffff_ffff_ffff_ffff;
87        let au = self >> 64;
88        let ad = self & MASK64;
89        let imu = im >> 64;
90        let imd = im & MASK64;
91        let mut res = au * imu;
92        let x = (ad * imd) >> 64;
93        let (x, c) = x.overflowing_add(au * imd);
94        res += c as u128;
95        let (x, c) = x.overflowing_add(ad * imu);
96        res += c as u128;
97        res + (x >> 64)
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use crate::tools::Xorshift;
105
106    macro_rules! test_barrett {
107        ($test_name:ident, $ty:ty, |$rng:ident| $res:expr) => {
108            #[test]
109            fn $test_name() {
110                let mut $rng = Xorshift::default();
111                const Q: usize = 10_000;
112                for _ in 0..Q {
113                    let (a, b): ($ty, $ty) = $res;
114                    let barrett = BarrettReduction::<$ty>::new(b);
115                    assert_eq!(a / b, barrett.div(a));
116                    assert_eq!(a % b, barrett.rem(a));
117                }
118            }
119        };
120    }
121    test_barrett!(test_barrett_u32_small, u32, |rng| (
122        rng.random(..=100),
123        rng.random(1..=100)
124    ));
125    test_barrett!(test_barrett_u64_small, u64, |rng| (
126        rng.random(..=100),
127        rng.random(1..=100)
128    ));
129    test_barrett!(test_barrett_u128_small, u128, |rng| {
130        (
131            rng.random(..=100u64) as u128 * rng.random(..=100u64) as u128,
132            rng.random(1..=100u64) as u128 * rng.random(1..=100u64) as u128,
133        )
134    });
135
136    test_barrett!(test_barrett_u32_large, u32, |rng| (
137        rng.random(..=!0),
138        rng.random(1..=!0)
139    ));
140    test_barrett!(test_barrett_u64_large, u64, |rng| (
141        rng.random(..=!0),
142        rng.random(1..=!0)
143    ));
144    test_barrett!(test_barrett_u128_large, u128, |rng| {
145        (
146            rng.random(..=!0u64) as u128 * rng.random(..=!0u64) as u128,
147            rng.random(1..=!0u64) as u128 * rng.random(1..=!0u64) as u128,
148        )
149    });
150
151    test_barrett!(test_barrett_u32_max, u32, |rng| (
152        rng.random(!0 - 100..=!0),
153        rng.random(!0 - 100..=!0)
154    ));
155    test_barrett!(test_barrett_u64_max, u64, |rng| (
156        rng.random(!0 - 100..=!0),
157        rng.random(!0 - 100..=!0)
158    ));
159    test_barrett!(test_barrett_u128_max, u128, |rng| {
160        (
161            rng.random(!0 - 100..=!0u64) as u128 * rng.random(!0 - 100..=!0u64) as u128,
162            rng.random(!0 - 100..=!0u64) as u128 * rng.random(!0 - 100..=!0u64) as u128,
163        )
164    });
165
166    test_barrett!(test_barrett_u128_mul, u128, |rng| {
167        (
168            rng.random(0u64..) as u128 * rng.random(0u64..) as u128,
169            rng.random(0u64..) as u128,
170        )
171    });
172}