competitive/string/
rolling_hash.rs

1use super::{Gf2_63, Invertible, Mersenne61, Monoid, Ring, SemiRing, Xorshift};
2use std::{
3    cmp::Ordering,
4    fmt::{self, Debug},
5    marker::PhantomData,
6    ops::{Bound, RangeBounds, RangeInclusive},
7};
8
9pub trait RollingHasher {
10    type T;
11    type Hash: Copy + Eq;
12    fn init_with_rng(len: usize, rng: &mut Xorshift);
13    fn init(len: usize) {
14        let mut rng = Xorshift::new();
15        Self::init_with_rng(len, &mut rng);
16    }
17    fn ensure(len: usize);
18    fn hash_sequence<I>(iter: I) -> HashedSequence<Self>
19    where
20        I: IntoIterator<Item = Self::T>;
21    fn hash_substr(hashed: &[Self::Hash]) -> Hashed<Self>;
22    fn concat_hash(x: &Hashed<Self>, y: &Hashed<Self>) -> Hashed<Self>;
23    fn empty_hash() -> Hashed<Self>;
24}
25
26#[derive(Debug)]
27pub struct HashedSequence<Hasher>
28where
29    Hasher: RollingHasher + ?Sized,
30{
31    hashed: Vec<Hasher::Hash>,
32    _marker: PhantomData<fn() -> Hasher>,
33}
34
35impl<Hasher> HashedSequence<Hasher>
36where
37    Hasher: RollingHasher + ?Sized,
38{
39    fn new(hashed: Vec<Hasher::Hash>) -> Self {
40        Self {
41            hashed,
42            _marker: PhantomData,
43        }
44    }
45    pub fn len(&self) -> usize {
46        self.hashed.len() - 1
47    }
48    pub fn is_empty(&self) -> bool {
49        self.len() == 0
50    }
51    pub fn range<R>(&self, range: R) -> HashedRange<'_, Hasher>
52    where
53        R: RangeBounds<usize>,
54    {
55        HashedRange::new(&self.hashed[to_range(range, self.len())])
56    }
57    pub fn hash_range<R>(&self, range: R) -> Hashed<Hasher>
58    where
59        R: RangeBounds<usize>,
60    {
61        self.range(range).hash()
62    }
63}
64
65#[derive(Debug)]
66pub struct HashedRange<'a, Hasher>
67where
68    Hasher: RollingHasher + ?Sized,
69{
70    hashed: &'a [Hasher::Hash],
71    _marker: PhantomData<fn() -> Hasher>,
72}
73
74impl<Hasher> Clone for HashedRange<'_, Hasher>
75where
76    Hasher: RollingHasher + ?Sized,
77{
78    fn clone(&self) -> Self {
79        *self
80    }
81}
82
83impl<Hasher> Copy for HashedRange<'_, Hasher> where Hasher: RollingHasher + ?Sized {}
84
85impl<Hasher> PartialEq for HashedRange<'_, Hasher>
86where
87    Hasher: RollingHasher + ?Sized,
88{
89    fn eq(&self, other: &Self) -> bool {
90        self.hash() == other.hash()
91    }
92}
93
94impl<Hasher> Eq for HashedRange<'_, Hasher> where Hasher: RollingHasher + ?Sized {}
95
96impl<Hasher> PartialOrd for HashedRange<'_, Hasher>
97where
98    Hasher: RollingHasher + ?Sized,
99    Hasher::Hash: PartialOrd,
100{
101    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
102        let n = self.longest_common_prefix(other);
103        match (self.len() > n, other.len() > n) {
104            (true, true) => {
105                let x = self.hash_range(n..=n);
106                let y = other.hash_range(n..=n);
107                x.hash.partial_cmp(&y.hash)
108            }
109            (x, y) => Some(x.cmp(&y)),
110        }
111    }
112}
113
114impl<Hasher> Ord for HashedRange<'_, Hasher>
115where
116    Hasher: RollingHasher + ?Sized,
117    Hasher::Hash: Ord,
118{
119    fn cmp(&self, other: &Self) -> Ordering {
120        let n = self.longest_common_prefix(other);
121        match (self.len() > n, other.len() > n) {
122            (true, true) => {
123                let x = self.hash_range(n..=n);
124                let y = other.hash_range(n..=n);
125                x.hash.cmp(&y.hash)
126            }
127            (x, y) => x.cmp(&y),
128        }
129    }
130}
131
132impl<'a, Hasher> HashedRange<'a, Hasher>
133where
134    Hasher: RollingHasher + ?Sized,
135{
136    fn new(hashed: &'a [Hasher::Hash]) -> Self {
137        Self {
138            hashed,
139            _marker: PhantomData,
140        }
141    }
142    pub fn len(&self) -> usize {
143        self.hashed.len() - 1
144    }
145    pub fn is_empty(&self) -> bool {
146        self.len() == 0
147    }
148    pub fn range<R>(&self, range: R) -> HashedRange<'a, Hasher>
149    where
150        R: RangeBounds<usize>,
151    {
152        HashedRange::new(&self.hashed[to_range(range, self.len())])
153    }
154    pub fn hash_range<R>(&self, range: R) -> Hashed<Hasher>
155    where
156        R: RangeBounds<usize>,
157    {
158        self.range(range).hash()
159    }
160    pub fn hash(&self) -> Hashed<Hasher> {
161        Hasher::hash_substr(self.hashed)
162    }
163    pub fn longest_common_prefix(&self, other: &Self) -> usize {
164        let n = self.len().min(other.len());
165        let mut ok = 0usize;
166        let mut err = n + 1;
167        while ok + 1 < err {
168            let mid = (ok + err) / 2;
169            if self.range(..mid).hash() == other.range(..mid).hash() {
170                ok = mid;
171            } else {
172                err = mid;
173            }
174        }
175        ok
176    }
177    pub fn chainable(self) -> HashedRangeChained<'a, Hasher> {
178        vec![self].into()
179    }
180}
181
182pub struct HashedRangeChained<'a, Hasher>
183where
184    Hasher: RollingHasher + ?Sized,
185{
186    chained: Vec<HashedRange<'a, Hasher>>,
187    _marker: PhantomData<fn() -> Hasher>,
188}
189
190impl<Hasher: Debug> Debug for HashedRangeChained<'_, Hasher>
191where
192    Hasher: RollingHasher + ?Sized,
193    Hasher::Hash: Debug,
194{
195    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196        f.debug_struct("HashedRangeChained")
197            .field("chained", &self.chained)
198            .finish()
199    }
200}
201
202impl<Hasher: Default> Default for HashedRangeChained<'_, Hasher>
203where
204    Hasher: RollingHasher,
205{
206    fn default() -> Self {
207        Self {
208            chained: Default::default(),
209            _marker: Default::default(),
210        }
211    }
212}
213
214impl<Hasher: Clone> Clone for HashedRangeChained<'_, Hasher>
215where
216    Hasher: RollingHasher,
217{
218    fn clone(&self) -> Self {
219        Self {
220            chained: self.chained.clone(),
221            _marker: self._marker,
222        }
223    }
224}
225
226impl<Hasher> PartialEq for HashedRangeChained<'_, Hasher>
227where
228    Hasher: RollingHasher + ?Sized,
229{
230    fn eq(&self, other: &Self) -> bool {
231        let mut a = self.chained.iter().cloned();
232        let mut b = other.chained.iter().cloned();
233        macro_rules! next {
234            ($iter:expr) => {
235                loop {
236                    if let Some(x) = $iter.next() {
237                        if x.len() > 0 {
238                            break Some(x);
239                        }
240                    } else {
241                        break None;
242                    }
243                }
244            };
245        }
246        let mut x: Option<HashedRange<'_, Hasher>> = None;
247        let mut y: Option<HashedRange<'_, Hasher>> = None;
248        loop {
249            if x.map_or(true, |x| x.is_empty()) {
250                x = next!(a);
251            }
252            if y.map_or(true, |y| y.is_empty()) {
253                y = next!(b);
254            }
255            if let (Some(x), Some(y)) = (&mut x, &mut y) {
256                let k = x.len().min(y.len());
257                if x.range(..k) != y.range(..k) {
258                    return false;
259                }
260                *x = x.range(k..);
261                *y = y.range(k..);
262            } else {
263                break x.is_none() == y.is_none();
264            }
265        }
266    }
267}
268
269impl<Hasher> Eq for HashedRangeChained<'_, Hasher> where Hasher: RollingHasher + ?Sized {}
270
271impl<Hasher> PartialOrd for HashedRangeChained<'_, Hasher>
272where
273    Hasher: RollingHasher + ?Sized,
274    Hasher::Hash: PartialOrd,
275{
276    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
277        let mut a = self.chained.iter().cloned();
278        let mut b = other.chained.iter().cloned();
279        macro_rules! next {
280            ($iter:expr) => {
281                loop {
282                    if let Some(x) = $iter.next() {
283                        if x.len() > 0 {
284                            break Some(x);
285                        }
286                    } else {
287                        break None;
288                    }
289                }
290            };
291        }
292        let mut x: Option<HashedRange<'_, Hasher>> = None;
293        let mut y: Option<HashedRange<'_, Hasher>> = None;
294        loop {
295            if x.map_or(true, |x| x.is_empty()) {
296                x = next!(a);
297            }
298            if y.map_or(true, |y| y.is_empty()) {
299                y = next!(b);
300            }
301            if let (Some(x), Some(y)) = (&mut x, &mut y) {
302                let k = x.longest_common_prefix(y);
303                if x.len() > k && y.len() > k {
304                    let x = x.hash_range(k..=k);
305                    let y = y.hash_range(k..=k);
306                    break x.hash.partial_cmp(&y.hash);
307                };
308                *x = x.range(k..);
309                *y = y.range(k..);
310            } else {
311                break x.is_some().partial_cmp(&y.is_some());
312            }
313        }
314    }
315}
316
317impl<Hasher> Ord for HashedRangeChained<'_, Hasher>
318where
319    Hasher: RollingHasher + ?Sized,
320    Hasher::Hash: Ord,
321{
322    fn cmp(&self, other: &Self) -> Ordering {
323        let mut a = self.chained.iter().cloned();
324        let mut b = other.chained.iter().cloned();
325        macro_rules! next {
326            ($iter:expr) => {
327                loop {
328                    if let Some(x) = $iter.next() {
329                        if x.len() > 0 {
330                            break Some(x);
331                        }
332                    } else {
333                        break None;
334                    }
335                }
336            };
337        }
338        let mut x: Option<HashedRange<'_, Hasher>> = None;
339        let mut y: Option<HashedRange<'_, Hasher>> = None;
340        loop {
341            if x.map_or(true, |x| x.is_empty()) {
342                x = next!(a);
343            }
344            if y.map_or(true, |y| y.is_empty()) {
345                y = next!(b);
346            }
347            if let (Some(x), Some(y)) = (&mut x, &mut y) {
348                let k = x.longest_common_prefix(y);
349                if x.len() > k && y.len() > k {
350                    let x = x.hash_range(k..=k);
351                    let y = y.hash_range(k..=k);
352                    break x.hash.cmp(&y.hash);
353                };
354                *x = x.range(k..);
355                *y = y.range(k..);
356            } else {
357                break x.is_some().cmp(&y.is_some());
358            }
359        }
360    }
361}
362
363impl<'a, Hasher> From<Vec<HashedRange<'a, Hasher>>> for HashedRangeChained<'a, Hasher>
364where
365    Hasher: RollingHasher + ?Sized,
366{
367    fn from(hashed: Vec<HashedRange<'a, Hasher>>) -> Self {
368        Self {
369            chained: hashed,
370            _marker: PhantomData,
371        }
372    }
373}
374
375impl<'a, Hasher> HashedRangeChained<'a, Hasher>
376where
377    Hasher: RollingHasher + ?Sized,
378{
379    pub fn chain(mut self, x: HashedRange<'a, Hasher>) -> Self {
380        self.chained.push(x);
381        self
382    }
383    pub fn push(&mut self, x: HashedRange<'a, Hasher>) {
384        self.chained.push(x);
385    }
386}
387
388fn to_range<R>(range: R, ub: usize) -> RangeInclusive<usize>
389where
390    R: RangeBounds<usize>,
391{
392    let l = match range.start_bound() {
393        Bound::Included(l) => *l,
394        Bound::Excluded(l) => l + 1,
395        Bound::Unbounded => 0,
396    };
397    let r = match range.end_bound() {
398        Bound::Included(r) => r + 1,
399        Bound::Excluded(r) => *r,
400        Bound::Unbounded => ub,
401    };
402    l..=r
403}
404
405#[derive(Debug)]
406pub struct Hashed<Hasher>
407where
408    Hasher: RollingHasher + ?Sized,
409{
410    len: usize,
411    hash: Hasher::Hash,
412    _marker: PhantomData<fn() -> Hasher>,
413}
414
415impl<Hasher> std::hash::Hash for Hashed<Hasher>
416where
417    Hasher: RollingHasher + ?Sized,
418    Hasher::Hash: std::hash::Hash,
419{
420    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
421        self.len.hash(state);
422        self.hash.hash(state);
423        self._marker.hash(state);
424    }
425}
426
427impl<Hasher> Hashed<Hasher>
428where
429    Hasher: RollingHasher + ?Sized,
430{
431    fn new(len: usize, hash: Hasher::Hash) -> Self {
432        Self {
433            len,
434            hash,
435            _marker: PhantomData,
436        }
437    }
438    pub fn concat(&self, other: &Self) -> Self {
439        Hasher::concat_hash(self, other)
440    }
441    pub fn pow(&self, n: usize) -> Self {
442        let mut res = Hasher::empty_hash();
443        let mut x = *self;
444        let mut n = n;
445        while n > 0 {
446            if n & 1 == 1 {
447                res = res.concat(&x);
448            }
449            x = x.concat(&x);
450            n >>= 1;
451        }
452        res
453    }
454}
455
456impl<Hasher> Clone for Hashed<Hasher>
457where
458    Hasher: RollingHasher + ?Sized,
459{
460    fn clone(&self) -> Self {
461        *self
462    }
463}
464
465impl<Hasher> Copy for Hashed<Hasher> where Hasher: RollingHasher + ?Sized {}
466
467impl<Hasher> PartialEq for Hashed<Hasher>
468where
469    Hasher: RollingHasher + ?Sized,
470{
471    fn eq(&self, other: &Self) -> bool {
472        self.len == other.len && self.hash == other.hash
473    }
474}
475
476impl<Hasher> Eq for Hashed<Hasher> where Hasher: RollingHasher + ?Sized {}
477
478#[derive(Debug)]
479struct RollingHashPrecalc<R>
480where
481    R: SemiRing,
482{
483    base: R::T,
484    pow: Vec<R::T>,
485}
486
487impl<R> Default for RollingHashPrecalc<R>
488where
489    R: SemiRing,
490    R::T: Default,
491{
492    fn default() -> Self {
493        Self {
494            base: Default::default(),
495            pow: Default::default(),
496        }
497    }
498}
499
500impl<R> RollingHashPrecalc<R>
501where
502    R: SemiRing,
503    R::Additive: Invertible,
504{
505    fn new(base: R::T) -> Self {
506        Self {
507            base,
508            pow: vec![R::one()],
509        }
510    }
511    fn ensure_pow(&mut self, len: usize) {
512        if self.pow.len() <= len {
513            self.pow.reserve(len - self.pow.len() + 1);
514            if self.pow.is_empty() {
515                self.pow.push(R::one());
516            }
517            for _ in 0..=len - self.pow.len() {
518                self.pow.push(R::mul(self.pow.last().unwrap(), &self.base));
519            }
520        }
521    }
522    fn mul1_add(&self, x: &R::T, y: &R::T) -> R::T {
523        R::add(&R::mul(x, &self.base), y)
524    }
525    fn muln_add(&mut self, x: &R::T, y: &R::T, n: usize) -> R::T {
526        if let Some(pow) = self.pow.get(n) {
527            R::add(&R::mul(x, pow), y)
528        } else {
529            let pow = <R::Multiplicative as Monoid>::pow(self.base.clone(), n);
530            R::add(&R::mul(x, &pow), y)
531        }
532    }
533    fn muln_sub(&mut self, l: &R::T, r: &R::T, n: usize) -> R::T {
534        if let Some(pow) = self.pow.get(n) {
535            R::sub(r, &R::mul(l, pow))
536        } else {
537            let pow = <R::Multiplicative as Monoid>::pow(self.base.clone(), n);
538            R::sub(r, &R::mul(l, &pow))
539        }
540    }
541}
542
543macro_rules! impl_rolling_hasher {
544    (@inner $T:ident, $R:ty, [$($i:tt)*] [$($s:tt)*] [$a:tt $($tt:tt)*] [$k:tt $($j:tt)*]) => {
545        impl_rolling_hasher!(@inner $T, $R, [$($i)* $k] [$($s)* ()] [$($tt)*] [$($j)*]);
546    };
547    (@inner $T:ident, $R:ty, [$($i:tt)+] [$($s:tt)+] [] [$len:tt $($j:tt)*]) => {
548        #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
549        pub enum $T {}
550
551        impl $T {
552            fn __rolling_hash_local_key() -> &'static ::std::thread::LocalKey<::std::cell::Cell<[RollingHashPrecalc<$R>; $len]>> {
553                ::std::thread_local!(
554                    static __LOCAL_KEY: ::std::cell::Cell<[RollingHashPrecalc<$R>; $len]> = ::std::cell::Cell::new(Default::default())
555                );
556                &__LOCAL_KEY
557            }
558        }
559
560        impl RollingHasher for $T {
561            type T = <$R as SemiRing>::T;
562
563            type Hash = [<$R as SemiRing>::T; $len];
564
565            fn init_with_rng(len: usize, rng: &mut Xorshift) {
566                Self::__rolling_hash_local_key().with(|cell| {
567                    if unsafe{ (&*cell.as_ptr()).iter().all(|p| p.base == 0) } {
568                        cell.set([$({ $s; RollingHashPrecalc::new(rng.rand(<$R>::MOD)) },)+]);
569                    }
570                });
571                Self::ensure(len);
572            }
573
574            fn ensure(len: usize) {
575                Self::__rolling_hash_local_key().with(|cell| {
576                    unsafe {
577                        let arr = &mut *cell.as_ptr();
578                        $(arr[$i].ensure_pow(len);)+
579                    }
580                })
581            }
582
583            fn hash_sequence<I>(iter: I) -> HashedSequence<Self>
584            where
585                I: IntoIterator<Item = Self::T>,
586            {
587                let iter = iter.into_iter();
588                let (lb, _) = iter.size_hint();
589                let mut hashed = Vec::with_capacity(lb + 1);
590                hashed.push([$({ $s; <$R>::zero() },)+]);
591                unsafe {
592                    Self::__rolling_hash_local_key().with(|cell| {
593                        let arr = &*cell.as_ptr();
594                        for item in iter {
595                            let last = hashed.last().unwrap();
596                            let h = [$(arr[$i].mul1_add(&last[$i], &item),)+];
597                            hashed.push(h);
598                        }
599                    })
600                };
601                HashedSequence::new(hashed)
602            }
603
604            fn hash_substr(hashed: &[Self::Hash]) -> Hashed<Self> {
605                let len = hashed.len() - 1;
606                let h = unsafe {
607                    Self::__rolling_hash_local_key().with(|cell| {
608                        let arr = &mut *cell.as_ptr();
609                        [$(arr[$i].muln_sub(&hashed[0][$i], &hashed[len][$i], len),)+]
610                    })
611                };
612                Hashed::new(len, h)
613            }
614
615            fn concat_hash(x: &Hashed<Self>, y: &Hashed<Self>) -> Hashed<Self> {
616                let len = y.len;
617                let hash = unsafe {
618                    Self::__rolling_hash_local_key().with(|cell| {
619                        let arr = &mut *cell.as_ptr();
620                        [$(arr[$i].muln_add(&x.hash[$i], &y.hash[$i], len),)+]
621                    })
622                };
623                Hashed::new(x.len + y.len, hash)
624            }
625
626            fn empty_hash() -> Hashed<Self> {
627                Hashed::new(0, [$({ $s; <$R>::zero() },)+])
628            }
629        }
630    };
631    ($T:ident, $R:ty, [$($tt:tt)+]) => {
632        impl_rolling_hasher!(@inner $T, $R, [] [] [$($tt)+] [0 1 2 3 4 5 6 7 8 9]);
633    };
634}
635
636impl_rolling_hasher!(Mersenne61x1, Mersenne61, [_]);
637impl_rolling_hasher!(Mersenne61x2, Mersenne61, [_ _]);
638impl_rolling_hasher!(Mersenne61x3, Mersenne61, [_ _ _]);
639impl_rolling_hasher!(Gf2_63x1, Gf2_63, [_]);
640impl_rolling_hasher!(Gf2_63x2, Gf2_63, [_ _]);
641impl_rolling_hasher!(Gf2_63x3, Gf2_63, [_ _ _]);
642
643#[cfg(test)]
644mod tests {
645    use super::*;
646    use crate::tools::Xorshift;
647
648    #[test]
649    fn test_rolling_hash() {
650        const N: usize = 200;
651        let mut rng = Xorshift::default();
652        let a: Vec<_> = rng.random_iter(0..10u64).take(N).collect();
653        Mersenne61x3::init(N);
654        let h = Mersenne61x3::hash_sequence(a.iter().copied());
655        for k in 1..=N {
656            for l1 in 0..=N - k {
657                for l2 in 0..=N - k {
658                    assert_eq!(
659                        a[l1..l1 + k] == a[l2..l2 + k],
660                        h.range(l1..l1 + k) == h.range(l2..l2 + k),
661                        "a1: {:?}, a2: {:?}",
662                        &a[l1..l1 + k],
663                        &a[l2..l2 + k]
664                    );
665                }
666            }
667        }
668    }
669
670    #[test]
671    fn test_rolling_hash_limited_precalc() {
672        const N: usize = 200;
673        let mut rng = Xorshift::default();
674        let a: Vec<_> = rng.random_iter(0..10u64).take(N).collect();
675        Mersenne61x3::init(0);
676        let h = Mersenne61x3::hash_sequence(a.iter().copied());
677        for k in 1..=N {
678            for l1 in 0..=N - k {
679                for l2 in 0..=N - k {
680                    assert_eq!(
681                        a[l1..l1 + k] == a[l2..l2 + k],
682                        h.range(l1..l1 + k) == h.range(l2..l2 + k)
683                    );
684                }
685            }
686        }
687    }
688
689    #[test]
690    fn test_rolling_hash_pow() {
691        const N: usize = 20;
692        let mut rng = Xorshift::default();
693        let a: Vec<_> = rng.random_iter(0..10u64).take(N).collect();
694        Mersenne61x3::init(N);
695        for k in 0..=N {
696            for l in 0..=N - k {
697                let a = &a[l..l + k];
698                for n in 0..=N {
699                    let b = a.repeat(n);
700                    let ha = Mersenne61x3::hash_sequence(a.iter().copied())
701                        .hash_range(..)
702                        .pow(n);
703                    let hb = Mersenne61x3::hash_sequence(b.iter().copied()).hash_range(..);
704                    assert_eq!(ha, hb);
705                }
706            }
707        }
708    }
709}