【ULR #1】打击复读 (SAM, DAG链剖分)

好牛的题。 DAG链剖分好牛的 trick。

题意

给定一个字符集大小为 4 4 4,长度为 n n n 的字符串 S S S,同时给定两个长度为 n n n 的数组 { w l i } , { w r i } \{wl_i\}, \{wr_i\} {wli},{wri}。

定义一个字符串 T T T 的左权值为
v l ( T ) = ∑ S i , i + ∣ T ∣ − 1 = T w l i vl(T) = \sum\limits_{S_{i, i + |T| - 1} = T} wl_i vl(T)=Si,i+∣T∣−1=T∑wli

右权值为
v r ( T ) = ∑ S i − ∣ T ∣ + 1 , i = T w r i vr(T) = \sum\limits_{S_{i - |T| + 1, i} = T} wr_i vr(T)=Si−∣T∣+1,i=T∑wri

你需要求 ∑ l ≤ r v l ( S l , r ) × v r ( S l , r ) \sum\limits_{l \leq r} vl(S_{l, r}) \times vr(S_{l, r}) l≤r∑vl(Sl,r)×vr(Sl,r)。同时还有 q q q 次修改,每次将一个 w l u wl_u wlu 改成 v v v。你也需要回答出每次修改后的答案,答案对 2 64 2^{64} 264 取模。

1 ≤ n , q ≤ 5 × 10 5 1 \leq n,q \leq 5 \times 10^5 1≤n,q≤5×105

分析:

首先因为我们要快速求修改某个 w l i wl_i wli 后的答案,所以考虑求出每个 w l i wl_i wli 的系数 f l i fl_i fli,这样可以 O ( 1 ) O(1) O(1) 修改答案。

那么对于一个 w l i wl_i wli,它什么时候会被用到呢?

首先很容易分析出本质相同的字符串对答案的贡献相同,因此我们只考虑本质不同的字符串。

那么当且仅当计算 左端点为 i i i,右端点为 [ i , n ] [i, n] [i,n] 这 n − i + 1 n - i + 1 n−i+1 个本质不同字符串的答案时,会用到 w l i wl_i wli。

考虑一个暴力的想法:枚举 j ∈ [ i , n ] j \in [i, n] j∈[i,n],考虑计算 S i , j S_{i, j} Si,j 对 w l i wl_i wli 系数的贡献。我们建出 S A M SAM SAM,假设 S i , j S_{i, j} Si,j 定位到了 p p p 节点,那么 ∣ e n d p o s ( p ) ∣ |endpos(p)| ∣endpos(p)∣ 就是 S i , j S_{i, j} Si,j 会被计算的次数, ∑ x ∈ e n d p o s ( p ) w r x \sum\limits_{x \in endpos(p)} wr_x x∈endpos(p)∑wrx 就是每次对 w l i wl_i wli 系数的贡献。

因此设 g p = ∣ e n d p o s ( p ) ∣ × ∑ x ∈ e n d p o s ( p ) w r x g_p = |endpos(p)| \times \sum\limits_{x \in endpos(p)}wr_x gp=∣endpos(p)∣×x∈endpos(p)∑wrx,那么 g p g_p gp 是非常容易求的,并且我们能看出 f l i = ∑ p ∈ P ( i ) g p fl_i = \sum\limits_{p \in P(i)} g_p fli=p∈P(i)∑gp,其中 P ( i ) P(i) P(i) 表示左端点为 i i i,右端点在 [ i , n ] [i, n] [i,n] 这 n − i + 1 n - i + 1 n−i+1 个字符串对应的节点集合。

实际上,根据 S A M SAM SAM 自动机的性质 ,可以直到我们从 S A M SAM SAM 的起点 s s s 出发,按照 S i , n S_{i, n} Si,n 中的字符顺序每次走一条转移边就能到达 P ( i ) P(i) P(i) 中的所有节点。

这样处理每个 f l i fl_i fli 的复杂度 O ( n ) O(n) O(n),总复杂度 O ( n 2 ) O(n^2) O(n2)。考虑优化。

我们想要加速在 D A G DAG DAG 上游走过程,考虑 D A G DAG DAG 链剖分

对于一张 D A G DAG DAG,如果它的 0 0 0 度点个数大于 1 1 1,那么新建虚拟源点向这些 0 0 0 度点连边,使得 0 0 0 度点个数为 1 1 1,称这个 0 0 0 度点为源点。

