1use super::Xorshift;
2use std::{
3 marker::PhantomData,
4 mem::swap,
5 ops::{Bound, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive},
6};
7
8pub trait RandomSpec<T>: Sized {
10 fn rand(&self, rng: &mut Xorshift) -> T;
12 fn rand_iter(self, rng: &mut Xorshift) -> RandIter<'_, T, Self> {
14 RandIter {
15 spec: self,
16 rng,
17 _marker: PhantomData,
18 }
19 }
20}
21
22impl Xorshift {
23 pub fn random<T, R>(&mut self, spec: R) -> T
24 where
25 R: RandomSpec<T>,
26 {
27 spec.rand(self)
28 }
29 pub fn random_iter<T, R>(&mut self, spec: R) -> RandIter<'_, T, R>
30 where
31 R: RandomSpec<T>,
32 {
33 spec.rand_iter(self)
34 }
35}
36
37#[derive(Debug)]
38pub struct RandIter<'r, T, R>
39where
40 R: RandomSpec<T>,
41{
42 spec: R,
43 rng: &'r mut Xorshift,
44 _marker: PhantomData<fn() -> T>,
45}
46
47impl<T, R> Iterator for RandIter<'_, T, R>
48where
49 R: RandomSpec<T>,
50{
51 type Item = T;
52 fn next(&mut self) -> Option<Self::Item> {
53 Some(self.spec.rand(self.rng))
54 }
55}
56
57macro_rules! impl_random_spec_range_full {
58 ($($t:ty)*) => {
59 $(impl RandomSpec<$t> for RangeFull {
60 fn rand(&self, rng: &mut Xorshift) -> $t {
61 rng.rand64() as _
62 }
63 })*
64 };
65}
66impl_random_spec_range_full!(u8 u16 u32 u64 usize i8 i16 i32 i64 isize);
67
68impl RandomSpec<u128> for RangeFull {
69 fn rand(&self, rng: &mut Xorshift) -> u128 {
70 ((rng.rand64() as u128) << 64) | rng.rand64() as u128
71 }
72}
73impl RandomSpec<i128> for RangeFull {
74 fn rand(&self, rng: &mut Xorshift) -> i128 {
75 rng.random::<u128, _>(..) as i128
76 }
77}
78
79macro_rules! impl_random_spec_ranges {
80 ($($u:ident $i:ident)*) => {
81 $(
82 impl RandomSpec<$u> for Range<$u> {
83 fn rand(&self, rng: &mut Xorshift) -> $u {
84 assert!(self.start < self.end);
85 let len = self.end - self.start;
86 (self.start + rng.random::<$u, _>(..) % len)
87 }
88 }
89 impl RandomSpec<$i> for Range<$i> {
90 fn rand(&self, rng: &mut Xorshift) -> $i {
91 assert!(self.start < self.end);
92 let len = self.end.abs_diff(self.start);
93 self.start.wrapping_add_unsigned(rng.random::<$u, _>(..) % len)
94 }
95 }
96 impl RandomSpec<$u> for RangeFrom<$u> {
97 fn rand(&self, rng: &mut Xorshift) -> $u {
98 let len = ($u::MAX - self.start).wrapping_add(1);
99 let x = rng.random::<$u, _>(..);
100 self.start + if len != 0 { x % len } else { x }
101 }
102 }
103 impl RandomSpec<$i> for RangeFrom<$i> {
104 fn rand(&self, rng: &mut Xorshift) -> $i {
105 let len = ($i::MAX.abs_diff(self.start)).wrapping_add(1);
106 let x = rng.random::<$u, _>(..);
107 self.start.wrapping_add_unsigned(if len != 0 { x % len } else { x })
108 }
109 }
110 impl RandomSpec<$u> for RangeInclusive<$u> {
111 fn rand(&self, rng: &mut Xorshift) -> $u {
112 assert!(self.start() <= self.end());
113 let len = (self.end() - self.start()).wrapping_add(1);
114 let x = rng.random::<$u, _>(..);
115 self.start() + if len != 0 { x % len } else { x }
116 }
117 }
118 impl RandomSpec<$i> for RangeInclusive<$i> {
119 fn rand(&self, rng: &mut Xorshift) -> $i {
120 assert!(self.start() <= self.end());
121 let len = (self.end().abs_diff(*self.start())).wrapping_add(1);
122 let x = rng.random::<$u, _>(..);
123 self.start().wrapping_add_unsigned(if len != 0 { x % len } else { x })
124 }
125 }
126 impl RandomSpec<$u> for RangeTo<$u> {
127 fn rand(&self, rng: &mut Xorshift) -> $u {
128 let len = self.end;
129 rng.random::<$u, _>(..) % len
130 }
131 }
132 impl RandomSpec<$i> for RangeTo<$i> {
133 fn rand(&self, rng: &mut Xorshift) -> $i {
134 let len = self.end.abs_diff($i::MIN);
135 $i::MIN.wrapping_add_unsigned(rng.random::<$u, _>(..) % len)
136 }
137 }
138 impl RandomSpec<$u> for RangeToInclusive<$u> {
139 fn rand(&self, rng: &mut Xorshift) -> $u {
140 let len = (self.end).wrapping_add(1);
141 let x = rng.random::<$u, _>(..);
142 if len != 0 { x % len } else { x }
143 }
144 }
145 impl RandomSpec<$i> for RangeToInclusive<$i> {
146 fn rand(&self, rng: &mut Xorshift) -> $i {
147 let len = (self.end.abs_diff($i::MIN)).wrapping_add(1);
148 let x = rng.random::<$u, _>(..);
149 $i::MIN.wrapping_add_unsigned(if len != 0 { x % len } else { x })
150 }
151 }
152 )*
153 };
154}
155impl_random_spec_ranges!(u8 i8 u16 i16 u32 i32 u64 i64 u128 i128 usize isize);
156
157macro_rules! impl_random_spec_tuple {
158 ($($T:ident)*, $($R:ident)*, $($v:ident)*) => {
159 impl<$($T),*, $($R),*> RandomSpec<($($T,)*)> for ($($R,)*)
160 where
161 $($R: RandomSpec<$T>),*
162 {
163 fn rand(&self, rng: &mut Xorshift) -> ($($T,)*) {
164 let ($($v,)*) = self;
165 ($(($v).rand(rng),)*)
166 }
167 }
168 };
169}
170impl_random_spec_tuple!(A, RA, a);
171impl_random_spec_tuple!(A B, RA RB, a b);
172impl_random_spec_tuple!(A B C, RA RB RC, a b c);
173impl_random_spec_tuple!(A B C D, RA RB RC RD, a b c d);
174impl_random_spec_tuple!(A B C D E, RA RB RC RD RE, a b c d e);
175impl_random_spec_tuple!(A B C D E F, RA RB RC RD RE RF, a b c d e f);
176impl_random_spec_tuple!(A B C D E F G, RA RB RC RD RE RF RG, a b c d e f g);
177impl_random_spec_tuple!(A B C D E F G H, RA RB RC RD RE RF RG RH, a b c d e f g h);
178impl_random_spec_tuple!(A B C D E F G H I, RA RB RC RD RE RF RG RH RI, a b c d e f g h i);
179impl_random_spec_tuple!(A B C D E F G H I J, RA RB RC RD RE RF RG RH RI RJ, a b c d e f g h i j);
180
181macro_rules! impl_random_spec_primitive {
182 ($($t:ty)*) => {
183 $(impl RandomSpec<$t> for $t {
184 fn rand(&self, _rng: &mut Xorshift) -> $t {
185 *self
186 }
187 })*
188 };
189}
190impl_random_spec_primitive!(() u8 u16 u32 u64 u128 usize i8 i16 i32 i64 i128 isize bool char);
191
192impl<T, R> RandomSpec<T> for &R
193where
194 R: RandomSpec<T>,
195{
196 fn rand(&self, rng: &mut Xorshift) -> T {
197 <R as RandomSpec<T>>::rand(self, rng)
198 }
199}
200impl<T, R> RandomSpec<T> for &mut R
201where
202 R: RandomSpec<T>,
203{
204 fn rand(&self, rng: &mut Xorshift) -> T {
205 <R as RandomSpec<T>>::rand(self, rng)
206 }
207}
208
209#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
210pub struct NotEmptySegment<T>(pub T);
212impl<T> RandomSpec<(usize, usize)> for NotEmptySegment<T>
213where
214 T: RandomSpec<usize>,
215{
216 fn rand(&self, rng: &mut Xorshift) -> (usize, usize) {
217 let n = rng.random(&self.0) as u64;
218 let k = randint_uniform(rng, n);
219 let l = randint_uniform(rng, n - k) as usize;
220 (l, l + k as usize + 1)
221 }
222}
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
225pub struct RandRange<Q, T> {
226 data: Q,
227 _marker: PhantomData<fn() -> T>,
228}
229impl<Q, T> RandRange<Q, T> {
230 pub fn new(data: Q) -> Self {
231 Self {
232 data,
233 _marker: PhantomData,
234 }
235 }
236}
237impl<Q, T> RandomSpec<(Bound<T>, Bound<T>)> for RandRange<Q, T>
238where
239 Q: RandomSpec<T>,
240 T: Ord,
241{
242 fn rand(&self, rng: &mut Xorshift) -> (Bound<T>, Bound<T>) {
243 let mut l = rng.random(&self.data);
244 let mut r = rng.random(&self.data);
245 if l > r {
246 swap(&mut l, &mut r);
247 }
248 (
249 match rng.rand(3) {
250 0 => Bound::Excluded(l),
251 1 => Bound::Included(l),
252 _ => Bound::Unbounded,
253 },
254 match rng.rand(3) {
255 0 => Bound::Excluded(r),
256 1 => Bound::Included(r),
257 _ => Bound::Unbounded,
258 },
259 )
260 }
261}
262
263#[inline]
264fn randint_uniform(rng: &mut Xorshift, k: u64) -> u64 {
265 let mut v = rng.rand64();
266 if k > 0 {
267 v %= k;
268 }
269 v
270}
271
272pub struct WeightedSampler {
273 n: usize,
274 prob: Vec<f64>,
275 alias: Vec<usize>,
276}
277
278impl WeightedSampler {
279 pub fn new(weights: impl IntoIterator<Item = f64>) -> Self {
280 let mut weights: Vec<_> = weights.into_iter().collect();
281 let n = weights.len();
282 assert!(n > 0, "weights must be non-empty");
283 let mut prob = vec![0.0; n];
284 let mut alias = vec![0; n];
285 let mut small = vec![];
286 let mut large = vec![];
287 let sum: f64 = weights.iter().sum();
288 assert!(sum > 0.0, "sum of weights must be positive");
289 for (i, weight) in weights.iter_mut().enumerate() {
290 assert!(*weight >= 0.0, "weights must be non-negative");
291 *weight *= n as f64 / sum;
292 if *weight < 1.0 {
293 small.push(i);
294 } else {
295 large.push(i);
296 }
297 }
298 loop {
299 match (small.pop(), large.pop()) {
300 (Some(l), Some(g)) => {
301 prob[l] = weights[l];
302 alias[l] = g;
303 weights[g] -= 1.0 - weights[l];
304 if weights[g] < 1.0 {
305 small.push(g);
306 } else {
307 large.push(g);
308 }
309 }
310 (Some(g), None) | (None, Some(g)) => {
311 prob[g] = 1.0;
312 alias[g] = g;
313 }
314 (None, None) => break,
315 }
316 }
317 Self { n, prob, alias }
318 }
319}
320
321impl RandomSpec<usize> for WeightedSampler {
322 fn rand(&self, rng: &mut Xorshift) -> usize {
323 let i = rng.rand(self.n as u64) as usize;
324 if rng.randf() < self.prob[i] {
325 i
326 } else {
327 self.alias[i]
328 }
329 }
330}
331
332#[macro_export]
333macro_rules! rand_value {
335 (@repeat $rng:expr, [$($t:tt)*] $($len:expr)?) => { ::std::iter::repeat_with(|| $crate::rand_value!(@inner $rng, [] $($t)*)) $(.take($len).collect::<Vec<_>>())? };
336 (@array $rng:expr, [$($t:tt)*] $len:expr) => { $crate::array![|| $crate::rand_value!(@inner $rng, [] $($t)*); $len] };
337 (@tuple $rng:expr, [$([$($args:tt)*])*]) => { ($($($args)*,)*) };
338 (@$tag:ident $rng:expr, [[$($args:tt)*]]) => { $($args)* };
339 (@$tag:ident $rng:expr, [$($args:tt)*] ($($tuple:tt)*) $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@tuple $rng, [] $($tuple)*)]] $($t)*) };
340 (@$tag:ident $rng:expr, [$($args:tt)*] [[$($tt:tt)*]; const $len:expr] $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@array $rng, [[$($tt)*]] $len)]] $($t)*) };
341 (@$tag:ident $rng:expr, [$($args:tt)*] [[$($tt:tt)*]; $len:expr] $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@repeat $rng, [[$($tt)*]] $len)]] $($t)*) };
342 (@$tag:ident $rng:expr, [$($args:tt)*] [($($tt:tt)*); const $len:expr] $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@array $rng, [($($tt)*)] $len)]] $($t)*) };
343 (@$tag:ident $rng:expr, [$($args:tt)*] [($($tt:tt)*); $len:expr] $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@repeat $rng, [($($tt)*)] $len)]] $($t)*) };
344 (@$tag:ident $rng:expr, [$($args:tt)*] [$ty:expr; const $len:expr] $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@array $rng, [$ty] $len)]] $($t)*) };
345 (@$tag:ident $rng:expr, [$($args:tt)*] [$ty:expr; $len:expr] $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@repeat $rng, [$ty] $len)]] $($t)*) };
346 (@$tag:ident $rng:expr, [$($args:tt)*] [$($tt:tt)*] $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@repeat $rng, [$($tt)*])]] $($t)*) };
347 (@$tag:ident $rng:expr, [$($args:tt)*] $ty:expr) => { $crate::rand_value!(@$tag $rng, [$($args)* [($rng).random($ty)]]) };
348 (@$tag:ident $rng:expr, [$($args:tt)*] $ty:expr, $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [($rng).random($ty)]] $($t)*) };
349 (@$tag:ident $rng:expr, [$($args:tt)*] , $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)*] $($t)*) };
350 (@$tag:ident $rng:expr, [$($args:tt)*]) => { ::std::compile_error!(::std::stringify!($($args)*)) };
351 (seed = $src:expr, $($t:tt)*) => { { let mut __rng = Xorshift::new_with_seed($src); $crate::rand_value!(@inner __rng, [] $($t)*) } };
352 ($rng:expr, $($t:tt)*) => { $crate::rand_value!(@inner $rng, [] $($t)*) }
353}
354#[macro_export]
355macro_rules! rand {
357 (@assert $p:pat) => {};
358 (@assert $($p:tt)*) => { ::std::compile_error!(::std::concat!("expected pattern, found `", ::std::stringify!($($p)*), "`")); };
359 (@pat $rng:expr, [] []) => {};
360 (@pat $rng:expr, [] [] , $($t:tt)*) => { $crate::rand!(@pat $rng, [] [] $($t)*) };
361 (@pat $rng:expr, [$($p:tt)*] [] $x:ident $($t:tt)*) => { $crate::rand!(@pat $rng, [$($p)* $x] [] $($t)*) };
362 (@pat $rng:expr, [$($p:tt)*] [] :: $($t:tt)*) => { $crate::rand!(@pat $rng, [$($p)* ::] [] $($t)*) };
363 (@pat $rng:expr, [$($p:tt)*] [] & $($t:tt)*) => { $crate::rand!(@pat $rng, [$($p)* &] [] $($t)*) };
364 (@pat $rng:expr, [$($p:tt)*] [] ($($x:tt)*) $($t:tt)*) => { $crate::rand!(@pat $rng, [$($p)* ($($x)*)] [] $($t)*) };
365 (@pat $rng:expr, [$($p:tt)*] [] [$($x:tt)*] $($t:tt)*) => { $crate::rand!(@pat $rng, [$($p)* [$($x)*]] [] $($t)*) };
366 (@pat $rng:expr, [$($p:tt)*] [] {$($x:tt)*} $($t:tt)*) => { $crate::rand!(@pat $rng, [$($p)* {$($x)*}] [] $($t)*) };
367 (@pat $rng:expr, [$($p:tt)*] [] : $($t:tt)*) => { $crate::rand!(@ty $rng, [$($p)*] [] $($t)*) };
368 (@ty $rng:expr, [$($p:tt)*] [$($tt:tt)*] ($($x:tt)*) $($t:tt)*) => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* ($($x)*)] $($t)*) };
369 (@ty $rng:expr, [$($p:tt)*] [$($tt:tt)*] [$($x:tt)*] $($t:tt)*) => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* [$($x)*]] $($t)*) };
370 (@ty $rng:expr, [$($p:tt)*] [$($tt:tt)*] $e:expr) => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* $e]) };
371 (@ty $rng:expr, [$($p:tt)*] [$($tt:tt)*] $e:expr, $($t:tt)*) => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* $e], $($t)*) };
372 (@ty $rng:expr, [$($p:tt)*] [$($tt:tt)*] $e:tt) => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* $e]) };
373 (@ty $rng:expr, [$($p:tt)*] [$($tt:tt)*] $e:tt, $($t:tt)*) => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* $e], $($t)*) };
374 (@let $rng:expr, [$($p:tt)*] [$($tt:tt)*] $($t:tt)*) => {
375 $crate::rand!{@assert $($p)*}
376 let $($p)* = $crate::rand_value!($rng, $($tt)*);
377 $crate::rand!(@pat $rng, [] [] $($t)*)
378 };
379 ($rng:expr) => {};
380 ($rng:expr, $($t:tt)*) => { $crate::rand!(@pat $rng, [] [] $($t)*) };
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386
387 #[test]
388 fn test_random_range() {
389 let mut rng = Xorshift::default();
390 assert_eq!(rng.random(1i32..2), 1);
391 assert_eq!(rng.random(1u32..2), 1);
392 assert_eq!(rng.random(1i32..=1), 1);
393 assert_eq!(rng.random(1u32..=1), 1);
394 assert_eq!(rng.random(i32::MAX..), i32::MAX);
395 assert_eq!(rng.random(u32::MAX..), u32::MAX);
396 assert_eq!(rng.random(..=i32::MIN), i32::MIN);
397 assert_eq!(rng.random(..=u32::MIN), u32::MIN);
398 }
399
400 #[test]
401 fn test_random_segment() {
402 let mut rng = Xorshift::default();
403 for _ in 0..100_000 {
404 let n = (1..=1_000_000).rand(&mut rng);
405 let (l, r) = NotEmptySegment(n).rand(&mut rng);
406 assert!(l < r);
407 assert!(r <= n);
408 }
409
410 const N_SMALL: usize = 100;
411 let mut set = std::collections::HashSet::new();
412 for _ in 0..100_000 {
413 let (l, r) = NotEmptySegment(N_SMALL).rand(&mut rng);
414 assert!(l < r);
415 assert!(r <= N_SMALL);
416 set.insert((l, r));
417 }
418 assert!(set.len() == N_SMALL * (N_SMALL + 1) / 2);
419 }
420
421 #[test]
422 fn test_rand_macro() {
423 let mut rng = Xorshift::default();
424 rand!(
425 rng,
426 _x: ..10,
427 _lr: NotEmptySegment(10),
428 _a: [..10; 10],
429 _t: (..10,),
430 _r: (&(..10),&mut (..10)),
431 _p: [(1..=10,2..=10); 2]
432 );
433 }
434
435 #[test]
436 fn test_weighted_sampler() {
437 let mut rng = Xorshift::default();
438 let weights = vec![1.0, 2.0, 3.0, 4.0];
439 let sampler = WeightedSampler::new(weights.clone());
440 let mut counts = vec![0; weights.len()];
441 for _ in 0..1_000_000 {
442 let idx = sampler.rand(&mut rng);
443 counts[idx] += 1;
444 }
445 for i in 0..weights.len() {
446 let expected = weights[i] / weights.iter().sum::<f64>();
447 let actual = counts[i] as f64 / 1_000_000.0;
448 assert!((expected - actual).abs() < 0.01);
449 }
450 }
451}