Skip to main content

competitive/tools/
random_generator.rs

1use super::Xorshift;
2use std::{
3    marker::PhantomData,
4    mem::swap,
5    ops::{Bound, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive},
6};
7
8/// Trait for spec of generating random value.
9pub trait RandomSpec<T>: Sized {
10    /// Return a random value.
11    fn rand(&self, rng: &mut Xorshift) -> T;
12    /// Return an iterator that generates random values.
13    fn rand_iter(self, rng: &mut Xorshift) -> RandIter<'_, T, Self> {
14        RandIter {
15            spec: self,
16            rng,
17            _marker: PhantomData,
18        }
19    }
20}
21
22impl Xorshift {
23    pub fn random<T, R>(&mut self, spec: R) -> T
24    where
25        R: RandomSpec<T>,
26    {
27        spec.rand(self)
28    }
29    pub fn random_iter<T, R>(&mut self, spec: R) -> RandIter<'_, T, R>
30    where
31        R: RandomSpec<T>,
32    {
33        spec.rand_iter(self)
34    }
35}
36
37#[derive(Debug)]
38pub struct RandIter<'r, T, R>
39where
40    R: RandomSpec<T>,
41{
42    spec: R,
43    rng: &'r mut Xorshift,
44    _marker: PhantomData<fn() -> T>,
45}
46
47impl<T, R> Iterator for RandIter<'_, T, R>
48where
49    R: RandomSpec<T>,
50{
51    type Item = T;
52    fn next(&mut self) -> Option<Self::Item> {
53        Some(self.spec.rand(self.rng))
54    }
55}
56
57macro_rules! impl_random_spec_range_full {
58    ($($t:ty)*) => {
59        $(impl RandomSpec<$t> for RangeFull {
60            fn rand(&self, rng: &mut Xorshift) -> $t {
61                rng.rand64() as _
62            }
63        })*
64    };
65}
66impl_random_spec_range_full!(u8 u16 u32 u64 usize i8 i16 i32 i64 isize);
67
68impl RandomSpec<u128> for RangeFull {
69    fn rand(&self, rng: &mut Xorshift) -> u128 {
70        ((rng.rand64() as u128) << 64) | rng.rand64() as u128
71    }
72}
73impl RandomSpec<i128> for RangeFull {
74    fn rand(&self, rng: &mut Xorshift) -> i128 {
75        rng.random::<u128, _>(..) as i128
76    }
77}
78
79macro_rules! impl_random_spec_ranges {
80    ($($u:ident $i:ident)*) => {
81        $(
82            impl RandomSpec<$u> for Range<$u> {
83                fn rand(&self, rng: &mut Xorshift) -> $u {
84                    assert!(self.start < self.end);
85                    let len = self.end - self.start;
86                    (self.start + rng.random::<$u, _>(..) % len)
87                }
88            }
89            impl RandomSpec<$i> for Range<$i> {
90                fn rand(&self, rng: &mut Xorshift) -> $i {
91                    assert!(self.start < self.end);
92                    let len = self.end.abs_diff(self.start);
93                    self.start.wrapping_add_unsigned(rng.random::<$u, _>(..) % len)
94                }
95            }
96            impl RandomSpec<$u> for RangeFrom<$u> {
97                fn rand(&self, rng: &mut Xorshift) -> $u {
98                    let len = ($u::MAX - self.start).wrapping_add(1);
99                    let x = rng.random::<$u, _>(..);
100                    self.start + if len != 0 { x % len } else { x }
101                }
102            }
103            impl RandomSpec<$i> for RangeFrom<$i> {
104                fn rand(&self, rng: &mut Xorshift) -> $i {
105                    let len = ($i::MAX.abs_diff(self.start)).wrapping_add(1);
106                    let x = rng.random::<$u, _>(..);
107                    self.start.wrapping_add_unsigned(if len != 0 { x % len } else { x })
108                }
109            }
110            impl RandomSpec<$u> for RangeInclusive<$u> {
111                fn rand(&self, rng: &mut Xorshift) -> $u {
112                    assert!(self.start() <= self.end());
113                    let len = (self.end() - self.start()).wrapping_add(1);
114                    let x = rng.random::<$u, _>(..);
115                    self.start() + if len != 0 { x % len } else { x }
116                }
117            }
118            impl RandomSpec<$i> for RangeInclusive<$i> {
119                fn rand(&self, rng: &mut Xorshift) -> $i {
120                    assert!(self.start() <= self.end());
121                    let len = (self.end().abs_diff(*self.start())).wrapping_add(1);
122                    let x = rng.random::<$u, _>(..);
123                    self.start().wrapping_add_unsigned(if len != 0 { x % len } else { x })
124                }
125            }
126            impl RandomSpec<$u> for RangeTo<$u> {
127                fn rand(&self, rng: &mut Xorshift) -> $u {
128                    let len = self.end;
129                    rng.random::<$u, _>(..) % len
130                }
131            }
132            impl RandomSpec<$i> for RangeTo<$i> {
133                fn rand(&self, rng: &mut Xorshift) -> $i {
134                    let len = self.end.abs_diff($i::MIN);
135                    $i::MIN.wrapping_add_unsigned(rng.random::<$u, _>(..) % len)
136                }
137            }
138            impl RandomSpec<$u> for RangeToInclusive<$u> {
139                fn rand(&self, rng: &mut Xorshift) -> $u {
140                    let len = (self.end).wrapping_add(1);
141                    let x = rng.random::<$u, _>(..);
142                    if len != 0 { x % len } else { x }
143                }
144            }
145            impl RandomSpec<$i> for RangeToInclusive<$i> {
146                fn rand(&self, rng: &mut Xorshift) -> $i {
147                    let len = (self.end.abs_diff($i::MIN)).wrapping_add(1);
148                    let x = rng.random::<$u, _>(..);
149                    $i::MIN.wrapping_add_unsigned(if len != 0 { x % len } else { x })
150                }
151            }
152        )*
153    };
154}
155impl_random_spec_ranges!(u8 i8 u16 i16 u32 i32 u64 i64 u128 i128 usize isize);
156
157macro_rules! impl_random_spec_tuple {
158    ($($T:ident)*, $($R:ident)*, $($v:ident)*) => {
159        impl<$($T),*, $($R),*> RandomSpec<($($T,)*)> for ($($R,)*)
160        where
161            $($R: RandomSpec<$T>),*
162        {
163            fn rand(&self, rng: &mut Xorshift) -> ($($T,)*) {
164                let ($($v,)*) = self;
165                ($(($v).rand(rng),)*)
166            }
167        }
168    };
169}
170impl_random_spec_tuple!(A, RA, a);
171impl_random_spec_tuple!(A B, RA RB, a b);
172impl_random_spec_tuple!(A B C, RA RB RC, a b c);
173impl_random_spec_tuple!(A B C D, RA RB RC RD, a b c d);
174impl_random_spec_tuple!(A B C D E, RA RB RC RD RE, a b c d e);
175impl_random_spec_tuple!(A B C D E F, RA RB RC RD RE RF, a b c d e f);
176impl_random_spec_tuple!(A B C D E F G, RA RB RC RD RE RF RG, a b c d e f g);
177impl_random_spec_tuple!(A B C D E F G H, RA RB RC RD RE RF RG RH, a b c d e f g h);
178impl_random_spec_tuple!(A B C D E F G H I, RA RB RC RD RE RF RG RH RI, a b c d e f g h i);
179impl_random_spec_tuple!(A B C D E F G H I J, RA RB RC RD RE RF RG RH RI RJ, a b c d e f g h i j);
180
181macro_rules! impl_random_spec_primitive {
182    ($($t:ty)*) => {
183        $(impl RandomSpec<$t> for $t {
184            fn rand(&self, _rng: &mut Xorshift) -> $t {
185                *self
186            }
187        })*
188    };
189}
190impl_random_spec_primitive!(() u8 u16 u32 u64 u128 usize i8 i16 i32 i64 i128 isize bool char);
191
192impl<T, R> RandomSpec<T> for &R
193where
194    R: RandomSpec<T>,
195{
196    fn rand(&self, rng: &mut Xorshift) -> T {
197        <R as RandomSpec<T>>::rand(self, rng)
198    }
199}
200impl<T, R> RandomSpec<T> for &mut R
201where
202    R: RandomSpec<T>,
203{
204    fn rand(&self, rng: &mut Xorshift) -> T {
205        <R as RandomSpec<T>>::rand(self, rng)
206    }
207}
208
209#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
210/// Left-close Right-open No Empty Segment
211pub struct NotEmptySegment<T>(pub T);
212impl<T> RandomSpec<(usize, usize)> for NotEmptySegment<T>
213where
214    T: RandomSpec<usize>,
215{
216    fn rand(&self, rng: &mut Xorshift) -> (usize, usize) {
217        let n = rng.random(&self.0) as u64;
218        let k = randint_uniform(rng, n);
219        let l = randint_uniform(rng, n - k) as usize;
220        (l, l + k as usize + 1)
221    }
222}
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
225/// Left-close Right-open Segment
226pub struct WithEmptySegment<T>(pub T);
227impl<T> RandomSpec<(usize, usize)> for WithEmptySegment<T>
228where
229    T: RandomSpec<usize>,
230{
231    fn rand(&self, rng: &mut Xorshift) -> (usize, usize) {
232        let n = rng.random(&self.0) as u64;
233        let k = randint_uniform(rng, n + 1);
234        let l = randint_uniform(rng, n - k + 1) as usize;
235        (l, l + k as usize)
236    }
237}
238
239#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
240pub struct RandRange<Q, T> {
241    data: Q,
242    _marker: PhantomData<fn() -> T>,
243}
244impl<Q, T> RandRange<Q, T> {
245    pub fn new(data: Q) -> Self {
246        Self {
247            data,
248            _marker: PhantomData,
249        }
250    }
251}
252impl<Q, T> RandomSpec<(Bound<T>, Bound<T>)> for RandRange<Q, T>
253where
254    Q: RandomSpec<T>,
255    T: Ord,
256{
257    fn rand(&self, rng: &mut Xorshift) -> (Bound<T>, Bound<T>) {
258        let mut l = rng.random(&self.data);
259        let mut r = rng.random(&self.data);
260        if l > r {
261            swap(&mut l, &mut r);
262        }
263        (
264            match rng.rand(3) {
265                0 => Bound::Excluded(l),
266                1 => Bound::Included(l),
267                _ => Bound::Unbounded,
268            },
269            match rng.rand(3) {
270                0 => Bound::Excluded(r),
271                1 => Bound::Included(r),
272                _ => Bound::Unbounded,
273            },
274        )
275    }
276}
277
278#[inline]
279fn randint_uniform(rng: &mut Xorshift, k: u64) -> u64 {
280    let mut v = rng.rand64();
281    if k > 0 {
282        v %= k;
283    }
284    v
285}
286
287pub struct WeightedSampler {
288    n: usize,
289    prob: Vec<f64>,
290    alias: Vec<usize>,
291}
292
293impl WeightedSampler {
294    pub fn new(weights: impl IntoIterator<Item = f64>) -> Self {
295        let mut weights: Vec<_> = weights.into_iter().collect();
296        let n = weights.len();
297        assert!(n > 0, "weights must be non-empty");
298        let mut prob = vec![0.0; n];
299        let mut alias = vec![0; n];
300        let mut small = vec![];
301        let mut large = vec![];
302        let sum: f64 = weights.iter().sum();
303        assert!(sum > 0.0, "sum of weights must be positive");
304        for (i, weight) in weights.iter_mut().enumerate() {
305            assert!(*weight >= 0.0, "weights must be non-negative");
306            *weight *= n as f64 / sum;
307            if *weight < 1.0 {
308                small.push(i);
309            } else {
310                large.push(i);
311            }
312        }
313        loop {
314            match (small.pop(), large.pop()) {
315                (Some(l), Some(g)) => {
316                    prob[l] = weights[l];
317                    alias[l] = g;
318                    weights[g] -= 1.0 - weights[l];
319                    if weights[g] < 1.0 {
320                        small.push(g);
321                    } else {
322                        large.push(g);
323                    }
324                }
325                (Some(g), None) | (None, Some(g)) => {
326                    prob[g] = 1.0;
327                    alias[g] = g;
328                }
329                (None, None) => break,
330            }
331        }
332        Self { n, prob, alias }
333    }
334}
335
336impl RandomSpec<usize> for WeightedSampler {
337    fn rand(&self, rng: &mut Xorshift) -> usize {
338        let i = rng.rand(self.n as u64) as usize;
339        if rng.randf() < self.prob[i] {
340            i
341        } else {
342            self.alias[i]
343        }
344    }
345}
346
347#[macro_export]
348/// Return a random value using [`RandomSpec`].
349macro_rules! rand_value {
350    (@repeat $rng:expr, [$($t:tt)*] $($len:expr)?)                                    => { ::std::iter::repeat_with(|| $crate::rand_value!(@inner $rng, [] $($t)*)) $(.take($len).collect::<Vec<_>>())? };
351    (@array $rng:expr, [$($t:tt)*] $len:expr)                                         => { $crate::array![|| $crate::rand_value!(@inner $rng, [] $($t)*); $len] };
352    (@tuple $rng:expr, [$([$($args:tt)*])*])                                          => { ($($($args)*),*) };
353    (@$tag:ident $rng:expr, [[$($args:tt)*]])                                         => { $($args)* };
354    (@$tag:ident $rng:expr, [$($args:tt)*] ($($tuple:tt)*) $($t:tt)*)                 => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@tuple $rng, [] $($tuple)*)]] $($t)*) };
355    (@$tag:ident $rng:expr, [$($args:tt)*] [[$($tt:tt)*]; const $len:expr] $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@array $rng, [[$($tt)*]] $len)]] $($t)*) };
356    (@$tag:ident $rng:expr, [$($args:tt)*] [[$($tt:tt)*]; $len:expr] $($t:tt)*)       => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@repeat $rng, [[$($tt)*]] $len)]] $($t)*) };
357    (@$tag:ident $rng:expr, [$($args:tt)*] [($($tt:tt)*); const $len:expr] $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@array $rng, [($($tt)*)] $len)]] $($t)*) };
358    (@$tag:ident $rng:expr, [$($args:tt)*] [($($tt:tt)*); $len:expr] $($t:tt)*)       => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@repeat $rng, [($($tt)*)] $len)]] $($t)*) };
359    (@$tag:ident $rng:expr, [$($args:tt)*] [$ty:expr; const $len:expr] $($t:tt)*)     => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@array $rng, [$ty] $len)]] $($t)*) };
360    (@$tag:ident $rng:expr, [$($args:tt)*] [$ty:expr; $len:expr] $($t:tt)*)           => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@repeat $rng, [$ty] $len)]] $($t)*) };
361    (@$tag:ident $rng:expr, [$($args:tt)*] [$($tt:tt)*] $($t:tt)*)                    => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@repeat $rng, [$($tt)*])]] $($t)*) };
362    (@$tag:ident $rng:expr, [$($args:tt)*] $ty:expr)                                  => { $crate::rand_value!(@$tag $rng, [$($args)* [($rng).random($ty)]]) };
363    (@$tag:ident $rng:expr, [$($args:tt)*] $ty:expr, $($t:tt)*)                       => { $crate::rand_value!(@$tag $rng, [$($args)* [($rng).random($ty)]] $($t)*) };
364    (@$tag:ident $rng:expr, [$($args:tt)*] , $($t:tt)*)                               => { $crate::rand_value!(@$tag $rng, [$($args)*] $($t)*) };
365    (@$tag:ident $rng:expr, [$($args:tt)*])                                           => { ::std::compile_error!(::std::stringify!($($args)*)) };
366    (seed = $src:expr, $($t:tt)*)                                                     => { { let mut __rng = Xorshift::new_with_seed($src); $crate::rand_value!(@inner __rng, [] $($t)*) } };
367    ($rng:expr, $($t:tt)*)                                                            => { $crate::rand_value!(@inner $rng, [] $($t)*) }
368}
369#[macro_export]
370/// Declare random values using [`RandomSpec`].
371macro_rules! rand {
372    (@assert $p:pat) => {};
373    (@assert $($p:tt)*) => { ::std::compile_error!(::std::concat!("expected pattern, found `", ::std::stringify!($($p)*), "`")); };
374    (@pat $rng:expr, [] [])                                          => {};
375    (@pat $rng:expr, [] [] , $($t:tt)*)                              => { $crate::rand!(@pat $rng, [] [] $($t)*) };
376    (@pat $rng:expr, [$($p:tt)*] [] $x:ident $($t:tt)*)              => { $crate::rand!(@pat $rng, [$($p)* $x] [] $($t)*) };
377    (@pat $rng:expr, [$($p:tt)*] [] :: $($t:tt)*)                    => { $crate::rand!(@pat $rng, [$($p)* ::] [] $($t)*) };
378    (@pat $rng:expr, [$($p:tt)*] [] & $($t:tt)*)                     => { $crate::rand!(@pat $rng, [$($p)* &] [] $($t)*) };
379    (@pat $rng:expr, [$($p:tt)*] [] ($($x:tt)*) $($t:tt)*)           => { $crate::rand!(@pat $rng, [$($p)* ($($x)*)] [] $($t)*) };
380    (@pat $rng:expr, [$($p:tt)*] [] [$($x:tt)*] $($t:tt)*)           => { $crate::rand!(@pat $rng, [$($p)* [$($x)*]] [] $($t)*) };
381    (@pat $rng:expr, [$($p:tt)*] [] {$($x:tt)*} $($t:tt)*)           => { $crate::rand!(@pat $rng, [$($p)* {$($x)*}] [] $($t)*) };
382    (@pat $rng:expr, [$($p:tt)*] [] : $($t:tt)*)                     => { $crate::rand!(@ty  $rng, [$($p)*] [] $($t)*) };
383    (@ty  $rng:expr, [$($p:tt)*] [$($tt:tt)*] ($($x:tt)*) $($t:tt)*) => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* ($($x)*)] $($t)*) };
384    (@ty  $rng:expr, [$($p:tt)*] [$($tt:tt)*] [$($x:tt)*] $($t:tt)*) => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* [$($x)*]] $($t)*) };
385    (@ty  $rng:expr, [$($p:tt)*] [$($tt:tt)*] $e:expr)               => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* $e]) };
386    (@ty  $rng:expr, [$($p:tt)*] [$($tt:tt)*] $e:expr, $($t:tt)*)    => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* $e], $($t)*) };
387    (@ty  $rng:expr, [$($p:tt)*] [$($tt:tt)*] $e:tt)                 => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* $e]) };
388    (@ty  $rng:expr, [$($p:tt)*] [$($tt:tt)*] $e:tt, $($t:tt)*)      => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* $e], $($t)*) };
389    (@let $rng:expr, [$($p:tt)*] [$($tt:tt)*] $($t:tt)*) => {
390        $crate::rand!{@assert $($p)*}
391        let $($p)* = $crate::rand_value!($rng, $($tt)*);
392        $crate::rand!(@pat $rng, [] [] $($t)*)
393    };
394    ($rng:expr) => {};
395    ($rng:expr, $($t:tt)*) => { $crate::rand!(@pat $rng, [] [] $($t)*) };
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401
402    #[test]
403    fn test_random_range() {
404        let mut rng = Xorshift::default();
405        assert_eq!(rng.random(1i32..2), 1);
406        assert_eq!(rng.random(1u32..2), 1);
407        assert_eq!(rng.random(1i32..=1), 1);
408        assert_eq!(rng.random(1u32..=1), 1);
409        assert_eq!(rng.random(i32::MAX..), i32::MAX);
410        assert_eq!(rng.random(u32::MAX..), u32::MAX);
411        assert_eq!(rng.random(..=i32::MIN), i32::MIN);
412        assert_eq!(rng.random(..=u32::MIN), u32::MIN);
413    }
414
415    #[test]
416    fn test_random_segment() {
417        let mut rng = Xorshift::default();
418        for _ in 0..100_000 {
419            let n = (1..=1_000_000).rand(&mut rng);
420            let (l, r) = NotEmptySegment(n).rand(&mut rng);
421            assert!(l < r);
422            assert!(r <= n);
423        }
424
425        const N_SMALL: usize = 100;
426        let mut set = std::collections::HashSet::new();
427        for _ in 0..100_000 {
428            let (l, r) = NotEmptySegment(N_SMALL).rand(&mut rng);
429            assert!(l < r);
430            assert!(r <= N_SMALL);
431            set.insert((l, r));
432        }
433        assert!(set.len() == N_SMALL * (N_SMALL + 1) / 2);
434    }
435
436    #[test]
437    fn test_rand_macro() {
438        let mut rng = Xorshift::default();
439        rand!(
440            rng,
441            _x: ..10,
442            _lr: NotEmptySegment(10),
443            _a: [..10; 10],
444            _t: (..10,),
445            _r: (&(..10),&mut (..10)),
446            _p: [(1..=10,2..=10); 2]
447        );
448    }
449
450    #[test]
451    fn test_weighted_sampler() {
452        let mut rng = Xorshift::default();
453        let weights = vec![1.0, 2.0, 3.0, 4.0];
454        let sampler = WeightedSampler::new(weights.clone());
455        let mut counts = vec![0; weights.len()];
456        for _ in 0..1_000_000 {
457            let idx = sampler.rand(&mut rng);
458            counts[idx] += 1;
459        }
460        for i in 0..weights.len() {
461            let expected = weights[i] / weights.iter().sum::<f64>();
462            let actual = counts[i] as f64 / 1_000_000.0;
463            assert!((expected - actual).abs() < 0.01);
464        }
465    }
466}