设 f i f_i fi 表示 源点到 i i i 号点的路径数量 , g i g_i gi 表示 i i i 为起点,终点任意的路径数量

那么定义 D A G DAG DAG 上一条边 ( u , v ) (u, v) (u,v) 为 重边 当且仅当 u u u 是 v v v 所有入点中 f f f 最大的 并且 v v v 是 u u u 的所有出点中 g g g 最大的

重边以外的所有边称作 轻边

不难发现,重边将 D A G DAG DAG 剖分成了若干条链,不同链之间靠轻边连接。

在 D A G DAG DAG 上游走,每移动一次 f f f 都会增大, g g g 都会减小。并且每走过一条轻边,要么 f f f 翻倍,要么 g g g 除以 2 2 2,因此 任意一条路径经过的轻边数量不超过 log ⁡ V \log V logV 。换言之,它只会和 log ⁡ V \log V logV 条重链有交。其中 V V V 是 D A G DAG DAG 上的路径总条数。

也由此我们能够看出:D A G DAG DAG 链剖分的适用条件是路径总数不多 。而 S A M SAM SAM 由于每条路径都对应了一个本质不同子串,因此它的总路径条数是 O ( n 2 ) O(n^2) O(n2) 量级的,所以 S A M SAM SAM 与 D A G DAG DAG 链剖分结合是很自然的。

接着回到原来我们要优化的问题上,我们要加速求解 D A G DAG DAG 上一条路径的信息,那么发现这个东西和刚才说的 D A G DAG DAG 链剖分是恰好相适的:预处理每条链上 g g g 的前缀和,然后每次加上一条重链上一段的信息,再暴力跳到下一段即可。单次的复杂度就是 O ( log ⁡ 2 V ) O(\log_2 V) O(log2V)。

还有一个问题是怎么快速求路径和一段重链交的长度:我们维护指针 j j j,表示 [ i , n ] [i, n] [i,n] 的字符串还剩下 [ j , n ] [j, n] [j,n]。那么只需要求当前节点到所在重链底部将重边上的字母依次拼接形成的字符串和 [ j , n ] [j, n] [j,n] 的 最长公共前缀 l c p lcp lcp 的长度即可。发现一条重链也对应了一个子串:设链底节点的一个 e d p edp edp 为 p p p,链长为 l l l,那么它对应了 S [ p − l + 1 , l ] S_{[p - l + 1, l]} S[p−l+1,l]。只需要求出 S A SA SA 的 h e i g h t height height 数组就可以用 s t st st 表 O ( 1 ) O(1) O(1) 查询 l c p lcp lcp。

总复杂度 O ( n log ⁡ 2 V ) = O ( n log ⁡ 2 ( n 2 ) ) = O ( n log ⁡ 2 n ) O(n \log_2 V) = O(n \log_2(n^2)) = O(n \log_2 n) O(nlog2V)=O(nlog2(n2))=O(nlog2n)。

CODE:

