Skip to main content

competitive/string/
string_search.rs

1use super::{RangeMinimumQuery, SuffixArray};
2use std::{cmp::Ordering, ops::Range};
3
4#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
5enum Delimited<T> {
6    Separator(usize),
7    Value(T),
8}
9
10trait Pattern<T> {
11    fn len(&self) -> usize;
12    fn eq_text(&self, index: usize, text: &T) -> bool;
13    fn cmp_text(&self, index: usize, text: &T) -> Ordering;
14}
15
16impl<T> Pattern<T> for [T]
17where
18    T: Ord,
19{
20    fn len(&self) -> usize {
21        self.len()
22    }
23
24    fn eq_text(&self, index: usize, text: &T) -> bool {
25        text == &self[index]
26    }
27
28    fn cmp_text(&self, index: usize, text: &T) -> Ordering {
29        text.cmp(&self[index])
30    }
31}
32
33struct DelimitedPattern<'a, T> {
34    pattern: &'a [T],
35}
36
37impl<T> Pattern<Delimited<T>> for DelimitedPattern<'_, T>
38where
39    T: Ord,
40{
41    fn len(&self) -> usize {
42        self.pattern.len()
43    }
44
45    fn eq_text(&self, index: usize, text: &Delimited<T>) -> bool {
46        matches!(text, Delimited::Value(value) if value == &self.pattern[index])
47    }
48
49    fn cmp_text(&self, index: usize, text: &Delimited<T>) -> Ordering {
50        match text {
51            Delimited::Separator(_) => Ordering::Less,
52            Delimited::Value(value) => value.cmp(&self.pattern[index]),
53        }
54    }
55}
56
57#[derive(Debug)]
58pub struct StringSearch<T> {
59    text: Vec<T>,
60    suffix_array: SuffixArray,
61    lcp_array: Vec<usize>,
62    rank: Vec<usize>,
63    rmq: RangeMinimumQuery<usize>,
64}
65
66impl<T> StringSearch<T>
67where
68    T: Ord,
69{
70    pub fn new(text: Vec<T>) -> Self {
71        let suffix_array = SuffixArray::new(&text);
72
73        let (lcp_array, rank) = suffix_array.lcp_array_with_rank(&text);
74        let rmq = RangeMinimumQuery::new(lcp_array.clone());
75
76        Self {
77            text,
78            suffix_array,
79            lcp_array,
80            rank,
81            rmq,
82        }
83    }
84
85    pub fn text(&self) -> &[T] {
86        &self.text
87    }
88
89    pub fn suffix_array(&self) -> &SuffixArray {
90        &self.suffix_array
91    }
92
93    pub fn lcp_array(&self) -> &[usize] {
94        &self.lcp_array
95    }
96
97    pub fn rank(&self) -> &[usize] {
98        &self.rank
99    }
100
101    pub fn longest_common_prefix(&self, a: Range<usize>, b: Range<usize>) -> usize {
102        debug_assert!(a.start <= a.end && a.end <= self.text.len());
103        debug_assert!(b.start <= b.end && b.end <= self.text.len());
104        let len = (a.end - a.start).min(b.end - b.start);
105        self.lcp_suffix(a.start, b.start).min(len)
106    }
107
108    pub fn compare(&self, a: Range<usize>, b: Range<usize>) -> Ordering {
109        debug_assert!(a.start <= a.end && a.end <= self.text.len());
110        debug_assert!(b.start <= b.end && b.end <= self.text.len());
111        let len_a = a.end - a.start;
112        let len_b = b.end - b.start;
113        let len = len_a.min(len_b);
114        let lcp = self.lcp_suffix(a.start, b.start).min(len);
115        if lcp == len {
116            return len_a.cmp(&len_b);
117        }
118        self.text[a.start + lcp].cmp(&self.text[b.start + lcp])
119    }
120
121    pub fn range(&self, pattern: &[T]) -> Range<usize> {
122        let left = self.bound_prefix(pattern, false);
123        let right = self.bound_prefix(pattern, true);
124        left..right
125    }
126
127    pub fn positions(&self, range: Range<usize>) -> impl DoubleEndedIterator<Item = usize> + '_ {
128        debug_assert!(range.start <= range.end);
129        debug_assert!(range.end <= self.text.len() + 1);
130        range.map(move |i| self.suffix_array[i])
131    }
132
133    pub fn kth_substrings(&self) -> KthSubstrings<'_, T> {
134        KthSubstrings::new(self)
135    }
136
137    fn lcp_suffix(&self, a: usize, b: usize) -> usize {
138        self.lcp_sa(self.rank[a], self.rank[b])
139    }
140
141    fn lcp_sa(&self, a: usize, b: usize) -> usize {
142        if a == b {
143            return self.text.len() - self.suffix_array[a];
144        }
145        let (l, r) = if a < b { (a, b) } else { (b, a) };
146        self.rmq.fold(l, r)
147    }
148
149    fn compare_suffix_pattern<P>(
150        &self,
151        suffix_start: usize,
152        pattern: &P,
153        start: usize,
154    ) -> (Ordering, usize)
155    where
156        P: Pattern<T> + ?Sized,
157    {
158        let n = self.text.len();
159        let m = pattern.len();
160        let mut i = start;
161        while i < m && suffix_start + i < n && pattern.eq_text(i, &self.text[suffix_start + i]) {
162            i += 1;
163        }
164        let ord = if i == m {
165            Ordering::Equal
166        } else if suffix_start + i == n {
167            Ordering::Less
168        } else {
169            pattern.cmp_text(i, &self.text[suffix_start + i])
170        };
171        (ord, i)
172    }
173
174    fn bound_prefix<P>(&self, pattern: &P, upper: bool) -> usize
175    where
176        P: Pattern<T> + ?Sized,
177    {
178        if pattern.len() == 0 {
179            return if upper { self.text.len() + 1 } else { 0 };
180        }
181        let pred = |ord: Ordering| {
182            if upper {
183                ord == Ordering::Greater
184            } else {
185                ord != Ordering::Less
186            }
187        };
188        let (cmp_last, lcp_last) =
189            self.compare_suffix_pattern(self.suffix_array[self.text.len()], pattern, 0);
190        if !pred(cmp_last) {
191            return self.text.len() + 1;
192        }
193        let mut l = 0usize;
194        let mut r = self.text.len();
195        let mut lcp_l = 0usize;
196        let mut lcp_r = lcp_last;
197        while r - l > 1 {
198            let m = (l + r) >> 1;
199            let start = match lcp_l.cmp(&lcp_r) {
200                Ordering::Less => lcp_l.min(self.lcp_sa(l, m)),
201                Ordering::Greater => lcp_r.min(self.lcp_sa(m, r)),
202                Ordering::Equal => lcp_l,
203            };
204            let (cmp_m, lcp_m) = self.compare_suffix_pattern(self.suffix_array[m], pattern, start);
205            if pred(cmp_m) {
206                r = m;
207                lcp_r = lcp_m;
208            } else {
209                l = m;
210                lcp_l = lcp_m;
211            }
212        }
213        r
214    }
215
216    fn geq_suffix(&self, range: Range<usize>) -> usize {
217        let n = self.text.len();
218        debug_assert!(range.start <= range.end && range.end <= n);
219        let mut l = 0usize;
220        let mut r = n;
221        while r - l > 1 {
222            let m = (l + r) >> 1;
223            let ord = self.compare(self.suffix_array[m]..n, range.start..range.end);
224            if matches!(ord, Ordering::Less) {
225                l = m;
226            } else {
227                r = m;
228            }
229        }
230        r
231    }
232}
233
234#[derive(Debug)]
235pub struct KthSubstrings<'a, T> {
236    search: &'a StringSearch<T>,
237    prefix: Vec<u64>,
238}
239
240impl<'a, T> KthSubstrings<'a, T>
241where
242    T: Ord,
243{
244    fn new(search: &'a StringSearch<T>) -> Self {
245        let n = search.text.len();
246        let mut prefix = Vec::with_capacity(n + 1);
247        prefix.push(0);
248        let mut total = 0u64;
249        for i in 1..=n {
250            total += (n - search.suffix_array[i] - search.lcp_array[i - 1]) as u64;
251            prefix.push(total);
252        }
253        Self { search, prefix }
254    }
255
256    pub fn kth_distinct_substring(&self, k: u64) -> Option<Range<usize>> {
257        let idx = self.prefix.partition_point(|&x| x <= k);
258        if idx == self.prefix.len() {
259            return None;
260        }
261        debug_assert!(idx > 0);
262        let start = self.search.suffix_array[idx];
263        let len = self.search.lcp_array[idx - 1] + (k - self.prefix[idx - 1]) as usize + 1;
264        Some(start..start + len)
265    }
266
267    pub fn index_of_distinct_substring(&self, range: Range<usize>) -> u64 {
268        debug_assert!(range.start < range.end && range.end <= self.search.text.len());
269        let m = range.len();
270        let idx = self.search.geq_suffix(range);
271        self.prefix[idx - 1] + (m - self.search.lcp_array[idx - 1] - 1) as u64
272    }
273}
274
275#[derive(Debug)]
276pub struct MultipleStringSearch<T> {
277    texts: Vec<Vec<T>>,
278    offsets: Vec<usize>,
279    position_map: Vec<(usize, usize)>,
280    search: StringSearch<Delimited<T>>,
281}
282
283impl<T> MultipleStringSearch<T>
284where
285    T: Ord + Clone,
286{
287    pub fn new(texts: Vec<Vec<T>>) -> Self {
288        assert!(!texts.is_empty());
289        let total_len: usize = texts.iter().map(|text| text.len() + 1).sum();
290        let mut concat = Vec::with_capacity(total_len - 1);
291        let mut offsets = Vec::with_capacity(texts.len());
292        let mut position_map = Vec::with_capacity(total_len);
293        for (i, text) in texts.iter().enumerate() {
294            offsets.push(concat.len());
295            for (pos, value) in text.iter().cloned().enumerate() {
296                concat.push(Delimited::Value(value));
297                position_map.push((i, pos));
298            }
299            if i + 1 < texts.len() {
300                concat.push(Delimited::Separator(!i));
301            }
302            position_map.push((i, text.len()));
303        }
304        let search = StringSearch::new(concat);
305        Self {
306            texts,
307            offsets,
308            position_map,
309            search,
310        }
311    }
312
313    pub fn texts(&self) -> &[Vec<T>] {
314        &self.texts
315    }
316
317    pub fn longest_common_prefix(
318        &self,
319        a: (usize, Range<usize>),
320        b: (usize, Range<usize>),
321    ) -> usize {
322        let a = self.to_global_range(a);
323        let b = self.to_global_range(b);
324        self.search.longest_common_prefix(a, b)
325    }
326
327    pub fn compare(&self, a: (usize, Range<usize>), b: (usize, Range<usize>)) -> Ordering {
328        let a = self.to_global_range(a);
329        let b = self.to_global_range(b);
330        self.search.compare(a, b)
331    }
332
333    pub fn range(&self, pattern: &[T]) -> Range<usize> {
334        let pattern = DelimitedPattern { pattern };
335        let left = self.search.bound_prefix(&pattern, false);
336        let right = self.search.bound_prefix(&pattern, true);
337        left..right
338    }
339
340    pub fn positions(
341        &self,
342        range: Range<usize>,
343    ) -> impl DoubleEndedIterator<Item = (usize, usize)> + '_ {
344        debug_assert!(range.start <= range.end);
345        debug_assert!(range.end <= self.position_map.len());
346        range.map(move |i| self.position_map[self.search.suffix_array[i]])
347    }
348
349    pub fn kth_substrings(&self) -> MultipleKthSubstrings<'_, T> {
350        MultipleKthSubstrings::new(self)
351    }
352
353    fn to_global_range(&self, (index, range): (usize, Range<usize>)) -> Range<usize> {
354        debug_assert!(index < self.texts.len());
355        let len = self.texts[index].len();
356        debug_assert!(range.start <= range.end && range.end <= len);
357        let base = self.offsets[index];
358        base + range.start..base + range.end
359    }
360
361    fn suffix_len(&self, a: usize) -> usize {
362        let (text_idx, pos) = self.position_map[self.search.suffix_array[a]];
363        self.texts[text_idx].len() - pos
364    }
365}
366
367#[derive(Debug)]
368pub struct MultipleKthSubstrings<'a, T> {
369    search: &'a MultipleStringSearch<T>,
370    prefix: Vec<u64>,
371}
372
373impl<'a, T> MultipleKthSubstrings<'a, T>
374where
375    T: Ord + Clone,
376{
377    fn new(search: &'a MultipleStringSearch<T>) -> Self {
378        let n = search.search.text.len();
379        let mut prefix = Vec::with_capacity(n);
380        prefix.push(0);
381        let mut total = 0u64;
382        for i in 1..=n {
383            let len = search.suffix_len(i);
384            let prev_len = search.suffix_len(i - 1);
385            let lcp_prev = search.search.lcp_array[i - 1].min(len).min(prev_len);
386            total += (len - lcp_prev) as u64;
387            prefix.push(total);
388        }
389        Self { search, prefix }
390    }
391
392    pub fn kth_distinct_substring(&self, k: u64) -> Option<(usize, Range<usize>)> {
393        let idx = self.prefix.partition_point(|&x| x <= k);
394        if idx == self.prefix.len() {
395            return None;
396        }
397        let (text_idx, pos) = self.search.position_map[self.search.search.suffix_array[idx]];
398        let len = self.search.suffix_len(idx) - (self.prefix[idx] - k) as usize + 1;
399        Some((text_idx, pos..pos + len))
400    }
401
402    pub fn index_of_distinct_substring(&self, (text_idx, range): (usize, Range<usize>)) -> u64 {
403        debug_assert!(text_idx < self.search.texts.len());
404        debug_assert!(range.start < range.end && range.end <= self.search.texts[text_idx].len());
405        let m = range.len();
406        let range = self.search.to_global_range((text_idx, range));
407        let idx = self.search.search.geq_suffix(range);
408        let len = self.search.suffix_len(idx);
409        let prev_len = self.search.suffix_len(idx - 1);
410        let lcp_prev = self.search.search.lcp_array[idx - 1].min(len).min(prev_len);
411        self.prefix[idx - 1] + (m - lcp_prev - 1) as u64
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418    use crate::tools::{WithEmptySegment as Wes, Xorshift};
419    use std::collections::{BTreeMap, BTreeSet};
420
421    #[test]
422    fn test_longest_common_prefix_and_compare() {
423        let mut rng = Xorshift::default();
424        for _ in 0..500 {
425            let n = rng.random(0..=80);
426            let m = rng.random(1..=20);
427            let s: Vec<_> = rng.random_iter(0..m).take(n).collect();
428            let search = StringSearch::new(s.clone());
429            if n == 0 {
430                assert_eq!(search.longest_common_prefix(0..0, 0..0), 0);
431                assert_eq!(search.compare(0..0, 0..0), Ordering::Equal);
432                continue;
433            }
434            for _ in 0..200 {
435                let (al, ar) = rng.random(Wes(n));
436                let (bl, br) = rng.random(Wes(n));
437                let lcp = s[al..ar]
438                    .iter()
439                    .zip(s[bl..br].iter())
440                    .take_while(|(x, y)| x == y)
441                    .count();
442                assert_eq!(search.longest_common_prefix(al..ar, bl..br), lcp);
443                let expected = s[al..ar].cmp(&s[bl..br]);
444                assert_eq!(search.compare(al..ar, bl..br), expected);
445            }
446        }
447    }
448
449    #[test]
450    fn test_range() {
451        let mut rng = Xorshift::default();
452        for _ in 0..500 {
453            let n = rng.random(0..=80);
454            let csize = rng.random(1..=20);
455            let s: Vec<usize> = rng.random_iter(0..csize).take(n).collect();
456            let search = StringSearch::new(s.clone());
457            let mut sa: Vec<_> = (0..=n).collect();
458            sa.sort_unstable_by_key(|&i| &s[i..]);
459            for _ in 0..200 {
460                let pattern = if n == 0 || rng.random(0..=1) == 0 {
461                    let m = rng.random(0..=n + 2);
462                    rng.random_iter(0..csize).take(m).collect()
463                } else {
464                    let (l, r) = rng.random(Wes(n));
465                    s[l..r].to_vec()
466                };
467                let cmp = |pos| {
468                    if s[pos..].starts_with(&pattern) {
469                        Ordering::Equal
470                    } else {
471                        s[pos..].cmp(&pattern)
472                    }
473                };
474                let left = sa
475                    .iter()
476                    .position(|&pos| cmp(pos) != Ordering::Less)
477                    .unwrap_or(sa.len());
478                let right = sa
479                    .iter()
480                    .rposition(|&pos| cmp(pos) != Ordering::Greater)
481                    .map_or(left, |i| i + 1);
482                let range = search.range(&pattern);
483                assert_eq!(range, left..right);
484                let positions: Vec<_> = search.positions(range).collect();
485                assert_eq!(positions, sa[left..right]);
486            }
487        }
488    }
489
490    #[test]
491    fn test_kth_substring() {
492        let mut rng = Xorshift::default();
493        for _ in 0..500 {
494            let n = rng.random(0..=80);
495            let csize = rng.random(1..=20);
496            let s: Vec<usize> = rng.random_iter(0..csize).take(n).collect();
497            let search = StringSearch::new(s.clone());
498            let kth = search.kth_substrings();
499            let mut set = BTreeSet::new();
500            for i in 0..n {
501                for j in i + 1..=n {
502                    set.insert(s[i..j].to_vec());
503                }
504            }
505            let substrings: Vec<_> = set.into_iter().collect();
506            for (k, expected) in substrings.iter().enumerate() {
507                let range = kth.kth_distinct_substring(k as u64).unwrap();
508                assert_eq!(&s[range.clone()], expected.as_slice());
509                assert_eq!(kth.index_of_distinct_substring(range), k as u64);
510            }
511            assert_eq!(kth.kth_distinct_substring(substrings.len() as u64), None);
512            let mut index_map = BTreeMap::new();
513            for (idx, substring) in substrings.iter().enumerate() {
514                index_map.insert(substring.clone(), idx as _);
515            }
516            for i in 0..n {
517                for j in i + 1..=n {
518                    let key = s[i..j].to_vec();
519                    let expected = *index_map.get(&key).unwrap();
520                    assert_eq!(kth.index_of_distinct_substring(i..j), expected);
521                }
522            }
523        }
524    }
525
526    #[test]
527    fn test_multiple_longest_common_prefix_and_compare() {
528        let mut rng = Xorshift::default();
529        for _ in 0..200 {
530            let k = rng.random(1..=6);
531            let csize = rng.random(1..=20);
532            let mut texts = Vec::with_capacity(k);
533            for _ in 0..k {
534                let n = rng.random(0..=40);
535                let s: Vec<_> = rng.random_iter(0..csize).take(n).collect();
536                texts.push(s);
537            }
538            let search = MultipleStringSearch::new(texts.clone());
539            for _ in 0..200 {
540                let i = rng.random(0..k);
541                let j = rng.random(0..k);
542                let (al, ar) = rng.random(Wes(texts[i].len()));
543                let (bl, br) = rng.random(Wes(texts[j].len()));
544                let lcp = texts[i][al..ar]
545                    .iter()
546                    .zip(texts[j][bl..br].iter())
547                    .take_while(|(x, y)| x == y)
548                    .count();
549                assert_eq!(search.longest_common_prefix((i, al..ar), (j, bl..br)), lcp);
550                assert_eq!(
551                    search.compare((i, al..ar), (j, bl..br)),
552                    texts[i][al..ar].cmp(&texts[j][bl..br])
553                );
554            }
555        }
556    }
557
558    #[test]
559    fn test_multiple_range() {
560        let mut rng = Xorshift::default();
561        for _ in 0..200 {
562            let k = rng.random(1..=6);
563            let csize = rng.random(1..=20);
564            let mut texts = Vec::with_capacity(k);
565            for _ in 0..k {
566                let n = rng.random(0..=40);
567                let s: Vec<_> = rng.random_iter(0..csize).take(n).collect();
568                texts.push(s);
569            }
570            let search = MultipleStringSearch::new(texts.clone());
571            let mut sa: Vec<_> = (0..k)
572                .flat_map(|i| (0..=texts[i].len()).map(move |pos| (i, pos)))
573                .collect();
574            sa.sort_unstable_by_key(|&(i, pos)| (&texts[i][pos..], !i));
575            for _ in 0..200 {
576                let pattern = if rng.random(0..=1) == 0 {
577                    let m = rng.random(0..=50);
578                    rng.random_iter(0..csize).take(m).collect()
579                } else {
580                    let idx = rng.random(0..k);
581                    let (l, r) = rng.random(Wes(texts[idx].len()));
582                    texts[idx][l..r].to_vec()
583                };
584                let cmp = |i: usize, pos: usize| {
585                    if texts[i][pos..].starts_with(&pattern) {
586                        Ordering::Equal
587                    } else {
588                        texts[i][pos..].cmp(&pattern)
589                    }
590                };
591                let left = sa
592                    .iter()
593                    .position(|&(i, pos)| cmp(i, pos) != Ordering::Less)
594                    .unwrap_or(sa.len());
595                let right = sa
596                    .iter()
597                    .rposition(|&(i, pos)| cmp(i, pos) != Ordering::Greater)
598                    .map_or(left, |idx| idx + 1);
599                let range = search.range(&pattern);
600                assert_eq!(range, left..right);
601                let positions: Vec<_> = search.positions(range).collect();
602                assert_eq!(positions, sa[left..right]);
603            }
604        }
605    }
606
607    #[test]
608    fn test_multiple_kth_substring() {
609        let mut rng = Xorshift::default();
610        for _ in 0..200 {
611            let k = rng.random(1..=6);
612            let csize = rng.random(1..=20);
613            let mut texts = Vec::with_capacity(k);
614            for _ in 0..k {
615                let n = rng.random(0..=40);
616                let s: Vec<_> = rng.random_iter(0..csize).take(n).collect();
617                texts.push(s);
618            }
619            let search = MultipleStringSearch::new(texts.clone());
620            let kth = search.kth_substrings();
621            let mut set = BTreeSet::new();
622            for text in &texts {
623                for i in 0..text.len() {
624                    for j in i + 1..=text.len() {
625                        set.insert(text[i..j].to_vec());
626                    }
627                }
628            }
629            let substrings: Vec<_> = set.into_iter().collect();
630            for (idx, expected) in substrings.iter().enumerate() {
631                let (text_idx, range) = kth.kth_distinct_substring(idx as u64).unwrap();
632                assert_eq!(&texts[text_idx][range.clone()], expected.as_slice());
633                assert_eq!(kth.index_of_distinct_substring((text_idx, range)), idx as _);
634            }
635            assert_eq!(kth.kth_distinct_substring(substrings.len() as u64), None);
636            let mut index_map = BTreeMap::new();
637            for (idx, substring) in substrings.iter().enumerate() {
638                index_map.insert(substring.clone(), idx as u64);
639            }
640            for (text_idx, text) in texts.iter().enumerate() {
641                for i in 0..text.len() {
642                    for j in i + 1..=text.len() {
643                        let key = text[i..j].to_vec();
644                        let expected = *index_map.get(&key).unwrap();
645                        assert_eq!(kth.index_of_distinct_substring((text_idx, i..j)), expected);
646                    }
647                }
648            }
649        }
650    }
651}