competitive/string/
suffix_array.rs

1use super::binary_search;
2use std::{cmp::Ordering, ops::Range};
3
4#[derive(Clone, Debug)]
5pub struct SuffixArray<T> {
6    pat: Vec<T>,
7    sa: Vec<usize>,
8    rank: Vec<usize>,
9}
10impl<T: Ord> SuffixArray<T> {
11    pub fn new(pat: Vec<T>) -> Self {
12        let n = pat.len();
13        let mut sa = (0..n + 1).collect::<Vec<_>>();
14        let mut rank = vec![0; n + 1];
15        let mut ford = (0..n).collect::<Vec<_>>();
16        ford.sort_by_key(|&i| &pat[i]);
17        let mut c = 1;
18        for i in 0..n {
19            rank[ford[i]] = c;
20            if i + 1 < n && pat[ford[i]] != pat[ford[i + 1]] {
21                c += 1;
22            }
23        }
24        let mut k = 1;
25        while k <= n {
26            sa.sort_by_key(|&i| (rank[i], rank.get(i + k).unwrap_or(&0)));
27            let mut tmp = vec![0; n + 1];
28            tmp[sa[0]] = 1;
29            for i in 1..n + 1 {
30                let x = sa[i - 1];
31                let y = sa[i];
32                let b = (rank[x], rank.get(x + k).unwrap_or(&0))
33                    < (rank[y], rank.get(y + k).unwrap_or(&0));
34                tmp[y] = tmp[x] + b as usize;
35            }
36            rank = tmp;
37            k *= 2;
38        }
39        Self { pat, sa, rank }
40    }
41    pub fn longest_common_prefix_array(&self) -> Vec<usize> {
42        let n = self.pat.len();
43        let mut h = 0usize;
44        let mut lcp = vec![0; n];
45        for i in 0..n {
46            let j = self[self.rank[i] - 2];
47            h = h.saturating_sub(1);
48            while j + h < n && i + h < n && self.pat[j + h] == self.pat[i + h] {
49                h += 1;
50            }
51            lcp[self.rank[i] - 2] = h;
52        }
53        lcp
54    }
55    pub fn range(&self, t: &[T], next: impl Fn(&T) -> T) -> Range<usize> {
56        let l = binary_search(
57            |&i| {
58                let mut si = self.sa[i as usize];
59                let mut ti = 0;
60                while si < self.pat.len() && ti < t.len() {
61                    match self.pat[si].cmp(&t[ti]) {
62                        Ordering::Less => return false,
63                        Ordering::Greater => return true,
64                        Ordering::Equal => {}
65                    }
66                    si += 1;
67                    ti += 1;
68                }
69                !(si >= self.pat.len() && ti < t.len())
70            },
71            self.sa.len() as isize,
72            -1,
73        ) as usize;
74        let r = binary_search(
75            |&i| {
76                let mut si = self.sa[i as usize];
77                let mut ti = 0;
78                while si < self.pat.len() && ti < t.len() {
79                    match if ti + 1 == t.len() {
80                        self.pat[si].cmp(&next(&t[ti]))
81                    } else {
82                        self.pat[si].cmp(&t[ti])
83                    } {
84                        Ordering::Less => return false,
85                        Ordering::Greater => return true,
86                        Ordering::Equal => {}
87                    }
88                    si += 1;
89                    ti += 1;
90                }
91                !(si >= self.pat.len() && ti < t.len())
92            },
93            self.sa.len() as isize,
94            l as isize - 1,
95        ) as usize;
96        l..r
97    }
98}
99impl<T> std::ops::Index<usize> for SuffixArray<T> {
100    type Output = usize;
101    fn index(&self, index: usize) -> &Self::Output {
102        &self.sa[index]
103    }
104}