cpp 复制代码
// DAG 链剖分: 适用于路径条数 V 比较小的DAG。 常与SAM结合 
// 使用 DAG 链剖分可以快速求原本需要在 DAG(SAM) 上 O(n) 游走来求解的信息 
// 单次复杂度 logV
#include<bits/stdc++.h>
#define pb emplace_back
using namespace std;
typedef unsigned long long ull;
const int N = 5e5 + 10; 
int n, m, idx[300], str[N];
char S[N]; 
ull fl[N], wl[N], wr[N];
struct SA {
	int m, sa[N], rk[N], height[N], x[N * 2], y[N * 2], c[N];
	int mn[20][N];
	inline void get_sa() {
		m = 3;
		for(int i = 1; i <= n; i ++ ) c[x[i] = str[i]] ++;
		for(int i = 1; i <= m; i ++ ) c[i] += c[i - 1];
		for(int i = n; i >= 1; i -- ) sa[c[x[i]] --] = i;
		for(int k = 1; k <= n; k <<= 1 ) { // 按照 k 排好序了, 现在排 2 * k 
			int num = 0;
			for(int i = n - k + 1; i <= n; i ++ ) y[++ num] = i;
			for(int i = 1; i <= n; i ++ ) 
				if(sa[i] > k) y[++ num] = sa[i] - k;
			for(int i = 0; i <= m; i ++ ) c[i] = 0;
			for(int i = 1; i <= n; i ++ ) c[x[i]] ++;
			for(int i = 1; i <= m; i ++ ) c[i] += c[i - 1];
			for(int i = n; i >= 1; i -- ) sa[c[x[y[i]]] --] = y[i], y[i] = 0;
			swap(x, y);
			x[sa[1]] = 1, num = 1;
			for(int i = 2; i <= n; i ++ ) 
				x[sa[i]] = (y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + k] == y[sa[i] + k]) ? num : ++ num;
			if(num == n) break;
			m = num;
		}
	}
	inline void get_height() {
		for(int i = 1; i <= n; i ++ ) rk[sa[i]] = i;
		for(int i = 1, k = 0; i <= n; i ++ ) { // 依次确定 [1, n] 后缀在 sa 数组上的 height 
			if(rk[i] == 1) continue;
			if(k) k --;
			int j = sa[rk[i] - 1];
			while(i + k <= n && j + k <= n && str[i + k] == str[j + k]) k ++;
			height[rk[i]] = k;
		}
	}
	inline void build_st() {
		for(int i = 1; i <= n; i ++ ) mn[0][i] = height[i];
		for(int i = 1; (1 << i) <= n; i ++ ) 
			for(int j = 1; j + (1 << i) - 1 <= n; j ++ ) 
				mn[i][j] = min(mn[i - 1][j], mn[i - 1][j + (1 << i - 1)]);
	}
	int query(int l, int r) {
		int k = log2(r - l + 1);
		return min(mn[k][l], mn[k][r - (1 << k) + 1]);
	}
	inline int lcp(int l1, int r1, int l2, int r2) {
		int u = rk[l1], v = rk[l2];
		if(u == v) return min(r1 - l1 + 1, r2 - l2 + 1);
		if(u > v) swap(u, v);
		return min({query(u + 1, v), r1 - l1 + 1, r2 - l2 + 1});
	}
	inline void build() {
		get_sa(); get_height();
		build_st();
	}
} sa;
struct SAM {
	struct Node {
		int fa, len;
		int ch[4];
	} node[N * 2];
	int tot = 1, last = 1;
	ull f[N * 2], g[N * 2], cnt[N * 2], sw[N * 2]; // f: 1 -> i; g: i -> ?
	vector< ull > sum[N * 2];
	vector< int > chain[N * 2];
	int edp[N * 2], bel[N * 2], bot[N * 2], pos[N * 2], Len[N * 2], in[N * 2], mxin[N * 2], mxou[N * 2];
	bool vis[N * 2];
	vector< int > E[N * 2];
	vector< int > G[N * 2];
	inline void extend(int c, int pos) {
		int p = last, np = last = ++ tot;
		node[np].len = node[p].len + 1; cnt[np] ++; sw[np] += wr[pos]; edp[np] = pos;
		for(; !node[p].ch[c] && p; p = node[p].fa) node[p].ch[c] = np;
		if(!p) node[np].fa = 1;
		else {
			int q = node[p].ch[c];
			if(node[q].len == node[p].len + 1) node[np].fa = q;
			else {
				int nq = ++ tot;
				node[nq] = node[q]; node[nq].len = node[p].len + 1;
				for(; node[p].ch[c] == q && p; p = node[p].fa) node[p].ch[c] = nq;
				node[q].fa = node[np].fa = nq;
			}
		}
	}
	void dfs(int x) {
		if(vis[x]) return ;
		vis[x] = 1; g[x] = 1;
		for(int i = 0; i < 4; i ++ ) 
			if(node[x].ch[i]) dfs(node[x].ch[i]), g[x] += g[node[x].ch[i]];
	}
	void bfs(int s) {
		for(int i = 1; i <= tot; i ++ )
			for(int j = 0; j < 4; j ++ ) 
				if(node[i].ch[j]) in[node[i].ch[j]] ++;
		queue< int > q; q.push(s); f[s] = 1;
		while(!q.empty()) {
			int u = q.front(); q.pop();
			for(int i = 0; i < 4; i ++ ) 
				if(node[u].ch[i]) {
					in[node[u].ch[i]] --; f[node[u].ch[i]] += f[u];
					if(!in[node[u].ch[i]]) q.push(node[u].ch[i]);
				}
		}
	}
	void Dfs(int x) {
		for(auto v : G[x]) {
			Dfs(v); cnt[x] += cnt[v]; sw[x] += sw[v]; edp[x] = max(edp[x], edp[v]);
		}
	}
	void DFS(int x, int b, int p) {
		bel[x] = b; bot[b] = x; Len[b] = p; pos[x] = p;
		if(p == 1) sum[b].pb(0), chain[b].pb(0);
		sum[b].pb(cnt[x] * sw[x]); chain[b].pb(x);
		for(auto v : E[x]) DFS(v, b, p + 1);
	}
	inline void build() {
		for(int i = 1; i <= n; i ++ ) extend(str[i], i);
		dfs(1); bfs(1);
		for(int i = 1; i <= n * 2; i ++ ) {
			for(int j = 0; j < 4; j ++ ) {
				int u = node[i].ch[j];
				if(!u) continue;
				if(g[u] > g[mxou[i]]) mxou[i] = u;
				if(f[i] > f[mxin[u]]) mxin[u] = i;
			}
		}
		for(int i = 1; i <= tot; i ++ ) 
			if(mxin[mxou[i]] == i) E[i].pb(mxou[i]), in[mxou[i]] ++;
		for(int i = 2; i <= tot; i ++ ) G[node[i].fa].pb(i);
		Dfs(1);
		for(int i = 1; i <= tot; i ++ ) 
			if(!in[i]) {
				DFS(i, i, 1);
				for(int j = 1; j <= Len[i]; j ++ ) sum[i][j] += sum[i][j - 1];
			}
	}
} sam;
inline ull ask(int l, int r) {
	int u = 1; ull ret = 0; int cnt = 0;
	while(l <= r) {
		int len = sa.lcp(l, r, sam.edp[sam.bot[sam.bel[u]]] - (sam.Len[sam.bel[u]] - sam.pos[u]) + 1, sam.edp[sam.bot[sam.bel[u]]]);
		ret += sam.sum[sam.bel[u]][sam.pos[u] + len] - sam.sum[sam.bel[u]][sam.pos[u]]; // 减去链头 
		l += len; cnt ++;
		if(l <= r) {
			u = sam.node[sam.chain[sam.bel[u]][sam.pos[u] + len]].ch[str[l]];
			ret += sam.cnt[u] * sam.sw[u]; l ++;
		}
	}
	return ret;
}
int main() {
	idx['A'] = 0, idx['T'] = 1, idx['G'] = 2, idx['C'] = 3; 
	scanf("%d%d", &n, &m); scanf("%s", S + 1); 
	for(int i = 1; i <= n; i ++ ) str[i] = idx[S[i]];
	for(int i = 1; i <= n; i ++ ) scanf("%llu", &wl[i]);
	for(int i = 1; i <= n; i ++ ) scanf("%llu", &wr[i]);
	sam.build(); 
	sa.build();
	ull ret = 0;
	for(int i = 1; i <= n; i ++ ) {
		fl[i] = ask(i, n); ret += fl[i] * wl[i];
	}
	for(int i = 0; i <= m; i ++ ) {
		if(i == 0) printf("%llu\n", ret);
		else {
			int u; ull v; scanf("%d%llu", &u, &v);
			ret -= fl[u] * wl[u];
			wl[u] = v; ret += fl[u] * wl[u];
			printf("%llu\n", ret);
		}
	}
	return 0;
}
相关推荐
MATLAB代码顾问2 分钟前
混合粒子群-模拟退火算法(HPSO-SA)求解作业车间调度问题——附MATLAB代码
算法·matlab·模拟退火算法
辞旧 lekkk5 分钟前
【Qt】初识(上)
开发语言·数据库·qt·学习·萌新
Felven6 分钟前
C. Prefix Min and Suffix Max
算法
加农炮手Jinx7 分钟前
LeetCode 26. Remove Duplicates from Sorted Array 题解
算法·leetcode·力扣
加农炮手Jinx7 分钟前
LeetCode 88. Merge Sorted Array 题解
算法·leetcode·力扣
Hhy_11077 分钟前
【从零开始学习数据结构 ④】:栈 ——后进先出的艺术
c语言·数据结构·学习·visual studio
格林威7 分钟前
线阵工业相机:如何计算线阵相机的行频(Line Rate)?公式+实例
开发语言·人工智能·数码相机·算法·计算机视觉·工业相机·线阵相机
2501_9271682910 分钟前
手机号测吉凶:尾数722手机号吉凶
笔记
yueyue54310 分钟前
透过现象看本质:以fast_lio架构的整套算法的局部避障改为TEB算法为例深度探讨——如何成为一个合格的算法架构师?
算法·架构
梨花爱跨境10 分钟前
红人视频×A10算法:亚马逊转化率与流量闭环实战
算法