competitive/tools/
random_generator.rs

1use super::Xorshift;
2use std::{
3    marker::PhantomData,
4    mem::swap,
5    ops::{Bound, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive},
6};
7
8/// Trait for spec of generating random value.
9pub trait RandomSpec<T>: Sized {
10    /// Return a random value.
11    fn rand(&self, rng: &mut Xorshift) -> T;
12    /// Return an iterator that generates random values.
13    fn rand_iter(self, rng: &mut Xorshift) -> RandIter<'_, T, Self> {
14        RandIter {
15            spec: self,
16            rng,
17            _marker: PhantomData,
18        }
19    }
20}
21
22impl Xorshift {
23    pub fn random<T, R>(&mut self, spec: R) -> T
24    where
25        R: RandomSpec<T>,
26    {
27        spec.rand(self)
28    }
29    pub fn random_iter<T, R>(&mut self, spec: R) -> RandIter<'_, T, R>
30    where
31        R: RandomSpec<T>,
32    {
33        spec.rand_iter(self)
34    }
35}
36
37#[derive(Debug)]
38pub struct RandIter<'r, T, R>
39where
40    R: RandomSpec<T>,
41{
42    spec: R,
43    rng: &'r mut Xorshift,
44    _marker: PhantomData<fn() -> T>,
45}
46
47impl<T, R> Iterator for RandIter<'_, T, R>
48where
49    R: RandomSpec<T>,
50{
51    type Item = T;
52    fn next(&mut self) -> Option<Self::Item> {
53        Some(self.spec.rand(self.rng))
54    }
55}
56
57macro_rules! impl_random_spec_range_full {
58    ($($t:ty)*) => {
59        $(impl RandomSpec<$t> for RangeFull {
60            fn rand(&self, rng: &mut Xorshift) -> $t {
61                rng.rand64() as _
62            }
63        })*
64    };
65}
66impl_random_spec_range_full!(u8 u16 u32 u64 usize i8 i16 i32 i64 isize);
67
68impl RandomSpec<u128> for RangeFull {
69    fn rand(&self, rng: &mut Xorshift) -> u128 {
70        ((rng.rand64() as u128) << 64) | rng.rand64() as u128
71    }
72}
73impl RandomSpec<i128> for RangeFull {
74    fn rand(&self, rng: &mut Xorshift) -> i128 {
75        rng.random::<u128, _>(..) as i128
76    }
77}
78
79macro_rules! impl_random_spec_ranges {
80    ($($u:ident $i:ident)*) => {
81        $(
82            impl RandomSpec<$u> for Range<$u> {
83                fn rand(&self, rng: &mut Xorshift) -> $u {
84                    assert!(self.start < self.end);
85                    let len = self.end - self.start;
86                    (self.start + rng.random::<$u, _>(..) % len)
87                }
88            }
89            impl RandomSpec<$i> for Range<$i> {
90                fn rand(&self, rng: &mut Xorshift) -> $i {
91                    assert!(self.start < self.end);
92                    let len = self.end.abs_diff(self.start);
93                    self.start.wrapping_add_unsigned(rng.random::<$u, _>(..) % len)
94                }
95            }
96            impl RandomSpec<$u> for RangeFrom<$u> {
97                fn rand(&self, rng: &mut Xorshift) -> $u {
98                    let len = ($u::MAX - self.start).wrapping_add(1);
99                    let x = rng.random::<$u, _>(..);
100                    self.start + if len != 0 { x % len } else { x }
101                }
102            }
103            impl RandomSpec<$i> for RangeFrom<$i> {
104                fn rand(&self, rng: &mut Xorshift) -> $i {
105                    let len = ($i::MAX.abs_diff(self.start)).wrapping_add(1);
106                    let x = rng.random::<$u, _>(..);
107                    self.start.wrapping_add_unsigned(if len != 0 { x % len } else { x })
108                }
109            }
110            impl RandomSpec<$u> for RangeInclusive<$u> {
111                fn rand(&self, rng: &mut Xorshift) -> $u {
112                    assert!(self.start() <= self.end());
113                    let len = (self.end() - self.start()).wrapping_add(1);
114                    let x = rng.random::<$u, _>(..);
115                    self.start() + if len != 0 { x % len } else { x }
116                }
117            }
118            impl RandomSpec<$i> for RangeInclusive<$i> {
119                fn rand(&self, rng: &mut Xorshift) -> $i {
120                    assert!(self.start() <= self.end());
121                    let len = (self.end().abs_diff(*self.start())).wrapping_add(1);
122                    let x = rng.random::<$u, _>(..);
123                    self.start().wrapping_add_unsigned(if len != 0 { x % len } else { x })
124                }
125            }
126            impl RandomSpec<$u> for RangeTo<$u> {
127                fn rand(&self, rng: &mut Xorshift) -> $u {
128                    let len = self.end;
129                    rng.random::<$u, _>(..) % len
130                }
131            }
132            impl RandomSpec<$i> for RangeTo<$i> {
133                fn rand(&self, rng: &mut Xorshift) -> $i {
134                    let len = self.end.abs_diff($i::MIN);
135                    $i::MIN.wrapping_add_unsigned(rng.random::<$u, _>(..) % len)
136                }
137            }
138            impl RandomSpec<$u> for RangeToInclusive<$u> {
139                fn rand(&self, rng: &mut Xorshift) -> $u {
140                    let len = (self.end).wrapping_add(1);
141                    let x = rng.random::<$u, _>(..);
142                    if len != 0 { x % len } else { x }
143                }
144            }
145            impl RandomSpec<$i> for RangeToInclusive<$i> {
146                fn rand(&self, rng: &mut Xorshift) -> $i {
147                    let len = (self.end.abs_diff($i::MIN)).wrapping_add(1);
148                    let x = rng.random::<$u, _>(..);
149                    $i::MIN.wrapping_add_unsigned(if len != 0 { x % len } else { x })
150                }
151            }
152        )*
153    };
154}
155impl_random_spec_ranges!(u8 i8 u16 i16 u32 i32 u64 i64 u128 i128 usize isize);
156
157macro_rules! impl_random_spec_tuple {
158    ($($T:ident)*, $($R:ident)*, $($v:ident)*) => {
159        impl<$($T),*, $($R),*> RandomSpec<($($T,)*)> for ($($R,)*)
160        where
161            $($R: RandomSpec<$T>),*
162        {
163            fn rand(&self, rng: &mut Xorshift) -> ($($T,)*) {
164                let ($($v,)*) = self;
165                ($(($v).rand(rng),)*)
166            }
167        }
168    };
169}
170impl_random_spec_tuple!(A, RA, a);
171impl_random_spec_tuple!(A B, RA RB, a b);
172impl_random_spec_tuple!(A B C, RA RB RC, a b c);
173impl_random_spec_tuple!(A B C D, RA RB RC RD, a b c d);
174impl_random_spec_tuple!(A B C D E, RA RB RC RD RE, a b c d e);
175impl_random_spec_tuple!(A B C D E F, RA RB RC RD RE RF, a b c d e f);
176impl_random_spec_tuple!(A B C D E F G, RA RB RC RD RE RF RG, a b c d e f g);
177impl_random_spec_tuple!(A B C D E F G H, RA RB RC RD RE RF RG RH, a b c d e f g h);
178impl_random_spec_tuple!(A B C D E F G H I, RA RB RC RD RE RF RG RH RI, a b c d e f g h i);
179impl_random_spec_tuple!(A B C D E F G H I J, RA RB RC RD RE RF RG RH RI RJ, a b c d e f g h i j);
180
181macro_rules! impl_random_spec_primitive {
182    ($($t:ty)*) => {
183        $(impl RandomSpec<$t> for $t {
184            fn rand(&self, _rng: &mut Xorshift) -> $t {
185                *self
186            }
187        })*
188    };
189}
190impl_random_spec_primitive!(() u8 u16 u32 u64 u128 usize i8 i16 i32 i64 i128 isize bool char);
191
192impl<T, R> RandomSpec<T> for &R
193where
194    R: RandomSpec<T>,
195{
196    fn rand(&self, rng: &mut Xorshift) -> T {
197        <R as RandomSpec<T>>::rand(self, rng)
198    }
199}
200impl<T, R> RandomSpec<T> for &mut R
201where
202    R: RandomSpec<T>,
203{
204    fn rand(&self, rng: &mut Xorshift) -> T {
205        <R as RandomSpec<T>>::rand(self, rng)
206    }
207}
208
209#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
210/// Left-close Right-open No Empty Segment
211pub struct NotEmptySegment<T>(pub T);
212impl<T> RandomSpec<(usize, usize)> for NotEmptySegment<T>
213where
214    T: RandomSpec<usize>,
215{
216    fn rand(&self, rng: &mut Xorshift) -> (usize, usize) {
217        let n = rng.random(&self.0) as u64;
218        let k = randint_uniform(rng, n);
219        let l = randint_uniform(rng, n - k) as usize;
220        (l, l + k as usize + 1)
221    }
222}
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
225pub struct RandRange<Q, T> {
226    data: Q,
227    _marker: PhantomData<fn() -> T>,
228}
229impl<Q, T> RandRange<Q, T> {
230    pub fn new(data: Q) -> Self {
231        Self {
232            data,
233            _marker: PhantomData,
234        }
235    }
236}
237impl<Q, T> RandomSpec<(Bound<T>, Bound<T>)> for RandRange<Q, T>
238where
239    Q: RandomSpec<T>,
240    T: Ord,
241{
242    fn rand(&self, rng: &mut Xorshift) -> (Bound<T>, Bound<T>) {
243        let mut l = rng.random(&self.data);
244        let mut r = rng.random(&self.data);
245        if l > r {
246            swap(&mut l, &mut r);
247        }
248        (
249            match rng.rand(3) {
250                0 => Bound::Excluded(l),
251                1 => Bound::Included(l),
252                _ => Bound::Unbounded,
253            },
254            match rng.rand(3) {
255                0 => Bound::Excluded(r),
256                1 => Bound::Included(r),
257                _ => Bound::Unbounded,
258            },
259        )
260    }
261}
262
263#[inline]
264fn randint_uniform(rng: &mut Xorshift, k: u64) -> u64 {
265    let mut v = rng.rand64();
266    if k > 0 {
267        v %= k;
268    }
269    v
270}
271
272pub struct WeightedSampler {
273    n: usize,
274    prob: Vec<f64>,
275    alias: Vec<usize>,
276}
277
278impl WeightedSampler {
279    pub fn new(weights: impl IntoIterator<Item = f64>) -> Self {
280        let mut weights: Vec<_> = weights.into_iter().collect();
281        let n = weights.len();
282        assert!(n > 0, "weights must be non-empty");
283        let mut prob = vec![0.0; n];
284        let mut alias = vec![0; n];
285        let mut small = vec![];
286        let mut large = vec![];
287        let sum: f64 = weights.iter().sum();
288        assert!(sum > 0.0, "sum of weights must be positive");
289        for (i, weight) in weights.iter_mut().enumerate() {
290            assert!(*weight >= 0.0, "weights must be non-negative");
291            *weight *= n as f64 / sum;
292            if *weight < 1.0 {
293                small.push(i);
294            } else {
295                large.push(i);
296            }
297        }
298        loop {
299            match (small.pop(), large.pop()) {
300                (Some(l), Some(g)) => {
301                    prob[l] = weights[l];
302                    alias[l] = g;
303                    weights[g] -= 1.0 - weights[l];
304                    if weights[g] < 1.0 {
305                        small.push(g);
306                    } else {
307                        large.push(g);
308                    }
309                }
310                (Some(g), None) | (None, Some(g)) => {
311                    prob[g] = 1.0;
312                    alias[g] = g;
313                }
314                (None, None) => break,
315            }
316        }
317        Self { n, prob, alias }
318    }
319}
320
321impl RandomSpec<usize> for WeightedSampler {
322    fn rand(&self, rng: &mut Xorshift) -> usize {
323        let i = rng.rand(self.n as u64) as usize;
324        if rng.randf() < self.prob[i] {
325            i
326        } else {
327            self.alias[i]
328        }
329    }
330}
331
332#[macro_export]
333/// Return a random value using [`RandomSpec`].
334macro_rules! rand_value {
335    (@repeat $rng:expr, [$($t:tt)*] $($len:expr)?)                                    => { ::std::iter::repeat_with(|| $crate::rand_value!(@inner $rng, [] $($t)*)) $(.take($len).collect::<Vec<_>>())? };
336    (@array $rng:expr, [$($t:tt)*] $len:expr)                                         => { $crate::array![|| $crate::rand_value!(@inner $rng, [] $($t)*); $len] };
337    (@tuple $rng:expr, [$([$($args:tt)*])*])                                          => { ($($($args)*,)*) };
338    (@$tag:ident $rng:expr, [[$($args:tt)*]])                                         => { $($args)* };
339    (@$tag:ident $rng:expr, [$($args:tt)*] ($($tuple:tt)*) $($t:tt)*)                 => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@tuple $rng, [] $($tuple)*)]] $($t)*) };
340    (@$tag:ident $rng:expr, [$($args:tt)*] [[$($tt:tt)*]; const $len:expr] $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@array $rng, [[$($tt)*]] $len)]] $($t)*) };
341    (@$tag:ident $rng:expr, [$($args:tt)*] [[$($tt:tt)*]; $len:expr] $($t:tt)*)       => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@repeat $rng, [[$($tt)*]] $len)]] $($t)*) };
342    (@$tag:ident $rng:expr, [$($args:tt)*] [($($tt:tt)*); const $len:expr] $($t:tt)*) => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@array $rng, [($($tt)*)] $len)]] $($t)*) };
343    (@$tag:ident $rng:expr, [$($args:tt)*] [($($tt:tt)*); $len:expr] $($t:tt)*)       => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@repeat $rng, [($($tt)*)] $len)]] $($t)*) };
344    (@$tag:ident $rng:expr, [$($args:tt)*] [$ty:expr; const $len:expr] $($t:tt)*)     => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@array $rng, [$ty] $len)]] $($t)*) };
345    (@$tag:ident $rng:expr, [$($args:tt)*] [$ty:expr; $len:expr] $($t:tt)*)           => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@repeat $rng, [$ty] $len)]] $($t)*) };
346    (@$tag:ident $rng:expr, [$($args:tt)*] [$($tt:tt)*] $($t:tt)*)                    => { $crate::rand_value!(@$tag $rng, [$($args)* [$crate::rand_value!(@repeat $rng, [$($tt)*])]] $($t)*) };
347    (@$tag:ident $rng:expr, [$($args:tt)*] $ty:expr)                                  => { $crate::rand_value!(@$tag $rng, [$($args)* [($rng).random($ty)]]) };
348    (@$tag:ident $rng:expr, [$($args:tt)*] $ty:expr, $($t:tt)*)                       => { $crate::rand_value!(@$tag $rng, [$($args)* [($rng).random($ty)]] $($t)*) };
349    (@$tag:ident $rng:expr, [$($args:tt)*] , $($t:tt)*)                               => { $crate::rand_value!(@$tag $rng, [$($args)*] $($t)*) };
350    (@$tag:ident $rng:expr, [$($args:tt)*])                                           => { ::std::compile_error!(::std::stringify!($($args)*)) };
351    (seed = $src:expr, $($t:tt)*)                                                     => { { let mut __rng = Xorshift::new_with_seed($src); $crate::rand_value!(@inner __rng, [] $($t)*) } };
352    ($rng:expr, $($t:tt)*)                                                            => { $crate::rand_value!(@inner $rng, [] $($t)*) }
353}
354#[macro_export]
355/// Declare random values using [`RandomSpec`].
356macro_rules! rand {
357    (@assert $p:pat) => {};
358    (@assert $($p:tt)*) => { ::std::compile_error!(::std::concat!("expected pattern, found `", ::std::stringify!($($p)*), "`")); };
359    (@pat $rng:expr, [] [])                                          => {};
360    (@pat $rng:expr, [] [] , $($t:tt)*)                              => { $crate::rand!(@pat $rng, [] [] $($t)*) };
361    (@pat $rng:expr, [$($p:tt)*] [] $x:ident $($t:tt)*)              => { $crate::rand!(@pat $rng, [$($p)* $x] [] $($t)*) };
362    (@pat $rng:expr, [$($p:tt)*] [] :: $($t:tt)*)                    => { $crate::rand!(@pat $rng, [$($p)* ::] [] $($t)*) };
363    (@pat $rng:expr, [$($p:tt)*] [] & $($t:tt)*)                     => { $crate::rand!(@pat $rng, [$($p)* &] [] $($t)*) };
364    (@pat $rng:expr, [$($p:tt)*] [] ($($x:tt)*) $($t:tt)*)           => { $crate::rand!(@pat $rng, [$($p)* ($($x)*)] [] $($t)*) };
365    (@pat $rng:expr, [$($p:tt)*] [] [$($x:tt)*] $($t:tt)*)           => { $crate::rand!(@pat $rng, [$($p)* [$($x)*]] [] $($t)*) };
366    (@pat $rng:expr, [$($p:tt)*] [] {$($x:tt)*} $($t:tt)*)           => { $crate::rand!(@pat $rng, [$($p)* {$($x)*}] [] $($t)*) };
367    (@pat $rng:expr, [$($p:tt)*] [] : $($t:tt)*)                     => { $crate::rand!(@ty  $rng, [$($p)*] [] $($t)*) };
368    (@ty  $rng:expr, [$($p:tt)*] [$($tt:tt)*] ($($x:tt)*) $($t:tt)*) => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* ($($x)*)] $($t)*) };
369    (@ty  $rng:expr, [$($p:tt)*] [$($tt:tt)*] [$($x:tt)*] $($t:tt)*) => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* [$($x)*]] $($t)*) };
370    (@ty  $rng:expr, [$($p:tt)*] [$($tt:tt)*] $e:expr)               => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* $e]) };
371    (@ty  $rng:expr, [$($p:tt)*] [$($tt:tt)*] $e:expr, $($t:tt)*)    => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* $e], $($t)*) };
372    (@ty  $rng:expr, [$($p:tt)*] [$($tt:tt)*] $e:tt)                 => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* $e]) };
373    (@ty  $rng:expr, [$($p:tt)*] [$($tt:tt)*] $e:tt, $($t:tt)*)      => { $crate::rand!(@let $rng, [$($p)*] [$($tt)* $e], $($t)*) };
374    (@let $rng:expr, [$($p:tt)*] [$($tt:tt)*] $($t:tt)*) => {
375        $crate::rand!{@assert $($p)*}
376        let $($p)* = $crate::rand_value!($rng, $($tt)*);
377        $crate::rand!(@pat $rng, [] [] $($t)*)
378    };
379    ($rng:expr) => {};
380    ($rng:expr, $($t:tt)*) => { $crate::rand!(@pat $rng, [] [] $($t)*) };
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    #[test]
388    fn test_random_range() {
389        let mut rng = Xorshift::default();
390        assert_eq!(rng.random(1i32..2), 1);
391        assert_eq!(rng.random(1u32..2), 1);
392        assert_eq!(rng.random(1i32..=1), 1);
393        assert_eq!(rng.random(1u32..=1), 1);
394        assert_eq!(rng.random(i32::MAX..), i32::MAX);
395        assert_eq!(rng.random(u32::MAX..), u32::MAX);
396        assert_eq!(rng.random(..=i32::MIN), i32::MIN);
397        assert_eq!(rng.random(..=u32::MIN), u32::MIN);
398    }
399
400    #[test]
401    fn test_random_segment() {
402        let mut rng = Xorshift::default();
403        for _ in 0..100_000 {
404            let n = (1..=1_000_000).rand(&mut rng);
405            let (l, r) = NotEmptySegment(n).rand(&mut rng);
406            assert!(l < r);
407            assert!(r <= n);
408        }
409
410        const N_SMALL: usize = 100;
411        let mut set = std::collections::HashSet::new();
412        for _ in 0..100_000 {
413            let (l, r) = NotEmptySegment(N_SMALL).rand(&mut rng);
414            assert!(l < r);
415            assert!(r <= N_SMALL);
416            set.insert((l, r));
417        }
418        assert!(set.len() == N_SMALL * (N_SMALL + 1) / 2);
419    }
420
421    #[test]
422    fn test_rand_macro() {
423        let mut rng = Xorshift::default();
424        rand!(
425            rng,
426            _x: ..10,
427            _lr: NotEmptySegment(10),
428            _a: [..10; 10],
429            _t: (..10,),
430            _r: (&(..10),&mut (..10)),
431            _p: [(1..=10,2..=10); 2]
432        );
433    }
434
435    #[test]
436    fn test_weighted_sampler() {
437        let mut rng = Xorshift::default();
438        let weights = vec![1.0, 2.0, 3.0, 4.0];
439        let sampler = WeightedSampler::new(weights.clone());
440        let mut counts = vec![0; weights.len()];
441        for _ in 0..1_000_000 {
442            let idx = sampler.rand(&mut rng);
443            counts[idx] += 1;
444        }
445        for i in 0..weights.len() {
446            let expected = weights[i] / weights.iter().sum::<f64>();
447            let actual = counts[i] as f64 / 1_000_000.0;
448            assert!((expected - actual).abs() < 0.01);
449        }
450    }
451}