后缀数组SA

简述:$花费O(nlogn)用倍增+基数排序$构建H数组完成对后缀排序。

思想在于calabashboy,实现细节和方法看dxsf。

注意细节:传入字符串0base,内部会修改成1base,传的是引用。值域都在1-n之间

struct SA {
    int n;                   // 存储字符串的长度
    vector<int> sa, rk, lc;  // sa: 后缀数组, rk: 排名数组, lc: 最长公共前缀数组 (LCP)
    SparseTable<int> qmn;
    SA(string& s) {
        n = s.length();    // 初始化字符串的长度
        sa.resize(n + 1);  // 调整 sa 的大小为 n + 1
        lc.resize(n + 1);  // 调整 lc 的大小为 n
        rk.resize(n + 1);  // 调整 rk 的大小为 n + 1
        s = " " + s;
        iota(sa.begin(), sa.end(), 0);  // 初始化 sa 为 [1, 2, ..., n]
        sort(sa.begin() + 1, sa.end(), [&](int a, int b) {
            return s[a] < s[b];  // 按照首字符对索引进行排序
        });

        // 初始化 rk 数组
        rk[sa[1]] = 1;
        for (int i = 2; i <= n; ++i)
            rk[sa[i]] = rk[sa[i - 1]] + (s[sa[i]] != s[sa[i - 1]]);

        int k = 1;                    // 初始化 k 为 1,表示当前使用的字符串长度
        vector<int> tmp, cnt(n + 1);  // tmp: 临时数组, cnt: 计数排序的频率数组
        tmp.reserve(n + 1);           // 为 tmp 预留 n + 1 个元素的空间

        while (rk[sa[n]] < n) {  // 当排名最高的后缀排名小于 n时继续循环
            tmp.clear();
            tmp.push_back(0);  // 清空 tmp 数组

            for (int i = 1; i <= k; ++i)
                tmp.push_back(n - k + i);  // 越界部分默认为空字符

            for (auto i : sa)
                if (i >= k + 1)
                    tmp.push_back(i - k);  // 按第二关键字排序

            fill(cnt.begin(), cnt.end(), 0);  // 清空 cnt 数组
            for (int i = 1; i <= n; ++i)
                ++cnt[rk[i]];  // 统计每个排名出现的频率

            for (int i = 1; i <= n; ++i)
                cnt[i] += cnt[i - 1];  // 计算计数排序中的前缀和

            for (int i = n; i >= 1; --i) {
                int tmprk = cnt[rk[tmp[i]]];
                sa[tmprk] = tmp[i];
                cnt[rk[tmp[i]]] -= 1;
            }  // 根据 tmp 中的排名重建后缀数组

            std::swap(rk, tmp);  // tmp的功能变为之前的rk桶数组
            rk[sa[1]] = 1;       // 重新初始化排名数组,首先将 sa[1] 的排名设为 1

            for (int i = 2; i <= n; ++i)
                rk[sa[i]] = rk[sa[i - 1]] + (tmp[sa[i - 1]] < tmp[sa[i]] ||
                                             sa[i - 1] + k > n || tmp[sa[i - 1] + k] < tmp[sa[i] + k]);  // 基于前后部分进行比较
            k *= 2;                                                                                      // 将 k 翻倍,以便在下一个循环中比较更长的前缀
        }

        for (int i = 1, j = 0; i <= n; ++i) {
            if (rk[i] == 1) {  // 如果当前后缀是字典序最小的,不需要计算 LCP
                j = 0;
            } else {
                for (j -= j > 0; i + j <= n && sa[rk[i] - 1] + j <= n &&
                                 s[i + j] == s[sa[rk[i] - 1] + j];)
                    ++j;        // 计算与前一个后缀的最长公共前缀长度
                lc[rk[i]] = j;  // 排名为 i 的后缀与排名为 i-1 的 LCP
            }
        }
    }
    bool flag = 0;
    void work() {
        qmn = SparseTable<int>(lc, [](int i, int j) { return min(i, j); });
        flag = 1;
    }
    int getlcp(int pos1, int pos2) {
        assert(flag == 1);
        int c1 = min(rk[pos1], rk[pos2]);
        int c2 = max(rk[pos1], rk[pos2]);
        // 只要pos1和pos2相同,排名才会相同,这种情况在外面特判
        assert(c1 < c2);
        return qmn.get(c1 + 1, c2);
    }
};

类似于dx的stl实现

#include <bits/stdc++.h>

using namespace std;

#define endl '\n'
#define int long long
#define IOS                           \
    ios_base::sync_with_stdio(false); \
    cin.tie(0);                       \
    cout.tie(0)
#define rep(i, a, n) for (int i = a; i <= n; i++)
#define per(i, a, n) for (int i = n; i >= a; i--)
#define all(x) (x).begin(), (x).end()
typedef unsigned long long ull;
typedef long long ll;
typedef pair<int, int> pii;

struct SA
{
    vector<int> rk, sa, cnt, lcp, oldrk, px, id;

    bool cmp(int x, int y, int w)
    {
        return oldrk[x] == oldrk[y] && oldrk[x + w] == oldrk[y + w];
    }

