题目链接
https://www.luogu.com.cn/problem/P3435
思路
我们先用扩展KMP算法对字符串 s s s进行预处理,求出 z z z数组。
对于字符串 s s s的第 i i i个字符,对于其 z z z数组的值 z [ i ] z[i] z[i],如果 z [ i ] ≥ 0 z[i] \ge 0 z[i]≥0,则区间 [ i , i + z [ i ] − 1 ] [i,i+z[i]-1] [i,i+z[i]−1]上的前缀最大值将变成 i i i(我们假设字符串的下标从 0 0 0开始)。
因此我们可以在统计答案的过程中,用线段树来维护区间最大值,保证计算的答案是最优的。
时间复杂度: O ( n l o g 2 n ) O(nlog_{2}n) O(nlog2n)
代码
cpp
#include <bits/stdc++.h>
using namespace std;
// #define int long long
#define double long double
typedef long long i64;
typedef unsigned long long u64;
typedef pair<int, int> pii;
const int N = 1e6 + 5, M = 2e2 + 5;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f3f3f3f3f;
std::mt19937 rnd(time(0));
int n;
string s;
// z[i]表示 s 和 s[i:] 匹配的最大前缀长度
vector<int> zfunc(const string& s) {
int n = s.size();
vector<int> z(n);
for (int i = 1, l = 0, r = 0; i < n; i++) {
if (i <= r) z[i] = min(z[i - l], r - i + 1);
while (i + z[i] < n && s[i + z[i]] == s[z[i]]) z[i]++;
if (i + z[i] - 1 > r) l = i, r = i + z[i] - 1;
}
z[0] = n;
return z;
}
struct segmenttree
{
struct node
{
int l, r, maxx, tag;
};
vector<node>tree;
segmenttree(): tree(1) {}
segmenttree(int n): tree(n * 4 + 1) {}
void pushup(int u)
{
auto &root = tree[u], &left = tree[u << 1], &right = tree[u << 1 | 1];
root.maxx = max(left.maxx, right.maxx);
}
void pushdown(int u)
{
auto &root = tree[u], &left = tree[u << 1], &right = tree[u << 1 | 1];
if (root.tag != 0)
{
left.tag = root.tag;
right.tag = root.tag;
left.maxx = root.tag;
right.maxx = root.tag;
root.tag = 0;
}
}
void build(int u, int l, int r)
{
auto &root = tree[u];
root = {l, r};
if (l == r)
{
root.maxx = 0;
}
else
{
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int l, int r, int val)
{
auto &root = tree[u];
if (root.l >= l && root.r <= r)
{
root.maxx = val;
root.tag = val;
return;
}
pushdown(u);
int mid = root.l + root.r >> 1;
if (l <= mid) modify(u << 1, l, r, val);
if (r > mid) modify(u << 1 | 1, l, r, val);
pushup(u);
}
int query(int u, int l, int r)
{
auto &root = tree[u];
if (root.l >= l && root.r <= r)
{
return root.maxx;
}
pushdown(u);
int mid = root.l + root.r >> 1;
int res = -inf;
if (l <= mid) res = query(u << 1, l, r);
if (r > mid) res = max(res, query(u << 1 | 1, l, r));
return res;
}
};
void solve(int test_case)
{
cin >> n >> s;
vector<int>z = zfunc(s);
segmenttree smt(n);
smt.build(1, 1, n);
i64 ans = 0;
for (int i = 0; i < n; i++)
{
int low = i + 1, high = min(i + z[i], n);
if (low <= high)
{
smt.modify(1, low, high, i);
}
int maxx = smt.query(1, i + 1, i + 1);
ans += maxx;
}
cout << ans << endl;
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int test = 1;
// cin >> test;
for (int i = 1; i <= test; i++)
{
solve(i);
}
return 0;
}