CCPC2023秦皇岛

C.题意:给定一个字符串,多次区间询问给定一个子串。问有多少种方案删掉最短的区间使得子串变成回文串

Sol:考虑每次询问可以独立处理,当然并不是指把子串单独提出来,而是按顺序正常回答。

  • 一个关键的点是,考虑如果第一个字符和最后一个字符不一样,那一定要删一个,依次推理下去,我们可以得到应该保留最长回文前缀或最长回文后缀。

  • 再考虑如果第一个字符和最后一个字符一样,那我们类双指针缩小删除区间,直到两边不一样。此时又回归到第一种情况。

根据上面这个观察,我们得到获得最小删除代价的做法:

  • 对于一个区间[l, r] 我们首先可以先让首尾相同字符不断相
    消, 最后剩余区间[x, y], 如果这个区间为空则为回文串。这一部分等价于求原串后缀l 和反串后缀n − r + 1 的最长公共前缀( LCP ), 再对区间长度取min。
  • 如果不为回文串, 那么$str[x] = str[y]$, x, y 其中之一会被删除。
    删除最少长度, 等价于保留最多长度, 即保留x 开头的一个
    回文串或y 结尾的一个回文串。这是经典的区间最长回文子串问题,可通过在回文自动机(PAM ) 上倍增跳fail 实现

这样我们就可以得到删除最小的区间长度,下面计算方案数。

  • 假设删除[x, x + t] 是一个最优解, 在求方案数时, 我们还可以在保持剩余字符串不变的前提下, 将此区间向左进行滑动。具体地, str[x − 1] = str[x + t] 则左移一次,str[x − 2] = str[x + t − 1] 则左移两次…
  • 这里有一个思考的点:首先这里一定是左移,为什么不能右移?如果右移,表示s[x]=s[x+t+1],又因为我们保留的是最长回文后缀,所以s[y]=s[x+t-1],则得到s[x]=x[y],这与前面的算法流程矛盾,我们一定是一直找到左右端点不相等才开始找回文前后缀的。

可以发现这里的左移次数就是求两个反串后缀的LCP, 由于L 的存在, 我们还要对x− L 取min。类似地, [y − t, y] 也可以进行右移, 且和左边的区间不交, 所
以可以直接相加。可以通过画图简单地证明只有这种形式的串才能满足最短的
要求。

  • 求LCP 可以用后缀数组,加上PAM 倍增的预处理, 所以时空复杂度均为O(n log n).

实现细节与debug:

1.注意区间位下标到回文自动机点的编号的映射

2.后缀数组的求lcp需要封装一下,要牢记lc数组是排名为i的和i-1的lcp,所以区间查询是左开右闭的。区间长度为1需要特判。

3.回文树上倍增的时候我们要讨论如果当前节点已经满足条件了。不然我们是倍增到最后一个不满足条件的点,父节点就是我们要找的点。

3.统计方案数的时候,我们需要注意判断反串的上下界,正串的上界,反串的坐标社id函数映射的。

void solve() {
    int n;
    cin >> n;
    string s;
    cin >> s;

    string rs = s;
    reverse(rs.begin(), rs.end());

    string tmp = s + "&" + rs;
    SA sa(tmp);
    SparseTable<int> qmx(sa.lc, [](int i, int j) { return min(i, j); });

    auto pre = [&](PAM& pp, string tt) {
        deb(tt);
        pp.work(tt);
        int tot = pp.t.size() - 1;
        vector<vector<int>> e(tot + 1);
        for (int i = 2; i <= tot; i++) e[pp.fail(i)].push_back(i);
        int jie = __lg(tot);
        vector st(jie + 1, vector<int>(tot + 1));

        auto dfs = [&](auto self, int u) -> void {
            for (auto v : e[u]) {
                // deb(u, v);
                st[0][v] = u;
                self(self, v);
            }
        };

        dfs(dfs, 1);
        for (int j = 1; j <= jie; j++) {
            for (int i = 1; i <= tot; i++) {
                st[j][i] = st[j - 1][st[j - 1][i]];
            }
        }
        return st;
    };
    auto getlcp = [&](int pos1, int pos2) {
        int c1 = min(sa.rk[pos1], sa.rk[pos2]);
        int c2 = max(sa.rk[pos1], sa.rk[pos2]);
        assert(c1 < c2);
        return qmx.get(c1 + 1, c2);
    };
    auto st1 = pre(pam1, s);
    auto st2 = pre(pam2, rs);
    int tot1 = pam1.size() - 1, tot2 = pam2.size() - 1;
    int jie1 = __lg(tot1), jie2 = __lg(tot2);
    deb(s);

    int m;
    cin >> m;
    auto id = [&](int x) {
        return 2 * n + 2 - x;
    };
    for (int i = 1; i <= m; i++) {
        int l, r;
        cin >> l >> r;
        int lcp1 = getlcp(l, id(r));
        int len = r - l + 1;
        if (len == 1) {
            cout << 0 << " " << 0 << endl;
            continue;
        }
        lcp1 = min(lcp1, len);
        deb(lcp1);
        if (lcp1 == len) {
            cout << 0 << " " << 0 << endl;
            continue;
        }

        int ql = l + lcp1, qr = r - lcp1;
        int cur = pam1.idpos[qr], cul = pam2.idpos[n + 1 - ql];
        deb(ql, qr);
        // deb(cur, cul);
        deb(pam1.len(cur), pam2.len(cul));
        int maxpre, maxsuf;
        if (pam1.len(cur) <= qr - ql + 1) {
            maxsuf = pam1.len(cur);
        } else {
            for (int j = jie1; j >= 0; j--) {
                if (pam1.len(st1[j][cur]) > qr - ql + 1) {
                    cur = st1[j][cur];
                }
            }
            maxsuf = max(pam1.len(st1[0][cur]), 1);
        }
        //--------------------
        if (pam2.len(cul) <= qr - ql + 1) {
            maxpre = pam2.len(cul);
        } else {
            for (int j = jie2; j >= 0; j--) {
                if (pam2.len(st2[j][cul]) > qr - ql + 1) {
                    cul = st2[j][cul];
                }
            }

            maxpre = max(1, pam2.len(st2[0][cul]));
        }
        // deb(cur, cul);
        deb(maxpre, maxsuf);
        int ans1 = qr - ql + 1 - max(maxsuf, maxpre);
        int ans2 = 0;
        assert(ans1 < qr - ql + 1);
        deb("aaa");
        if (qr - ql + 1 - maxsuf == ans1) {
            ans2++;
            int fl = ql, fr = fl + ans1 - 1;
            deb("maxsuf", fl, fr);
            if (id(fl - 1) > n + 1 && id(fl - 1) < 2 * n + 2) {
                int tmpcp = getlcp(id(fl - 1), id(fr));
                tmpcp = min(tmpcp, fl - 1 - l + 1);
                ans2 += tmpcp;
            }
        }
        deb("aaa");
        if (qr - ql + 1 - maxpre == ans1) {
            ans2++;
            int fr = qr, fl = fr - ans1 + 1;
            if ((fr + 1) < n + 1) {
                int tmpcp = getlcp(fr + 1, fl);
                tmpcp = min(tmpcp, r - (fr + 1) + 1);
                ans2 += tmpcp;
            }
        }
        cout << ans1 << " " << ans2 << endl;
    }
}