不负责任的安利我的博客:https://blog.sengxian.com/solutions/bzoj-4650 为了防止流量太太,图我就不外链到 UOJ 来了QAQ,可以在上面的链接里面看到图。
描述
分析
算法一:枚举 $\mathrm{AABB}$ 串的中心点,则如果记 $\mathrm{pre}(i)$ 为在 $i$ 前面,有多少个以 $i$ 结尾的 $AA$ 串;$\mathrm{post}(i)$ 为在 $i$ 后面,有多少个以 $i$ 开头的 $AA$ 串,则我们的答案为:
$$\sum_{0\le i\le n - 2}\mathrm{pre}(i)\times\mathrm{post}(i + 1)$$
其中求 $\mathrm{pre}$ 以及 $\mathrm{post}$ 的方法很多,我们只需要一个工具判断两段串是否相等即可。 比较简单的做法是记 $L_{i, j}$ 为后缀 $i, j$ 的 LCP(Longest Common Prefix,最长公共前缀),则 $L_{i, j} = [s_i = s_j](L_{i + 1, j + 1} + 1)$,判断两个串相等,只需要判断两个串的起始点代表的后缀的 LCP 是否大于串的长度即可。 复杂度 $O(n^2)$,期望得分:95 分。
算法二:延续上一个算法的思路,算法一的瓶颈在于求 $\mathrm{pre}$ 以及 $\mathrm{post}$,我们换一个思路,找出所有形如 $\mathrm{AA}$ 的子串。 枚举串 $\mathrm{AA}$ 的一半长度 $len$(也就是只考虑长度为 $2 * len$ 的 $AA$ 串),我们原串每隔长度 $len$ 设置一个关键点,则所有串 $\mathrm{AA}$ 必定覆盖两个关键点,而且这两个关键点位于 $\mathrm{A}$ 的同一个位置,如下图:
 接下来我们只考虑求 $\mathrm{pre}$ 数组,因为求 $\mathrm{post}$ 的方法是大同小异的。
枚举相邻的两个关键点 $i, i + 1$,则这一次枚举,将会影响 $[(i + 1) * len, (i + 2) * len)$ 内的 $\mathrm{pre}$ 值,因为如果某个 $\mathrm{AA}$ 串覆盖关键点 $i, i + 1$ 的话,其串末尾只可能落在 $[(i + 1) * len, (i + 2) * len)$ 里面,如下图: 
设后缀 $i$ 与后缀 $i + 1$ 的 LCP 为 $x$,前缀 $i$ 与前缀 $i + 1$ 的 LCS(Longest Common Suffix,最长公共后缀) 为 $y$。 若 $x + y < len$,不存在一个长度为 $2 * len$ 的 $\mathrm{AA}$ 串覆盖关键点 $i, i + 1$。 若 $x + y = len$,存在且仅存在一个长度为 $2 * len$ 的 $\mathrm{AA}$ 串覆盖关键点 $i, i + 1$,这个串的末尾的坐标是 $(i + 1) * len + x - 1$,如下图: 
若 $x + y > len$,也就是说区间重叠了,这时可能有多个长度为 $2 * len$ 的 $\mathrm{AA}$ 串,如下图,淡绿色区间的点都是长度为 $2 * len$ 的 $\mathrm{AA}$ 串的末尾点,这个区间为 $[(i + 1) * len - x + len, (i + 1) * len + y)$,如下图:

我们来证一下,显然,我们证明区间的端点是 $\mathrm{AA}$ 串的末尾点即可: 对于点 $(i + 1) * len - x + len$ 作为末尾,这个 $\mathrm{AA}$ 串的开头就是最前面,显然成立,如下图,红色的部分相等:

而对于点 $(i + 1) * len + y - 1$,好像结论不是很显然了,我们要证红色的部分相等:

把串剥离出来考虑,可以发现,重叠的部分会导致两个串的首尾一段相等:

