competitive/num/
barrett_reduction.rs1#[derive(Debug, Clone, Copy)]
2pub struct BarrettReduction<T> {
3 m: T,
4 im: T,
5}
6macro_rules! impl_barrett {
7 ($basety:ty, |$a:ident, $im:ident| $quotient:expr) => {
8 impl BarrettReduction<$basety> {
9 pub const fn new(m: $basety) -> Self {
10 Self { m, im: !0 / m }
11 }
12 pub const fn get_mod(&self) -> $basety {
13 self.m
14 }
15 pub const fn div_rem(&self, $a: $basety) -> ($basety, $basety) {
16 if self.m == 1 {
17 return ($a, 0);
18 }
19 let $im = self.im;
20 let mut q = $quotient;
21 let mut r = $a - q * self.m;
22 if self.m <= r {
23 r -= self.m;
24 q += 1;
25 }
26 (q, r)
27 }
28 pub const fn div(&self, a: $basety) -> $basety {
29 self.div_rem(a).0
30 }
31 pub const fn rem(&self, a: $basety) -> $basety {
32 self.div_rem(a).1
33 }
34 }
35 };
36}
37impl_barrett!(u32, |a, im| ((a as u64 * im as u64) >> 32) as u32);
38impl_barrett!(u64, |a, im| ((a as u128 * im as u128) >> 64) as u64);
39impl_barrett!(u128, |a, im| {
40 const MASK64: u128 = 0xffff_ffff_ffff_ffff;
41 let au = a >> 64;
42 let ad = a & MASK64;
43 let imu = im >> 64;
44 let imd = im & MASK64;
45 let mut res = au * imu;
46 let x = (ad * imd) >> 64;
47 let (x, c) = x.overflowing_add(au * imd);
48 res += c as u128;
49 let (x, c) = x.overflowing_add(ad * imu);
50 res += c as u128;
51 res + (x >> 64)
52});
53
54#[cfg(test)]
55mod tests {
56 use super::*;
57 use crate::tools::Xorshift;
58
59 macro_rules! test_barrett {
60 ($test_name:ident, $ty:ty, |$rng:ident| $res:expr) => {
61 #[test]
62 fn $test_name() {
63 let mut $rng = Xorshift::default();
64 const Q: usize = 10_000;
65 for _ in 0..Q {
66 let (a, b): ($ty, $ty) = $res;
67 let barrett = BarrettReduction::<$ty>::new(b);
68 assert_eq!(a / b, barrett.div(a));
69 assert_eq!(a % b, barrett.rem(a));
70 }
71 }
72 };
73 }
74 test_barrett!(test_barrett_u32_small, u32, |rng| (
75 rng.random(..=100),
76 rng.random(1..=100)
77 ));
78 test_barrett!(test_barrett_u64_small, u64, |rng| (
79 rng.random(..=100),
80 rng.random(1..=100)
81 ));
82 test_barrett!(test_barrett_u128_small, u128, |rng| {
83 (
84 rng.random(..=100u64) as u128 * rng.random(..=100u64) as u128,
85 rng.random(1..=100u64) as u128 * rng.random(1..=100u64) as u128,
86 )
87 });
88
89 test_barrett!(test_barrett_u32_large, u32, |rng| (
90 rng.random(..=!0),
91 rng.random(1..=!0)
92 ));
93 test_barrett!(test_barrett_u64_large, u64, |rng| (
94 rng.random(..=!0),
95 rng.random(1..=!0)
96 ));
97 test_barrett!(test_barrett_u128_large, u128, |rng| {
98 (
99 rng.random(..=!0u64) as u128 * rng.random(..=!0u64) as u128,
100 rng.random(1..=!0u64) as u128 * rng.random(1..=!0u64) as u128,
101 )
102 });
103
104 test_barrett!(test_barrett_u32_max, u32, |rng| (
105 rng.random(!0 - 100..=!0),
106 rng.random(!0 - 100..=!0)
107 ));
108 test_barrett!(test_barrett_u64_max, u64, |rng| (
109 rng.random(!0 - 100..=!0),
110 rng.random(!0 - 100..=!0)
111 ));
112 test_barrett!(test_barrett_u128_max, u128, |rng| {
113 (
114 rng.random(!0 - 100..=!0u64) as u128 * rng.random(!0 - 100..=!0u64) as u128,
115 rng.random(!0 - 100..=!0u64) as u128 * rng.random(!0 - 100..=!0u64) as u128,
116 )
117 });
118
119 test_barrett!(test_barrett_u128_mul, u128, |rng| {
120 (
121 rng.random(0u64..) as u128 * rng.random(0u64..) as u128,
122 rng.random(0u64..) as u128,
123 )
124 });
125}