    SA(string& s)
    {
        int n = s.length(), m = 300;
        oldrk.resize(max(m + 1, 2 * n + 1));
        sa.resize(max(m + 1, n + 1));
        rk.resize(max(m + 1, n + 1));
        cnt.resize(max(m + 1, n + 1));
        lcp.resize(max(m + 1, n + 1));
        px.resize(max(m + 1, n + 1));
        id.resize(max(m + 1, n + 1));
        s = " " + s;
        for (int i = 1; i <= n; ++i)
            ++cnt[rk[i] = s[i]];
        for (int i = 1; i <= m; ++i)
            cnt[i] += cnt[i - 1];
        for (int i = n; i >= 1; --i)
            sa[cnt[rk[i]]--] = i;
        for (int w = 1, p;; w <<= 1, m = p)
        {
            p = 0;
            for (int i = n; i > n - w; --i)
                id[++p] = i;
            for (int i = 1; i <= n; ++i)
                if (sa[i] > w)
                    id[++p] = sa[i] - w;
            fill(cnt.begin(), cnt.end(), 0);
            for (int i = 1; i <= n; ++i)
                ++cnt[px[i] = rk[id[i]]];
            for (int i = 1; i <= m; ++i)
                cnt[i] += cnt[i - 1];
            for (int i = n; i >= 1; --i)
                sa[cnt[px[i]]--] = id[i];
            copy(rk.begin(), rk.end(), oldrk.begin());
            p = 0;
            for (int i = 1; i <= n; ++i)
                rk[sa[i]] = cmp(sa[i], sa[i - 1], w) ? p : ++p;
            if (p == n)
            {
                for (int i = 1; i <= n; ++i)
                    sa[rk[i]] = i;
                break;
            }
        }
        for (int i = 1, k = 0; i <= n; ++i)
        {
            if (rk[i] == 0)
                continue;
            if (k)
                --k;
            while (s[i + k] == s[sa[rk[i] - 1] + k])
                ++k;
            lcp[rk[i]] = k;
        }
    }
};

void solve()
{
    string s;
    cin >> s;
    int len = s.size();
    SA sa(s);
    for (int i = 1; i <= len; i++)
        cout << sa.sa[i] << ' ';
    cout << endl;
    for (int i = 1; i <= len; i++)
        cout << sa.lcp[i] << ' ';
    cout << endl;
}

signed main()
{
    IOS;
    int t = 1;
    cin >> t;
    while (t--)
        solve();
    return 0;
}

刷题:

1.利用$RMQ$配合$hegiht$数组进行求任意两个子串的$lcp$

Sol:考虑找到两个子串的排名,然后在排名区间内查询Height最大值。

  • 注意左端点需要右偏移1.
  • 注意特判左端点重合的情况

Problem - 4691 (hdu.edu.cn)

题意:给定字符串,给定若干子串,求相邻子串的最长公共前缀。此外以压缩字符串为背景,算大小的时候这里的数字是按字符算的.需要开longlong

void solve() {
    string s;

    while (cin >> s) {
        deb(s);
        n = s.size();
        SA sa(s);
        auto q = sa.lc;
        SparseTable<int> qmn(q, [](int i, int j) { return min(i, j); });
        cin >> m;
        int lstl, lstr;
        int sum = 0;
        int ans = 2 * m;
        deb(s);
        for (int i = 1; i <= m; i++) {
            int l2, r2;
            cin >> l2 >> r2;
            l2++;
            sum += 1 + r2 - l2 + 1;
            if (i == 1) {
                int res = 0;
                deb(i, res);
              //  deb(s.substr(l2, max(0LL, r2 - l2 + 1 - res)));
                ans += (r2 - l2 + 1) + cal(res);
                lstl = l2, lstr = r2;
                continue;
            }
            int res = min(r2 - l2 + 1, lstr - lstl + 1);
            int pos1 = sa.rk[l2], pos2 = sa.rk[lstl];
            if (pos1 != pos2) {
                res = min(res, qmn.get(min(pos1, pos2) + 1, max(pos1, pos2)));
            }

            deb(i, res);
            //deb(s.substr(l2 + res, max(0, r2 - l2 + 1 - res)));

            ans += max(0LL, r2 - l2 + 1 - res) + cal(res);
            lstl = l2;
            lstr = r2;
        }
        cout << sum << " " << ans << endl;
    }
}

2.最长公共子串https://www.luogu.com.cn/problem/SP1811

Sol:考虑将两个串用无关字符连接在一起,仔细思考会发现$\text{LCP}(s_{sa_i},s_{sa_j})=\min\limits_{k=i+1}^{j}h_k$

  • 容易想到只检查排名相邻的字符串,这一定是不劣的。
  • 真的显然吗?给出证明:考虑如果在排序后的后缀数组中,如果存在i和j相邻,这一定是答案,因为是不断取min的,距离越长lcp越小。再考虑不存在相邻部分的情况,则只可能是#把两部分隔离,那只能说明连第一个字母都存在相同的,答案一定是0,也符合结论。
void solve() {
    string s, t;
    cin >> s >> t;
    string tmp = s + "#" + t;
    SA sa(tmp);
    int len1 = s.size(), len2 = t.size();
    int ans = 0;
    auto check = [&](int i) {
        if (sa.rk[i] == 1)
            return true;
        int pos = sa.sa[sa.rk[i] - 1];
        if (i <= len1 && pos > len1 + 1)
            return true;
        if (i > len1 + 1 && pos <= len1)
            return true;
        return false;
    };
    for (int i = 1; i <= len1 + len2 + 1; i++) {
        if (check(i))
            ans = max(ans, sa.lc[sa.rk[i]]);
    }
    cout << ans << endl;
}