competitive/num/
barrett_reduction.rs1use super::{One, Zero};
2use std::ops::{Add, Mul, Sub};
3
4#[derive(Debug, Clone, Copy)]
5pub struct BarrettReduction<T> {
6 m: T,
7 im: T,
8}
9
10impl<T> BarrettReduction<T>
11where
12 T: Barrettable,
13{
14 pub fn new(m: T) -> Self {
15 Self {
16 m,
17 im: T::inv_mod_approx(m),
18 }
19 }
20 pub const fn new_with_im(m: T, im: T) -> Self {
21 Self { m, im }
22 }
23 pub const fn get_mod(&self) -> T {
24 self.m
25 }
26 pub fn div_rem(&self, a: T) -> (T, T) {
27 T::barrett_reduce(a, self.m, self.im)
28 }
29 pub fn div(&self, a: T) -> T {
30 self.div_rem(a).0
31 }
32 pub fn rem(&self, a: T) -> T {
33 self.div_rem(a).1
34 }
35}
36
37pub trait Barrettable:
38 Sized
39 + Copy
40 + PartialOrd
41 + Zero
42 + One
43 + Add<Output = Self>
44 + Sub<Output = Self>
45 + Mul<Output = Self>
46{
47 fn inv_mod_approx(m: Self) -> Self;
48 fn div_approx(self, im: Self) -> Self;
49 fn barrett_reduce(self, m: Self, im: Self) -> (Self, Self) {
50 if m == Self::one() {
51 return (self, Self::zero());
52 }
53 let q = self.div_approx(im);
54 let r = self - q * m;
55 if m <= r {
56 (q + Self::one(), r - m)
57 } else {
58 (q, r)
59 }
60 }
61}
62
63impl Barrettable for u32 {
64 fn inv_mod_approx(m: Self) -> Self {
65 !0 / m
66 }
67 fn div_approx(self, im: Self) -> Self {
68 ((self as u64 * im as u64) >> 32) as u32
69 }
70}
71
72impl Barrettable for u64 {
73 fn inv_mod_approx(m: Self) -> Self {
74 !0 / m
75 }
76 fn div_approx(self, im: Self) -> Self {
77 ((self as u128 * im as u128) >> 64) as u64
78 }
79}
80
81impl Barrettable for u128 {
82 fn inv_mod_approx(m: Self) -> Self {
83 !0 / m
84 }
85 fn div_approx(self, im: Self) -> Self {
86 const MASK64: u128 = 0xffff_ffff_ffff_ffff;
87 let au = self >> 64;
88 let ad = self & MASK64;
89 let imu = im >> 64;
90 let imd = im & MASK64;
91 let mut res = au * imu;
92 let x = (ad * imd) >> 64;
93 let (x, c) = x.overflowing_add(au * imd);
94 res += c as u128;
95 let (x, c) = x.overflowing_add(ad * imu);
96 res += c as u128;
97 res + (x >> 64)
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use crate::tools::Xorshift;
105
106 macro_rules! test_barrett {
107 ($test_name:ident, $ty:ty, |$rng:ident| $res:expr) => {
108 #[test]
109 fn $test_name() {
110 let mut $rng = Xorshift::default();
111 const Q: usize = 10_000;
112 for _ in 0..Q {
113 let (a, b): ($ty, $ty) = $res;
114 let barrett = BarrettReduction::<$ty>::new(b);
115 assert_eq!(a / b, barrett.div(a));
116 assert_eq!(a % b, barrett.rem(a));
117 }
118 }
119 };
120 }
121 test_barrett!(test_barrett_u32_small, u32, |rng| (
122 rng.random(..=100),
123 rng.random(1..=100)
124 ));
125 test_barrett!(test_barrett_u64_small, u64, |rng| (
126 rng.random(..=100),
127 rng.random(1..=100)
128 ));
129 test_barrett!(test_barrett_u128_small, u128, |rng| {
130 (
131 rng.random(..=100u64) as u128 * rng.random(..=100u64) as u128,
132 rng.random(1..=100u64) as u128 * rng.random(1..=100u64) as u128,
133 )
134 });
135
136 test_barrett!(test_barrett_u32_large, u32, |rng| (
137 rng.random(..=!0),
138 rng.random(1..=!0)
139 ));
140 test_barrett!(test_barrett_u64_large, u64, |rng| (
141 rng.random(..=!0),
142 rng.random(1..=!0)
143 ));
144 test_barrett!(test_barrett_u128_large, u128, |rng| {
145 (
146 rng.random(..=!0u64) as u128 * rng.random(..=!0u64) as u128,
147 rng.random(1..=!0u64) as u128 * rng.random(1..=!0u64) as u128,
148 )
149 });
150
151 test_barrett!(test_barrett_u32_max, u32, |rng| (
152 rng.random(!0 - 100..=!0),
153 rng.random(!0 - 100..=!0)
154 ));
155 test_barrett!(test_barrett_u64_max, u64, |rng| (
156 rng.random(!0 - 100..=!0),
157 rng.random(!0 - 100..=!0)
158 ));
159 test_barrett!(test_barrett_u128_max, u128, |rng| {
160 (
161 rng.random(!0 - 100..=!0u64) as u128 * rng.random(!0 - 100..=!0u64) as u128,
162 rng.random(!0 - 100..=!0u64) as u128 * rng.random(!0 - 100..=!0u64) as u128,
163 )
164 });
165
166 test_barrett!(test_barrett_u128_mul, u128, |rng| {
167 (
168 rng.random(0u64..) as u128 * rng.random(0u64..) as u128,
169 rng.random(0u64..) as u128,
170 )
171 });
172}