1use super::Xorshift;
2use std::{
3 marker::PhantomData,
4 mem::swap,
5 ops::{Bound, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive},
6};
7
8pub trait RandomSpec<T>: Sized {
10 fn rand(&self, rng: &mut Xorshift) -> T;
12 fn rand_iter(self, rng: &mut Xorshift) -> RandIter<'_, T, Self> {
14 RandIter {
15 spec: self,
16 rng,
17 _marker: PhantomData,
18 }
19 }
20}
21
22impl Xorshift {
23 pub fn random<T, R>(&mut self, spec: R) -> T
24 where
25 R: RandomSpec<T>,
26 {
27 spec.rand(self)
28 }
29 pub fn random_iter<T, R>(&mut self, spec: R) -> RandIter<'_, T, R>
30 where
31 R: RandomSpec<T>,
32 {
33 spec.rand_iter(self)
34 }
35}
36
37#[derive(Debug)]
38pub struct RandIter<'r, T, R>
39where
40 R: RandomSpec<T>,
41{
42 spec: R,
43 rng: &'r mut Xorshift,
44 _marker: PhantomData<fn() -> T>,
45}
46
47impl<T, R> Iterator for RandIter<'_, T, R>
48where
49 R: RandomSpec<T>,
50{
51 type Item = T;
52 fn next(&mut self) -> Option<Self::Item> {
53 Some(self.spec.rand(self.rng))
54 }
55}
56
57macro_rules! impl_random_spec_range_full {
58 ($($t:ty)*) => {
59 $(impl RandomSpec<$t> for RangeFull {
60 fn rand(&self, rng: &mut Xorshift) -> $t {
61 rng.rand64() as _
62 }
63 })*
64 };
65}
66impl_random_spec_range_full!(u8 u16 u32 u64 usize i8 i16 i32 i64 isize);
67
68impl RandomSpec<u128> for RangeFull {
69 fn rand(&self, rng: &mut Xorshift) -> u128 {
70 ((rng.rand64() as u128) << 64) | rng.rand64() as u128
71 }
72}
73impl RandomSpec<i128> for RangeFull {
74 fn rand(&self, rng: &mut Xorshift) -> i128 {
75 rng.random::<u128, _>(..) as i128
76 }
77}
78
79macro_rules! impl_random_spec_ranges {
80 ($($u:ident $i:ident)*) => {
81 $(
82 impl RandomSpec<$u> for Range<$u> {
83 fn rand(&self, rng: &mut Xorshift) -> $u {
84 assert!(self.start < self.end);
85 let len = self.end - self.start;
86 (self.start + rng.random::<$u, _>(..) % len)
87 }
88 }
89 impl RandomSpec<$i> for Range<$i> {
90 fn rand(&self, rng: &mut Xorshift) -> $i {
91 assert!(self.start < self.end);
92 let len = self.end.abs_diff(self.start);
93 self.start.wrapping_add_unsigned(rng.random::<$u, _>(..) % len)
94 }
95 }
96 impl RandomSpec<$u> for RangeFrom<$u> {
97 fn rand(&self, rng: &mut Xorshift) -> $u {
98 let len = ($u::MAX - self.start).wrapping_add(1);
99 let x = rng.random::<$u, _>(..);
100 self.start + if len != 0 { x % len } else { x }
101 }
102 }
103 impl RandomSpec<$i> for RangeFrom<$i> {
104 fn rand(&self, rng: &mut Xorshift) -> $i {
105 let len = ($i::MAX.abs_diff(self.start)).wrapping_add(1);
106 let x = rng.random::<$u, _>(..);
107 self.start.wrapping_add_unsigned(if len != 0 { x % len } else { x })
108 }
109 }
110 impl RandomSpec<$u> for RangeInclusive<$u> {
111 fn rand(&self, rng: &mut Xorshift) -> $u {
112 assert!(self.start() <= self.end());
113 let len = (self.end() - self.start()).wrapping_add(1);
114 let x = rng.random::<$u, _>(..);
115 self.start() + if len != 0 { x % len } else { x }
116 }
117 }
118 impl RandomSpec<$i> for RangeInclusive<$i> {
119 fn rand(&self, rng: &mut Xorshift) -> $i {
120 assert!(self.start() <= self.end());
121 let len = (self.end().abs_diff(*self.start())).wrapping_add(1);
122 let x = rng.random::<$u, _>(..);
123 self.start().wrapping_add_unsigned(if len != 0 { x % len } else { x })
124 }
125 }
126 impl RandomSpec<$u> for RangeTo<$u> {
127 fn rand(&self, rng: &mut Xorshift) -> $u {
128 let len = self.end;
129 rng.random::<$u, _>(..) % len
130 }
131 }
132 impl RandomSpec<$i> for RangeTo<$i> {
133 fn rand(&self, rng: &mut Xorshift) -> $i {
134 let len = self.end.abs_diff($i::MIN);
135 $i::MIN.wrapping_add_unsigned(rng.random::<$u, _>(..) % len)
136 }
137 }
138 impl RandomSpec<$u> for RangeToInclusive<$u> {
139 fn rand(&self, rng: &mut Xorshift) -> $u {
140 let len = (self.end).wrapping_add(1);
141 let x = rng.random::<$u, _>(..);
142 if len != 0 { x % len } else { x }
143 }
144 }
145 impl RandomSpec<$i> for RangeToInclusive<$i> {
146 fn rand(&self, rng: &mut Xorshift) -> $i {
147 let len = (self.end.abs_diff($i::MIN)).wrapping_add(1);
148 let x = rng.random::<$u, _>(..);
149 $i::MIN.wrapping_add_unsigned(if len != 0 { x % len } else { x })
150 }
151 }
152 )*
153 };
154}
155impl_random_spec_ranges!(u8 i8 u16 i16 u32 i32 u64 i64 u128 i128 usize isize);
156
157macro_rules! impl_random_spec_tuple {
158 ($($T:ident)*, $($R:ident)*, $($v:ident)*) => {
159 impl<$($T),*, $($R),*> RandomSpec<($($T,)*)> for ($($R,)*)
160 where
161 $($R: RandomSpec<$T>),*
162 {
163 fn rand(&self, rng: &mut Xorshift) -> ($($T,)*) {
164 let ($($v,)*) = self;
165 ($(($v).rand(rng),)*)
166 }
167 }
168 };
169}
170impl_random_spec_tuple!(A, RA, a);
171impl_random_spec_tuple!(A B, RA RB, a b);
172impl_random_spec_tuple!(A B C, RA RB RC, a b c);
173impl_random_spec_tuple!(A B C D, RA RB RC RD, a b c d);
174impl_random_spec_tuple!(A B C D E, RA RB RC RD RE, a b c d e);
175impl_random_spec_tuple!(A B C D E F, RA RB RC RD RE RF, a b c d e f);
176impl_random_spec_tuple!(A B C D E F G, RA RB RC RD RE RF RG, a b c d e f g);
177impl_random_spec_tuple!(A B C D E F G H, RA RB RC RD RE RF RG RH, a b c d e f g h);
178impl_random_spec_tuple!(A B C D E F G H I, RA RB RC RD RE RF RG RH RI, a b c d e f g h i);
179impl_random_spec_tuple!(A B C D E F G H I J, RA RB RC RD RE RF RG RH RI RJ, a b c d e f g h i j);
180
181macro_rules! impl_random_spec_primitive {
182 ($($t:ty)*) => {
183 $(impl RandomSpec<$t> for $t {
184 fn rand(&self, _rng: &mut Xorshift) -> $t {
185 *self
186 }
187 })*
188 };
189}
190impl_random_spec_primitive!(() u8 u16 u32 u64 u128 usize i8 i16 i32 i64 i128 isize bool char);
191
192impl<T, R> RandomSpec<T> for &R
193where
194 R: RandomSpec<T>,
195{
196 fn rand(&self, rng: &mut Xorshift) -> T {
197 <R as RandomSpec<T>>::rand(self, rng)
198 }
199}
200impl<T, R> RandomSpec<T> for &mut R
201where
202 R: RandomSpec<T>,
203{
204 fn rand(&self, rng: &mut Xorshift) -> T {
205 <R as RandomSpec<T>>::rand(self, rng)
206 }
207}
208
209#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
210pub struct NotEmptySegment<T>(pub T);
212impl<T> RandomSpec<(usize, usize)> for NotEmptySegment<T>
213where
214 T: RandomSpec<usize>,
215{
216 fn rand(&self, rng: &mut Xorshift) -> (usize, usize) {
217 let n = rng.random(&self.0) as u64;
218 let k = randint_uniform(rng, n);
219 let l = randint_uniform(rng, n - k) as usize;
220 (l, l + k as usize + 1)
221 }
222}
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
225pub struct WithEmptySegment<T>(pub T);
227impl<T> RandomSpec<(usize, usize)> for WithEmptySegment<T>
228where
229 T: RandomSpec<usize>,
230{
231 fn rand(&self, rng: &mut Xorshift) -> (usize, usize) {
232 let n = rng.random(&self.0) as u64;
233 let k = randint_uniform(rng, n + 1);
234 let l = randint_uniform(rng, n - k + 1) as usize;
235 (l, l + k as usize)
236 }
237}
238
239#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
240pub struct RandRange<Q, T> {
241 data: Q,
242 _marker: PhantomData<fn() -> T>,
243}
244impl<Q, T> RandRange<Q, T> {
245 pub fn new(data: Q) -> Self {
246 Self {
247 data,
248 _marker: PhantomData,
249 }
250 }
251}
252impl<Q, T> RandomSpec<(Bound<T>, Bound<T>)> for RandRange<Q, T>
253where
254 Q: RandomSpec<T>,
255 T: Ord,
256{
257 fn rand(&self, rng: &mut Xorshift) -> (Bound<T>, Bound<T>) {
258 let mut l = rng.random(&self.data);
259 let mut r = rng.random(&self.data);
260 if l > r {
261 swap(&mut l, &mut r);
262 }
263 (
264 match rng.rand(3) {
265 0 => Bound::Excluded(l),
266 1 => Bound::Included(l),
267 _ => Bound::Unbounded,
268 },
269 match rng.rand(3) {
270 0 => Bound::Excluded(r),
271 1 => Bound::Included(r),
272 _ => Bound::Unbounded,
273 },
274 )
275 }
276}
277
278#[inline]
279fn randint_uniform(rng: &mut Xorshift, k: u64) -> u64 {
280 let mut v = rng.rand64();
281 if k > 0 {
282 v %= k;
283 }
284 v
285}
286
287pub struct WeightedSampler {
288 n: usize,
289 prob: Vec<f64>,
290 alias: Vec<usize>,
291}
292
293impl WeightedSampler {
294 pub fn new(weights: impl IntoIterator<Item = f64>) -> Self {
295 let mut weights: Vec<_> = weights.into_iter().collect();
296 let n = weights.len();
297 assert!(n > 0, "weights must be non-empty");
298 let mut prob = vec![0.0; n];
299 let mut alias = vec![0; n];
300 let mut small = vec![];
301 let mut large = vec![];
302 let sum: f64 = weights.iter().sum();
303 assert!(sum > 0.0, "sum of weights must be positive");
304 for (i, weight) in weights.iter_mut().enumerate() {
305 assert!(*weight >= 0.0, "weights must be non-negative");
306 *weight *= n as f64 / sum;
307 if *weight < 1.0 {
308 small.push(i);
309 } else {
310 large.push(i);
311 }
312 }
313 loop {
314 match (small.pop(), large.pop()) {
315 (Some(l), Some(g)) => {
316 prob[l] = weights[l];
317 alias[l] = g;
318 weights[g] -= 1.0 - weights[l];
319 if weights[g] < 1.0 {
320 small.push(g);
321 } else {
322 large.push(g);
323 }
324 }
325 (Some(g), None) | (None, Some(g)) => {
326 prob[g] = 1.0;
327 alias[g] = g;
328 }
329 (None, None) => break,
330 }
331 }
332 Self { n, prob, alias }
333 }
334}
335
336impl RandomSpec<usize> for WeightedSampler {
337 fn rand(&self, rng: &mut Xorshift) -> usize {
338 let i = rng.rand(self.n as u64) as usize;
339 if rng.randf() < self.prob[i] {
340 i
341 } else {
342 self.alias[i]
343 }
344 }
345}
346
347#[macro_export]
348macro_rules! rand_value {
350 (@repeat $rng:expr, [$($t:tt)*] $($len:expr)?) => { ::std::iter::repeat_with(|| $crate::rand_value!(@inner $rng, [] $($t)*)) $(.take($len).collect::<Vec<_>>())? };
351 (@array $rng:expr, [$($t:tt)*] $len:expr) => { $crate::array![|| $crate::rand_value!(@inner $rng, [] $($t)*); $len] };
352 (@tuple $rng:expr, [$([$($args:tt)*])*]) => { ($($($args)*),*) };
353 (@$tag:ident $rng:expr, [[$($args:tt)*]]) => { $($args)* };
354 (@$tag:ident $rng:expr, [$($args:tt)*] ($($tuple:tt)*) $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@tuple $rng, [] $($tuple)*)]] $($t)*) };
355 (@$tag:ident $rng:expr, [$($args:tt)*] [[$($tt:tt)*]; const $len:expr] $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@array $rng, [[$($tt)*]] $len)]] $($t)*) };
356 (@$tag:ident $rng:expr, [$($args:tt)*] [[$($tt:tt)*]; $len:expr] $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@repeat $rng, [[$($tt)*]] $len)]] $($t)*) };
357 (@$tag:ident $rng:expr, [$($args:tt)*] [($($tt:tt)*); const $len:expr] $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@array $rng, [($($tt)*)] $len)]] $($t)*) };
358 (@$tag:ident $rng:expr, [$($args:tt)*] [($($tt:tt)*); $len:expr] $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@repeat $rng, [($($tt)*)] $len)]] $($t)*) };
359 (@$tag:ident $rng:expr, [$($args:tt)*] [$ty:expr; const $len:expr] $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@array $rng, [$ty] $len)]] $($t)*) };
360 (@$tag:ident $rng:expr, [$($args:tt)*] [$ty:expr; $len:expr] $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@repeat $rng, [$ty] $len)]] $($t)*) };
361 (@$tag:ident $rng:expr, [$($args:tt)*] [$($tt:tt)*] $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@repeat $rng, [$($tt)*])]] $($t)*) };
362 (@$tag:ident $rng:expr, [$($args:tt)*] $ty:expr) => { $crate::rand_value!(@$tag $rng, [$($args)* [($rng).random($ty)]]) };
363 (@$tag:ident $rng:expr, [$($args:tt)*] $ty:expr, $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [($rng).random($ty)]] $($t)*) };
364 (@$tag:ident $rng:expr, [$($args:tt)*] , $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)*] $($t)*) };
365 (@$tag:ident $rng:expr, [$($args:tt)*]) => { ::std::compile_error!(::std::stringify!($($args)*)) };
366 (seed = $src:expr, $($t:tt)*) => { { let mut __rng = Xorshift::new_with_seed($src); $crate::rand_value!(@inner __rng, [] $($t)*) } };
367 ($rng:expr, $($t:tt)*) => { $crate::rand_value!(@inner $rng, [] $($t)*) }
368}
369#[macro_export]
370macro_rules! rand {
372 (@assert $p:pat) => {};
373 (@assert $($p:tt)*) => { ::std::compile_error!(::std::concat!("expected pattern, found `", ::std::stringify!($($p)*), "`")); };
374 (@pat $rng:expr, [] []) => {};
375 (@pat $rng:expr, [] [] , $($t:tt)*) => { $crate::rand!(@pat $rng, [] [] $($t)*) };
376 (@pat $rng:expr, [$($p:tt)*] [] $x:ident $($t:tt)*) => { $crate::rand!(@pat $rng, [$($p)* $x] [] $($t)*) };
377 (@pat $rng:expr, [$($p:tt)*] [] :: $($t:tt)*) => { $crate::rand!(@pat $rng, [$($p)* ::] [] $($t)*) };
378 (@pat $rng:expr, [$($p:tt)*] [] & $($t:tt)*) => { $crate::rand!(@pat $rng, [$($p)* &] [] $($t)*) };
379 (@pat $rng:expr, [$($p:tt)*] [] ($($x:tt)*) $($t:tt)*) => { $crate::rand!(@pat $rng, [$($p)* ($($x)*)] [] $($t)*) };
380 (@pat $rng:expr, [$($p:tt)*] [] [$($x:tt)*] $($t:tt)*) => { $crate::rand!(@pat $rng, [$($p)* [$($x)*]] [] $($t)*) };
381 (@pat $rng:expr, [$($p:tt)*] [] {$($x:tt)*} $($t:tt)*) => { $crate::rand!(@pat $rng, [$($p)* {$($x)*}] [] $($t)*) };
382 (@pat $rng:expr, [$($p:tt)*] [] : $($t:tt)*) => { $crate::rand!(@ty $rng, [$($p)*] [] $($t)*) };
383 (@ty $rng:expr, [$($p:tt)*] [$($tt:tt)*] ($($x:tt)*) $($t:tt)*) => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* ($($x)*)] $($t)*) };
384 (@ty $rng:expr, [$($p:tt)*] [$($tt:tt)*] [$($x:tt)*] $($t:tt)*) => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* [$($x)*]] $($t)*) };
385 (@ty $rng:expr, [$($p:tt)*] [$($tt:tt)*] $e:expr) => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* $e]) };
386 (@ty $rng:expr, [$($p:tt)*] [$($tt:tt)*] $e:expr, $($t:tt)*) => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* $e], $($t)*) };
387 (@ty $rng:expr, [$($p:tt)*] [$($tt:tt)*] $e:tt) => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* $e]) };
388 (@ty $rng:expr, [$($p:tt)*] [$($tt:tt)*] $e:tt, $($t:tt)*) => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* $e], $($t)*) };
389 (@let $rng:expr, [$($p:tt)*] [$($tt:tt)*] $($t:tt)*) => {
390 $crate::rand!{@assert $($p)*}
391 let $($p)* = $crate::rand_value!($rng, $($tt)*);
392 $crate::rand!(@pat $rng, [] [] $($t)*)
393 };
394 ($rng:expr) => {};
395 ($rng:expr, $($t:tt)*) => { $crate::rand!(@pat $rng, [] [] $($t)*) };
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401
402 #[test]
403 fn test_random_range() {
404 let mut rng = Xorshift::default();
405 assert_eq!(rng.random(1i32..2), 1);
406 assert_eq!(rng.random(1u32..2), 1);
407 assert_eq!(rng.random(1i32..=1), 1);
408 assert_eq!(rng.random(1u32..=1), 1);
409 assert_eq!(rng.random(i32::MAX..), i32::MAX);
410 assert_eq!(rng.random(u32::MAX..), u32::MAX);
411 assert_eq!(rng.random(..=i32::MIN), i32::MIN);
412 assert_eq!(rng.random(..=u32::MIN), u32::MIN);
413 }
414
415 #[test]
416 fn test_random_segment() {
417 let mut rng = Xorshift::default();
418 for _ in 0..100_000 {
419 let n = (1..=1_000_000).rand(&mut rng);
420 let (l, r) = NotEmptySegment(n).rand(&mut rng);
421 assert!(l < r);
422 assert!(r <= n);
423 }
424
425 const N_SMALL: usize = 100;
426 let mut set = std::collections::HashSet::new();
427 for _ in 0..100_000 {
428 let (l, r) = NotEmptySegment(N_SMALL).rand(&mut rng);
429 assert!(l < r);
430 assert!(r <= N_SMALL);
431 set.insert((l, r));
432 }
433 assert!(set.len() == N_SMALL * (N_SMALL + 1) / 2);
434 }
435
436 #[test]
437 fn test_rand_macro() {
438 let mut rng = Xorshift::default();
439 rand!(
440 rng,
441 _x: ..10,
442 _lr: NotEmptySegment(10),
443 _a: [..10; 10],
444 _t: (..10,),
445 _r: (&(..10),&mut (..10)),
446 _p: [(1..=10,2..=10); 2]
447 );
448 }
449
450 #[test]
451 fn test_weighted_sampler() {
452 let mut rng = Xorshift::default();
453 let weights = vec![1.0, 2.0, 3.0, 4.0];
454 let sampler = WeightedSampler::new(weights.clone());
455 let mut counts = vec![0; weights.len()];
456 for _ in 0..1_000_000 {
457 let idx = sampler.rand(&mut rng);
458 counts[idx] += 1;
459 }
460 for i in 0..weights.len() {
461 let expected = weights[i] / weights.iter().sum::<f64>();
462 let actual = counts[i] as f64 / 1_000_000.0;
463 assert!((expected - actual).abs() < 0.01);
464 }
465 }
466}