competitive/string/
rolling_hash.rs

1use super::{Gf2_63, Invertible, Mersenne61, Monoid, Ring, SemiRing, Xorshift};
2use std::{
3    cmp::Ordering,
4    fmt::{self, Debug},
5    marker::PhantomData,
6    ops::{Bound, RangeBounds, RangeInclusive},
7};
8
9pub trait RollingHasher {
10    type T;
11    type Hash: Copy + Eq;
12    fn init_with_rng(len: usize, rng: &mut Xorshift);
13    fn init(len: usize) {
14        let mut rng = Xorshift::new();
15        Self::init_with_rng(len, &mut rng);
16    }
17    fn ensure(len: usize);
18    fn hash_sequence<I>(iter: I) -> HashedSequence<Self>
19    where
20        I: IntoIterator<Item = Self::T>;
21    fn hash_substr(hashed: &[Self::Hash]) -> Hashed<Self>;
22    fn concat_hash(x: &Hashed<Self>, y: &Hashed<Self>) -> Hashed<Self>;
23    fn empty_hash() -> Hashed<Self>;
24}
25
26#[derive(Debug)]
27pub struct HashedSequence<Hasher>
28where
29    Hasher: RollingHasher + ?Sized,
30{
31    hashed: Vec<Hasher::Hash>,
32    _marker: PhantomData<fn() -> Hasher>,
33}
34
35impl<Hasher> HashedSequence<Hasher>
36where
37    Hasher: RollingHasher + ?Sized,
38{
39    fn new(hashed: Vec<Hasher::Hash>) -> Self {
40        Self {
41            hashed,
42            _marker: PhantomData,
43        }
44    }
45    pub fn len(&self) -> usize {
46        self.hashed.len() - 1
47    }
48    pub fn is_empty(&self) -> bool {
49        self.len() == 0
50    }
51    pub fn range<R>(&self, range: R) -> HashedRange<'_, Hasher>
52    where
53        R: RangeBounds<usize>,
54    {
55        HashedRange::new(&self.hashed[to_range(range, self.len())])
56    }
57    pub fn hash_range<R>(&self, range: R) -> Hashed<Hasher>
58    where
59        R: RangeBounds<usize>,
60    {
61        self.range(range).hash()
62    }
63}
64
65#[derive(Debug)]
66pub struct HashedRange<'a, Hasher>
67where
68    Hasher: RollingHasher + ?Sized,
69{
70    hashed: &'a [Hasher::Hash],
71    _marker: PhantomData<fn() -> Hasher>,
72}
73
74impl<Hasher> Clone for HashedRange<'_, Hasher>
75where
76    Hasher: RollingHasher + ?Sized,
77{
78    fn clone(&self) -> Self {
79        *self
80    }
81}
82
83impl<Hasher> Copy for HashedRange<'_, Hasher> where Hasher: RollingHasher + ?Sized {}
84
85impl<Hasher> PartialEq for HashedRange<'_, Hasher>
86where
87    Hasher: RollingHasher + ?Sized,
88{
89    fn eq(&self, other: &Self) -> bool {
90        self.hash() == other.hash()
91    }
92}
93
94impl<Hasher> Eq for HashedRange<'_, Hasher> where Hasher: RollingHasher + ?Sized {}
95
96impl<Hasher> PartialOrd for HashedRange<'_, Hasher>
97where
98    Hasher: RollingHasher<Hash: PartialOrd> + ?Sized,
99{
100    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
101        let n = self.longest_common_prefix(other);
102        match (self.len() > n, other.len() > n) {
103            (true, true) => {
104                let x = self.hash_range(n..=n);
105                let y = other.hash_range(n..=n);
106                x.hash.partial_cmp(&y.hash)
107            }
108            (x, y) => Some(x.cmp(&y)),
109        }
110    }
111}
112
113impl<Hasher> Ord for HashedRange<'_, Hasher>
114where
115    Hasher: RollingHasher<Hash: Ord> + ?Sized,
116{
117    fn cmp(&self, other: &Self) -> Ordering {
118        let n = self.longest_common_prefix(other);
119        match (self.len() > n, other.len() > n) {
120            (true, true) => {
121                let x = self.hash_range(n..=n);
122                let y = other.hash_range(n..=n);
123                x.hash.cmp(&y.hash)
124            }
125            (x, y) => x.cmp(&y),
126        }
127    }
128}
129
130impl<'a, Hasher> HashedRange<'a, Hasher>
131where
132    Hasher: RollingHasher + ?Sized,
133{
134    fn new(hashed: &'a [Hasher::Hash]) -> Self {
135        Self {
136            hashed,
137            _marker: PhantomData,
138        }
139    }
140    pub fn len(&self) -> usize {
141        self.hashed.len() - 1
142    }
143    pub fn is_empty(&self) -> bool {
144        self.len() == 0
145    }
146    pub fn range<R>(&self, range: R) -> HashedRange<'a, Hasher>
147    where
148        R: RangeBounds<usize>,
149    {
150        HashedRange::new(&self.hashed[to_range(range, self.len())])
151    }
152    pub fn hash_range<R>(&self, range: R) -> Hashed<Hasher>
153    where
154        R: RangeBounds<usize>,
155    {
156        self.range(range).hash()
157    }
158    pub fn hash(&self) -> Hashed<Hasher> {
159        Hasher::hash_substr(self.hashed)
160    }
161    pub fn longest_common_prefix(&self, other: &Self) -> usize {
162        let n = self.len().min(other.len());
163        let mut ok = 0usize;
164        let mut err = n + 1;
165        while ok + 1 < err {
166            let mid = ok.midpoint(err);
167            if self.range(..mid).hash() == other.range(..mid).hash() {
168                ok = mid;
169            } else {
170                err = mid;
171            }
172        }
173        ok
174    }
175    pub fn chainable(self) -> HashedRangeChained<'a, Hasher> {
176        vec![self].into()
177    }
178}
179
180pub struct HashedRangeChained<'a, Hasher>
181where
182    Hasher: RollingHasher + ?Sized,
183{
184    chained: Vec<HashedRange<'a, Hasher>>,
185    _marker: PhantomData<fn() -> Hasher>,
186}
187
188impl<Hasher: Debug> Debug for HashedRangeChained<'_, Hasher>
189where
190    Hasher: RollingHasher<Hash: Debug> + ?Sized,
191{
192    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193        f.debug_struct("HashedRangeChained")
194            .field("chained", &self.chained)
195            .finish()
196    }
197}
198
199impl<Hasher: Default> Default for HashedRangeChained<'_, Hasher>
200where
201    Hasher: RollingHasher,
202{
203    fn default() -> Self {
204        Self {
205            chained: Default::default(),
206            _marker: Default::default(),
207        }
208    }
209}
210
211impl<Hasher: Clone> Clone for HashedRangeChained<'_, Hasher>
212where
213    Hasher: RollingHasher,
214{
215    fn clone(&self) -> Self {
216        Self {
217            chained: self.chained.clone(),
218            _marker: self._marker,
219        }
220    }
221}
222
223impl<Hasher> PartialEq for HashedRangeChained<'_, Hasher>
224where
225    Hasher: RollingHasher + ?Sized,
226{
227    fn eq(&self, other: &Self) -> bool {
228        let mut a = self.chained.iter().cloned();
229        let mut b = other.chained.iter().cloned();
230        macro_rules! next {
231            ($iter:expr) => {
232                loop {
233                    if let Some(x) = $iter.next() {
234                        if x.len() > 0 {
235                            break Some(x);
236                        }
237                    } else {
238                        break None;
239                    }
240                }
241            };
242        }
243        let mut x: Option<HashedRange<'_, Hasher>> = None;
244        let mut y: Option<HashedRange<'_, Hasher>> = None;
245        loop {
246            if x.is_none_or(|x| x.is_empty()) {
247                x = next!(a);
248            }
249            if y.is_none_or(|y| y.is_empty()) {
250                y = next!(b);
251            }
252            if let (Some(x), Some(y)) = (&mut x, &mut y) {
253                let k = x.len().min(y.len());
254                if x.range(..k) != y.range(..k) {
255                    return false;
256                }
257                *x = x.range(k..);
258                *y = y.range(k..);
259            } else {
260                break x.is_none() == y.is_none();
261            }
262        }
263    }
264}
265
266impl<Hasher> Eq for HashedRangeChained<'_, Hasher> where Hasher: RollingHasher + ?Sized {}
267
268impl<Hasher> PartialOrd for HashedRangeChained<'_, Hasher>
269where
270    Hasher: RollingHasher<Hash: PartialOrd> + ?Sized,
271{
272    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
273        let mut a = self.chained.iter().cloned();
274        let mut b = other.chained.iter().cloned();
275        macro_rules! next {
276            ($iter:expr) => {
277                loop {
278                    if let Some(x) = $iter.next() {
279                        if x.len() > 0 {
280                            break Some(x);
281                        }
282                    } else {
283                        break None;
284                    }
285                }
286            };
287        }
288        let mut x: Option<HashedRange<'_, Hasher>> = None;
289        let mut y: Option<HashedRange<'_, Hasher>> = None;
290        loop {
291            if x.is_none_or(|x| x.is_empty()) {
292                x = next!(a);
293            }
294            if y.is_none_or(|y| y.is_empty()) {
295                y = next!(b);
296            }
297            if let (Some(x), Some(y)) = (&mut x, &mut y) {
298                let k = x.longest_common_prefix(y);
299                if x.len() > k && y.len() > k {
300                    let x = x.hash_range(k..=k);
301                    let y = y.hash_range(k..=k);
302                    break x.hash.partial_cmp(&y.hash);
303                };
304                *x = x.range(k..);
305                *y = y.range(k..);
306            } else {
307                break x.is_some().partial_cmp(&y.is_some());
308            }
309        }
310    }
311}
312
313impl<Hasher> Ord for HashedRangeChained<'_, Hasher>
314where
315    Hasher: RollingHasher<Hash: Ord> + ?Sized,
316{
317    fn cmp(&self, other: &Self) -> Ordering {
318        let mut a = self.chained.iter().cloned();
319        let mut b = other.chained.iter().cloned();
320        macro_rules! next {
321            ($iter:expr) => {
322                loop {
323                    if let Some(x) = $iter.next() {
324                        if x.len() > 0 {
325                            break Some(x);
326                        }
327                    } else {
328                        break None;
329                    }
330                }
331            };
332        }
333        let mut x: Option<HashedRange<'_, Hasher>> = None;
334        let mut y: Option<HashedRange<'_, Hasher>> = None;
335        loop {
336            if x.is_none_or(|x| x.is_empty()) {
337                x = next!(a);
338            }
339            if y.is_none_or(|y| y.is_empty()) {
340                y = next!(b);
341            }
342            if let (Some(x), Some(y)) = (&mut x, &mut y) {
343                let k = x.longest_common_prefix(y);
344                if x.len() > k && y.len() > k {
345                    let x = x.hash_range(k..=k);
346                    let y = y.hash_range(k..=k);
347                    break x.hash.cmp(&y.hash);
348                };
349                *x = x.range(k..);
350                *y = y.range(k..);
351            } else {
352                break x.is_some().cmp(&y.is_some());
353            }
354        }
355    }
356}
357
358impl<'a, Hasher> From<Vec<HashedRange<'a, Hasher>>> for HashedRangeChained<'a, Hasher>
359where
360    Hasher: RollingHasher + ?Sized,
361{
362    fn from(hashed: Vec<HashedRange<'a, Hasher>>) -> Self {
363        Self {
364            chained: hashed,
365            _marker: PhantomData,
366        }
367    }
368}
369
370impl<'a, Hasher> HashedRangeChained<'a, Hasher>
371where
372    Hasher: RollingHasher + ?Sized,
373{
374    pub fn chain(mut self, x: HashedRange<'a, Hasher>) -> Self {
375        self.chained.push(x);
376        self
377    }
378    pub fn push(&mut self, x: HashedRange<'a, Hasher>) {
379        self.chained.push(x);
380    }
381}
382
383fn to_range<R>(range: R, ub: usize) -> RangeInclusive<usize>
384where
385    R: RangeBounds<usize>,
386{
387    let l = match range.start_bound() {
388        Bound::Included(l) => *l,
389        Bound::Excluded(l) => l + 1,
390        Bound::Unbounded => 0,
391    };
392    let r = match range.end_bound() {
393        Bound::Included(r) => r + 1,
394        Bound::Excluded(r) => *r,
395        Bound::Unbounded => ub,
396    };
397    l..=r
398}
399
400#[derive(Debug)]
401pub struct Hashed<Hasher>
402where
403    Hasher: RollingHasher + ?Sized,
404{
405    len: usize,
406    hash: Hasher::Hash,
407    _marker: PhantomData<fn() -> Hasher>,
408}
409
410impl<Hasher> std::hash::Hash for Hashed<Hasher>
411where
412    Hasher: RollingHasher<Hash: std::hash::Hash> + ?Sized,
413{
414    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
415        self.len.hash(state);
416        self.hash.hash(state);
417        self._marker.hash(state);
418    }
419}
420
421impl<Hasher> Hashed<Hasher>
422where
423    Hasher: RollingHasher + ?Sized,
424{
425    fn new(len: usize, hash: Hasher::Hash) -> Self {
426        Self {
427            len,
428            hash,
429            _marker: PhantomData,
430        }
431    }
432    pub fn concat(&self, other: &Self) -> Self {
433        Hasher::concat_hash(self, other)
434    }
435    pub fn pow(&self, n: usize) -> Self {
436        let mut res = Hasher::empty_hash();
437        let mut x = *self;
438        let mut n = n;
439        while n > 0 {
440            if n & 1 == 1 {
441                res = res.concat(&x);
442            }
443            x = x.concat(&x);
444            n >>= 1;
445        }
446        res
447    }
448}
449
450impl<Hasher> Clone for Hashed<Hasher>
451where
452    Hasher: RollingHasher + ?Sized,
453{
454    fn clone(&self) -> Self {
455        *self
456    }
457}
458
459impl<Hasher> Copy for Hashed<Hasher> where Hasher: RollingHasher + ?Sized {}
460
461impl<Hasher> PartialEq for Hashed<Hasher>
462where
463    Hasher: RollingHasher + ?Sized,
464{
465    fn eq(&self, other: &Self) -> bool {
466        self.len == other.len && self.hash == other.hash
467    }
468}
469
470impl<Hasher> Eq for Hashed<Hasher> where Hasher: RollingHasher + ?Sized {}
471
472#[derive(Debug)]
473struct RollingHashPrecalc<R>
474where
475    R: SemiRing,
476{
477    base: R::T,
478    pow: Vec<R::T>,
479}
480
481impl<R> Default for RollingHashPrecalc<R>
482where
483    R: SemiRing<T: Default>,
484{
485    fn default() -> Self {
486        Self {
487            base: Default::default(),
488            pow: Default::default(),
489        }
490    }
491}
492
493impl<R> RollingHashPrecalc<R>
494where
495    R: SemiRing<Additive: Invertible>,
496{
497    fn new(base: R::T) -> Self {
498        Self {
499            base,
500            pow: vec![R::one()],
501        }
502    }
503    fn ensure_pow(&mut self, len: usize) {
504        if self.pow.len() <= len {
505            self.pow.reserve(len - self.pow.len() + 1);
506            if self.pow.is_empty() {
507                self.pow.push(R::one());
508            }
509            for _ in 0..=len - self.pow.len() {
510                self.pow.push(R::mul(self.pow.last().unwrap(), &self.base));
511            }
512        }
513    }
514    fn mul1_add(&self, x: &R::T, y: &R::T) -> R::T {
515        R::add(&R::mul(x, &self.base), y)
516    }
517    fn muln_add(&mut self, x: &R::T, y: &R::T, n: usize) -> R::T {
518        if let Some(pow) = self.pow.get(n) {
519            R::add(&R::mul(x, pow), y)
520        } else {
521            let pow = <R::Multiplicative as Monoid>::pow(self.base.clone(), n);
522            R::add(&R::mul(x, &pow), y)
523        }
524    }
525    fn muln_sub(&mut self, l: &R::T, r: &R::T, n: usize) -> R::T {
526        if let Some(pow) = self.pow.get(n) {
527            R::sub(r, &R::mul(l, pow))
528        } else {
529            let pow = <R::Multiplicative as Monoid>::pow(self.base.clone(), n);
530            R::sub(r, &R::mul(l, &pow))
531        }
532    }
533}
534
535macro_rules! impl_rolling_hasher {
536    (@inner $T:ident, $R:ty, [$($i:tt)*] [$($s:tt)*] [$a:tt $($tt:tt)*] [$k:tt $($j:tt)*]) => {
537        impl_rolling_hasher!(@inner $T, $R, [$($i)* $k] [$($s)* ()] [$($tt)*] [$($j)*]);
538    };
539    (@inner $T:ident, $R:ty, [$($i:tt)+] [$($s:tt)+] [] [$len:tt $($j:tt)*]) => {
540        #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
541        pub enum $T {}
542
543        impl $T {
544            fn __rolling_hash_local_key() -> &'static ::std::thread::LocalKey<::std::cell::Cell<[RollingHashPrecalc<$R>; $len]>> {
545                ::std::thread_local!(
546                    static __LOCAL_KEY: ::std::cell::Cell<[RollingHashPrecalc<$R>; $len]> = ::std::cell::Cell::new(Default::default())
547                );
548                &__LOCAL_KEY
549            }
550        }
551
552        impl RollingHasher for $T {
553            type T = <$R as SemiRing>::T;
554
555            type Hash = [<$R as SemiRing>::T; $len];
556
557            fn init_with_rng(len: usize, rng: &mut Xorshift) {
558                Self::__rolling_hash_local_key().with(|cell| {
559                    if unsafe{ (&*cell.as_ptr()).iter().all(|p| p.base == 0) } {
560                        cell.set([$({ $s; RollingHashPrecalc::new(rng.rand(<$R>::MOD)) },)+]);
561                    }
562                });
563                Self::ensure(len);
564            }
565
566            fn ensure(len: usize) {
567                Self::__rolling_hash_local_key().with(|cell| {
568                    unsafe {
569                        let arr = &mut *cell.as_ptr();
570                        $(arr[$i].ensure_pow(len);)+
571                    }
572                })
573            }
574
575            fn hash_sequence<I>(iter: I) -> HashedSequence<Self>
576            where
577                I: IntoIterator<Item = Self::T>,
578            {
579                let iter = iter.into_iter();
580                let (lb, _) = iter.size_hint();
581                let mut hashed = Vec::with_capacity(lb + 1);
582                hashed.push([$({ $s; <$R>::zero() },)+]);
583                unsafe {
584                    Self::__rolling_hash_local_key().with(|cell| {
585                        let arr = &*cell.as_ptr();
586                        for item in iter {
587                            let last = hashed.last().unwrap();
588                            let h = [$(arr[$i].mul1_add(&last[$i], &item),)+];
589                            hashed.push(h);
590                        }
591                    })
592                };
593                HashedSequence::new(hashed)
594            }
595
596            fn hash_substr(hashed: &[Self::Hash]) -> Hashed<Self> {
597                let len = hashed.len() - 1;
598                let h = unsafe {
599                    Self::__rolling_hash_local_key().with(|cell| {
600                        let arr = &mut *cell.as_ptr();
601                        [$(arr[$i].muln_sub(&hashed[0][$i], &hashed[len][$i], len),)+]
602                    })
603                };
604                Hashed::new(len, h)
605            }
606
607            fn concat_hash(x: &Hashed<Self>, y: &Hashed<Self>) -> Hashed<Self> {
608                let len = y.len;
609                let hash = unsafe {
610                    Self::__rolling_hash_local_key().with(|cell| {
611                        let arr = &mut *cell.as_ptr();
612                        [$(arr[$i].muln_add(&x.hash[$i], &y.hash[$i], len),)+]
613                    })
614                };
615                Hashed::new(x.len + y.len, hash)
616            }
617
618            fn empty_hash() -> Hashed<Self> {
619                Hashed::new(0, [$({ $s; <$R>::zero() },)+])
620            }
621        }
622    };
623    ($T:ident, $R:ty, [$($tt:tt)+]) => {
624        impl_rolling_hasher!(@inner $T, $R, [] [] [$($tt)+] [0 1 2 3 4 5 6 7 8 9]);
625    };
626}
627
628impl_rolling_hasher!(Mersenne61x1, Mersenne61, [_]);
629impl_rolling_hasher!(Mersenne61x2, Mersenne61, [_ _]);
630impl_rolling_hasher!(Mersenne61x3, Mersenne61, [_ _ _]);
631impl_rolling_hasher!(Gf2_63x1, Gf2_63, [_]);
632impl_rolling_hasher!(Gf2_63x2, Gf2_63, [_ _]);
633impl_rolling_hasher!(Gf2_63x3, Gf2_63, [_ _ _]);
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638    use crate::tools::Xorshift;
639
640    #[test]
641    fn test_rolling_hash() {
642        const N: usize = 200;
643        let mut rng = Xorshift::default();
644        let a: Vec<_> = rng.random_iter(0..10u64).take(N).collect();
645        Mersenne61x3::init(N);
646        let h = Mersenne61x3::hash_sequence(a.iter().copied());
647        for k in 1..=N {
648            for l1 in 0..=N - k {
649                for l2 in 0..=N - k {
650                    assert_eq!(
651                        a[l1..l1 + k] == a[l2..l2 + k],
652                        h.range(l1..l1 + k) == h.range(l2..l2 + k),
653                        "a1: {:?}, a2: {:?}",
654                        &a[l1..l1 + k],
655                        &a[l2..l2 + k]
656                    );
657                }
658            }
659        }
660    }
661
662    #[test]
663    fn test_rolling_hash_limited_precalc() {
664        const N: usize = 200;
665        let mut rng = Xorshift::default();
666        let a: Vec<_> = rng.random_iter(0..10u64).take(N).collect();
667        Mersenne61x3::init(0);
668        let h = Mersenne61x3::hash_sequence(a.iter().copied());
669        for k in 1..=N {
670            for l1 in 0..=N - k {
671                for l2 in 0..=N - k {
672                    assert_eq!(
673                        a[l1..l1 + k] == a[l2..l2 + k],
674                        h.range(l1..l1 + k) == h.range(l2..l2 + k)
675                    );
676                }
677            }
678        }
679    }
680
681    #[test]
682    fn test_rolling_hash_pow() {
683        const N: usize = 20;
684        let mut rng = Xorshift::default();
685        let a: Vec<_> = rng.random_iter(0..10u64).take(N).collect();
686        Mersenne61x3::init(N);
687        for k in 0..=N {
688            for l in 0..=N - k {
689                let a = &a[l..l + k];
690                for n in 0..=N {
691                    let b = a.repeat(n);
692                    let ha = Mersenne61x3::hash_sequence(a.iter().copied())
693                        .hash_range(..)
694                        .pow(n);
695                    let hb = Mersenne61x3::hash_sequence(b.iter().copied()).hash_range(..);
696                    assert_eq!(ha, hb);
697                }
698            }
699        }
700    }
701}