从而两个串都是灰色部分 + 绿色部分,相等!
也就是说,每次枚举会导致 $[(i + 1) * len, (i + 2) * len)$ 的一个子区间的 $\mathrm{pre}$ + 1,我们差分,将 $\mathrm{pre}(i)$ 变为 $\mathrm{pre}(i) - \mathrm{pre}(i - 1)$,这样区间加变为两个单点修改,最后求一次前缀和即可。求 $\mathrm{post}$ 无非是找到 $\mathrm{AA}$ 的开头的一段区间 + 1,容易类比求 $\mathrm{pre}$ 的过程求出。 LCP + LCS 采用后缀数组 + ST 表实现,枚举到关键点以后单次计算 $O(1)$,枚举长度为 L,共有 $\frac n L$ 个关键点,枚举的总复杂度是 $\frac n 1 + \frac n 2 + \frac n 3 + \cdots = O(n\log n)$,所以总复杂度 $O(Tn\log n)$。
代码
// Created by Sengxian on 8/4/16.
// Copyright (c) 2016年 Sengxian. All rights reserved.
// BZOJ 4650 NOI 2016 D1T1 后缀数组
#pragma GCC optimize("O3")
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 30000 + 3;
int logs[maxn], pre[maxn], post[maxn], n;
struct SuffixArray {
static const int maxNode = maxn;
int sa[maxNode], rank[maxNode], minHeight[15][maxNode], n;
char str[maxNode];
inline void build_sa(int m = 'z' + 3) {
static int tmpSA[maxNode], rank1[maxNode], rank2[maxNode], cnt[maxNode];
register int i;
n = strlen(str) + 1, str[n] = 0;
memset(cnt, 0, sizeof (int) * m);
for (i = 0; i < n; ++i) cnt[(int)str[i]]++;
for (i = 1; i < m; ++i) cnt[i] += cnt[i - 1];
for (i = 0; i < n; ++i) rank[i] = cnt[(int)str[i]] - 1;
for (int l = 1; l < n; l <<= 1) {
for (i = 0; i < n; ++i)
rank1[i] = rank[i], rank2[i] = i + l < n ? rank[i + l] : 0;
memset(cnt, 0, sizeof (int) * n);
for (i = 0; i < n; ++i) cnt[rank2[i]]++;
for (i = 1; i < n; ++i) cnt[i] += cnt[i - 1];
for (i = n - 1; ~i; --i) tmpSA[--cnt[rank2[i]]] = i;
memset(cnt, 0, sizeof (int) * n);
for (i = 0; i < n; ++i) cnt[rank1[i]]++;
for (i = 1; i < n; ++i) cnt[i] += cnt[i - 1];
for (i = n - 1; ~i; --i) sa[--cnt[rank1[tmpSA[i]]]] = tmpSA[i];
bool unique = true;
rank[sa[0]] = 0;
for (i = 1; i < n; ++i) {
rank[sa[i]] = rank[sa[i - 1]];
if (rank1[sa[i]] == rank1[sa[i - 1]] && rank2[sa[i]] == rank2[sa[i - 1]]) unique = false;
else rank[sa[i]]++;
}
if (unique) break;
}
}
inline void getHeight() {
minHeight[0][0] = 0;
for (int i = 0, j = 0, k = 0; i < n - 1; ++i) {
if (k) --k;
j = sa[rank[i] - 1];
while (str[i + k] == str[j + k]) k++;
minHeight[0][rank[i]] = k;
}
for (int w = 1; (1 << w) <= n; ++w)
for (int i = 0; i + (1 << w) <= n; ++i)
minHeight[w][i] = min(minHeight[w - 1][i], minHeight[w - 1][i + (1 << (w - 1))]);
}
inline int query(int l, int r) {
static int bit;
bit = logs[r - l];
return min(minHeight[bit][l], minHeight[bit][r - (1 << bit)]);
}
inline int LCP(int l, int r) {
l = rank[l], r = rank[r];
if (l > r) swap(l, r);
return query(l + 1, r + 1);
}
}SA, rSA;
inline int LCP(int i, int j) {
return SA.LCP(i, j);
}
inline int LCS(int i, int j) {
return rSA.LCP(n - i - 1, n - j - 1);
}
ll solve() {
ll ans = 0;
memset(pre, 0, sizeof (int) * (n + 1));
memset(post, 0, sizeof (int) * (n + 1));
for (int len = 1, x, y, l, r; (len << 1) <= n; ++len)
for (int i = 0, j = len; j < n; i += len, j += len) if (SA.str[i] == SA.str[j]) {
x = LCS(i, j), y = LCP(i, j), l = max(i - x + len, i), r = min(i + y, j);
if (r - l >= 1) {
pre[l + len]++, pre[r + len]--;
post[l - len + 1]++, post[r - len + 1]--;
}
}
for (int i = 1; i < n; ++i) pre[i] += pre[i - 1], post[i] += post[i - 1];
for (int i = 0; i < n - 1; ++i)
ans += (ll)pre[i] * post[i + 1];
return ans;
}
int main() {
logs[0] = logs[1] = 0;
for (int i = 2; i < maxn; ++i) logs[i] = logs[i >> 1] + 1;
int caseNum; scanf("%d", &caseNum);
while (caseNum--) {
scanf("%s", SA.str), n = strlen(SA.str);
memcpy(rSA.str, SA.str, sizeof SA.str);
reverse(rSA.str, rSA.str + n);
SA.build_sa(), rSA.build_sa(), SA.getHeight(), rSA.getHeight();
printf("%lld\n", solve());
}
return 0;
}