一、线段树解决的问题
有 n (n ≤ 1e5) 个数,q (1 ≤ 1e5) 次操作,每次操作为询问区间 [l,r] 的和。
有 n (n ≤ 1e5) 个数,q (1 ≤ 1e5) 次操作,操作有两种:a.查询区间 [l,r] 的和; b.将第 i 个数修改成 x。
有 n (n ≤ 1e5) 个数,q (1 ≤ 1e5) 次操作,操作有两种:a. 查询区间 [l,r] 的和; b. 将区间 [l,r] 的数全部修改成 x。
有 n(n ≤ 1e5) 个数,q(1 ≤ 1e5) 次操作,每次操作为区间 [l,r] 的最大值或者最小值。 RMQ(Range Minimum/Maximum Query)问题。
以上问题,采用暴力解法显然会超时。用一种树形数据结构 - 线段树来解决:
- 线段树是一棵二叉树,常用来维护区间信息;
- 可以在 log 级别的时间复杂度内完成:区间的单点修改,区间修改、区间查询(区间和,区间最大 / 最小值)等操作。
二、线段树的构建
线段树是基于分治思想的二叉树,**树中的每一个结点都会维护一段区间的信息。**其中叶子结点存储元素本身,非叶结点维护区间内元素的信息。
以数组 a=[5,1,3,0,2,2,7,4,5,8] 为例**,如果查询的是区间和,我们会创建出来这样一棵树来维护信息:** 
根据构建方式,可以得到以下性质:
- 线段树的每个结点都维护一个区间的信息
- 线段树中的根节点维护整个区间的信息,叶子结点维护长度为 1 的区间信息;
- 可以用结构体数组来实现线段树,类似堆的存储方式,也就是二叉树的静态存储。此时父节点的编号为 p 时,左孩子编号为 p×2,右孩子编号为 p×2+1;
- 若当前结点维护的区间为 [l,r],那么左右孩子分别维护 [l,mid] 以及 [mid+1,r] 区间的信息;
- 线段树的空间,需要开最大区间的 4 倍。
代码:
cpp
#define lc p << 1
#define rc p << 1 | 1
typedef long long LL;
const int N = 1e5 + 10;
int n, m;
LL a[N];
struct node
{
LL l, r, sum;
}tr[N << 2];
// 整合左右孩子的信息
void pushup(int p)
{
tr[p].sum = tr[lc].sum + tr[rc].sum;
}
void build(int p, int l, int r) // 建树
{
tr[p] = {l, r, 0}; // 初始化
if(l == r) // 如果是叶子结点
{
tr[p].sum = a[l]; // 更新 sum 的值
return;
}
int mid = (l + r) >> 1; // 一分为二
build(lc, l, mid); // 构建左子树
build(rc, mid + 1, r); // 构建右子树
// tr[p].sum = tr[lc].sum + tr[rc].sum; // 左右子树构建完成之后,维护 sum 信息
pushup(p);
}
时间复杂度:O(n)
三、区间查询
对于一个待查询的区间,用拆分 + 拼凑的思想,在线段树的结点中收集结果。具体流程:
- 从根节点出发,向下递归;
- 如果当前结点维护的区间信息包含在待查询的区间内,直接返回结点维护的信息;
- 如果左区间有重叠,去左子树上找结果;
- 如果右区间有重叠,去右子树上找结果。
以数组 a=[5,1,3,0,2,2,7,4,5,8] 为例,如果查询的是区间 [3,8] 的和,维护的信息如下:

代码:
cpp
// 区间查询
LL query(int p, int x, int y)
{
LL l = tr[p].l, r = tr[p].r; // 当前结点维护的信息
if(l >= x && r <= y) return tr[p].sum; // 如果是查询区间的子区间,返回结果
LL sum = 0, mid = (l + r) >> 1;
if(x <= mid) sum += query(lc, x, y);
if(y > mid) sum += query(rc, x, y);
return sum;
}
时间复杂度: 从上面的图中发现,感觉时间复杂度有可能会很高啊?因为从上往下遍历的时候,依旧会遇到很多点。实际上,查询的过程只会沿着两道线下来,每一条线最多在某个结点位置分出去一个节点。因此,整体的时间复杂度为 O(logn)。
四、单点修改
例如:将 x=6 位置上的数加上 3。(同理:对单个位置上的数执行:减去一个数,乘上一个数,除以一个数的操作)
具体流程:
- 递归找到叶子结点,并且维护修改之后的信息;
- 然后一路向上回溯,修改所有路径上的结点信息,使得维护的信息为修改之后的信息。
以数组 a=[5,1,3,0,2,2,7,4,5,8] 为例,如果将 x=6 位置上的元素增加 3,维护的信息如下:

代码:
cpp
// 单点修改
void modify(int p, int x, LL k)
{
int l = tr[p].l, r = tr[p].r;
if(l == x && r == x)
{
tr[p].sum += k;
return;
}
int mid = (l + r) >> 1;
if(x <= mid) modify(lc, x, k);
else modify(rc, x, k);
pushup(p);
}
来做个题吧:
五、区间修改(懒标记)
其实搞懂原理很简单
例如:对区间 [4,9] 上每个元素增加 2。如果按照单点修改的方式,把所有 [4,9] 所覆盖的结点全部修改,时间复杂度将是 O (n)。
试想一下,如果某个结点维护的区间 [l, r] 被修改的区间 [x, y] 完全覆盖时,如果能够在 O (1) 时间内修改区间维护的信息,那么左右子树其实没有必要修改。可以等到下次遇到的时候,再去处理。
借助这样的思想,我们会在每一个结点中额外维护一个懒标记:
- 如果当前结点维护的区间 [l, r] 被待查询区间 [x, y] 完全覆盖时,停止递归,根据区间长度维护出增加元素之后的和;不去处理左右孩子,打上一个区间增加一个值的懒标记;
- 等到下次修改或者查询操作,遇到该节点时,再把懒标记下放给左右孩子。
这样,就可以把时间控制的与查询时间一致,都是 log (n) 级别。
以数组 a = [5, 1, 3, 0, 2, 2, 7, 4, 5, 8] 为例,如果对区间 [4, 9] 上每个元素增加 2,维护的信息如下:(叶子节点的add也会增加,这个图中没有画出来)

如果执行查询操作:查询区间 [5, 7] 上所有元素的和,维护信息如下:

简单总结一下,当涉及区间修改,加上懒标记之后,查询和修改操作递归到某个区间后:
如果当前结点维护的区间 [l, r] 包含在查询的区间 [x, y] 中:
a. 此时就可以根据区间长度计算出区间和,没有必要继续递归下去;
b. 利用区间长度计算出区间和,打上一个懒标记,就可以向上返回。
如果当前结点维护的区间 [l, r] 只有一部分在查询区间 [x, y] 中:
a. 把该节点存储的懒标记下放一层 pushdown;
b. 根据查询区间的范围,递归到左右区间;
c. 等左右区间处理完毕之后,维护当前结点的区间和信息 pushup。
注意:记得下发懒标记
cpp
struct node
{
int l, r;
LL sum, add;
}tr[N * 4];
// 接收到修改任务,修改完毕之后,把修改信息懒下来
void lazy(int p, LL add)
{
int l = tr[p].l, r = tr[p].r;
tr[p].sum += (r - l + 1) * add;
tr[p].add += add;
}
void pushup(int p)
{
tr[p].sum = tr[lc].sum + tr[rc].sum;
}
void pushdown(int p)
{
if(tr[p].add)
{
lazy(lc, tr[p].add); // 懒标记分给左孩子
lazy(rc, tr[p].add); // 懒标记分给右孩子
tr[p].add = 0;
}
}
// 创建 build
void build(int p, int l, int r)
{
tr[p] = {l, r, a[l], 0};
if (l == r)
{
tr[p].sum = a[l];
return;
}
int mid = (l + r) >> 1;
build(lc, l, mid); build(rc, mid + 1, r);
pushup(p);
}
// 区间查询
LL query(int p, int x, int y)
{
int l = tr[p].l, r = tr[p].r;
if (x <= l && r <= y) return tr[p].sum;
pushdown(p); // 懒标记下放
LL sum = 0, mid = (l + r) >> 1;
if (x <= mid) sum += query(lc, x, y);
if (y > mid) sum += query(rc, x, y);
return sum;
}
// 区间修改
void modify(int p, int x, int y, LL k)
{
int l = tr[p].l, r = tr[p].r;
if (x <= l && r <= y)
{
// 修改之后,打上标记
tr[p].sum += k * (r - l + 1);
tr[p].add += k;
return;
}
int mid = (l + r) >> 1;
pushdown(p); // 懒标记下放
if (x <= mid) modify(lc, x, y, k);
if (y > mid) modify(rc, x, y, k);
pushup(p); // 更新父节点
}
时间复杂度:
由于加上了懒标记,所有的操作与区间查询的过程一致,整体的时间复杂度为 log (n)。
做个题吧
六、维护更多类型的信息
1.区间最小值
cpp
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define endl '\n'
#define lc p << 1
#define rc p << 1 | 1
const int N = 5e5 + 10;
int a[N];
struct node
{
int l, r;
int mi;
}tr[N << 2];
void pushup(int p)
{
tr[p].mi = min(tr[lc].mi, tr[rc].mi);
}
void build(int p, int l, int r)
{
tr[p] = {l, r, a[l]};
if(l == r) return;
int mid = (l + r) >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
pushup(p);
}
int query(int p, int x, int y)
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y)
{
return tr[p].mi;
}
int mid = (l + r) >> 1;
int mi = 1e18;
if(x <= mid) mi = min(query(lc, x, y), mi);
if(y > mid) mi = min(query(rc, x, y), mi);
return mi;
}
void solve()
{
int n, m; cin >> n >> m;
for(int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
while(m--)
{
int x, y; cin >> x >> y;
cout << query(1, x, y) << " ";
}
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int _ = 1;
// cin >> _;
while(_--)
{
solve();
}
return 0;
}
2.懒标记:翻转次数
cpp
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define endl '\n'
#define lc p << 1
#define rc p << 1 | 1
const int N = 5e5 + 10;
struct node
{
int l, r;
int g, k;
int lazy;
}tr[N << 2];
void pushup(int p)
{
tr[p].g = tr[lc].g + tr[rc].g;
tr[p].k = tr[lc].k + tr[rc].k;
}
void build(int p, int l, int r)
{
tr[p] = {l, r, 1, 0, 0};
if(l == r) return;
int mid = (l + r) >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
pushup(p);
}
void lazy(int p)
{
swap(tr[p].g, tr[p].k);
tr[p].lazy ^= 1;
}
void pushdown(int p)
{
if(tr[p].lazy)
{
lazy(lc);
lazy(rc);
tr[p].lazy = 0;
}
}
int query(int p, int x, int y)
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y)
{
return tr[p].k;
}
pushdown(p);
int mid = (l + r) >> 1;
int sum = 0;
if(x <= mid) sum += query(lc, x, y);
if(y > mid) sum += query(rc, x, y);
return sum;
}
void modify(int p, int x, int y)
{
int l = tr[p].l;
int r = tr[p].r;
pushdown(p);
if(l >= x && r <= y)
{
lazy(p);
return;
}
int mid = (l + r) >> 1;
if(x <= mid) modify(lc, x, y);
if(y > mid) modify(rc, x, y);
pushup(p);
}
void solve()
{
int n, m; cin >> n >> m;
build(1, 1, n);
while(m--)
{
int c, a, b;
cin >> c >> a >> b;
if(c == 0)
{
modify(1, a, b);
}
else
{
cout << query(1, a, b) << endl;
}
}
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int _ = 1;
// cin >> _;
while(_--)
{
solve();
}
return 0;
}
3.线段个数
在区间 [l, r] 内,如果知道 [1, l-1] 中线段的终点数量 以及 [1, r] 中线段的起点数量,后者减前者就是 [l, r] 内的线段种类。
因此,利用线段树维护区间内线段的起点数量和终点数量即可。
cpp
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
#define lc p << 1
#define rc p << 1 | 1
int n, m;
struct node
{
int l, r, sta, end;
}tr[N << 2];
void pushup(int p)
{
tr[p].sta = tr[lc].sta + tr[rc].sta;
tr[p].end = tr[lc].end + tr[rc].end;
}
void build(int p, int l, int r)
{
tr[p] = {l, r, 0, 0};
if(l == r) return;
int mid = (l + r) >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
pushup(p);
}
void m1(int p, int x)
{
int l = tr[p].l;
int r = tr[p].r;
if(l == x && r == x)
{
tr[p].sta++;
return;
}
int mid = (l + r) >> 1;
if(x <= mid) m1(lc, x);
else m1(rc, x);
pushup(p);
}
void m2(int p, int y)
{
int l = tr[p].l;
int r = tr[p].r;
if(l == y && r == y)
{
tr[p].end++;
return;
}
int mid = (l + r) >> 1;
if(y <= mid) m2(lc, y);
else m2(rc, y);
pushup(p);
}
int q1(int p, int x, int y) // sta
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y)
{
return tr[p].sta;
}
int sum = 0;
int mid = (l + r) >> 1;
if(x <= mid) sum += q1(lc, x, y);
if(y > mid) sum += q1(rc, x, y);
return sum;
}
int q2(int p, int x, int y) // end
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y)
{
return tr[p].end;
}
int sum = 0;
int mid = (l + r) >> 1;
if(x <= mid) sum += q2(lc, x, y);
if(y > mid) sum += q2(rc, x, y);
return sum;
}
int main()
{
cin >> n >> m;
build(1, 1, n);
while(m--)
{
int q, l, r;
cin >> q >> l >> r;
if(q == 1)
{
m1(1, l);
m2(1, r);
}
else if(q == 2)
{
cout << q1(1, 1, r) - q2(1, 1, l - 1) << endl;
}
}
return 0;
}
4.线段树+等差数列
可以直接用等差数列求和来计算,也可以用差分来转换一下问题
cpp
// #include <bits/stdc++.h>
// using namespace std;
// const int N = 1e5 + 10;
// #define lc p << 1
// #define rc p << 1 | 1
// #define int long long
// int n, m;
// struct node
// {
// int l, r, sum, k, d;
// }tr[N << 2];
// int a[N];
// void pushup(int p)
// {
// tr[p].sum = tr[lc].sum + tr[rc].sum;
// }
// void build(int p, int l, int r)
// {
// tr[p] = {l, r, a[l], 0, 0};
// if(l == r) return;
// int mid = (l + r) >> 1;
// build(lc, l, mid);
// build(rc, mid + 1, r);
// pushup(p);
// }
// void lazy(int p, int k, int d)
// {
// int n = tr[p].r - tr[p].l + 1;
// tr[p].sum += n * k + (n * (n - 1)) / 2 * d;
// tr[p].k += k;
// tr[p].d += d;
// }
// void pushdown(int p)
// {
// int mid = (tr[p].l + tr[p].r) >> 1;
// lazy(lc, tr[p].k, tr[p].d);
// lazy(rc, tr[p].k + (mid + 1 - tr[p].l) * tr[p].d, tr[p].d);
// tr[p].k = tr[p].d = 0;
// }
// void modify(int p, int x, int y, int k, int d)
// {
// int l = tr[p].l;
// int r = tr[p].r;
// if(l >= x && r <= y)
// {
// lazy(p, k + (l - x) * d, d);
// return;
// }
// pushdown(p);
// int mid = (l + r) >> 1;
// if(x <= mid) modify(lc, x, y, k, d);
// if(y > mid) modify(rc, x, y, k, d);
// pushup(p);
// }
// int query(int p, int x)
// {
// int l = tr[p].l;
// int r = tr[p].r;
// if(x == l && x == r) return tr[p].sum;
// pushdown(p);
// int mid = (l + r) >> 1;
// if(x <= mid) return query(lc, x);
// else return query(rc, x);
// }
// signed main()
// {
// cin >> n >> m;
// for(int i = 1; i <= n; i++) cin >> a[i];
// build(1, 1, n);
// while(m--)
// {
// int q;
// cin >> q;
// if(q == 1)
// {
// int l, r, k, d;
// cin >> l >> r >> k >> d;
// modify(1, l, r, k, d);
// }
// else if(q == 2)
// {
// int x; cin >> x;
// cout << query(1, x) << endl;
// }
// }
// return 0;
// }
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
#define lc p << 1
#define rc p << 1 | 1
#define int long long
int n, m;
struct node
{
int l, r, sum, add;
}tr[N << 2];
int a[N];
void pushup(int p)
{
tr[p].sum = tr[lc].sum + tr[rc].sum;
}
void build(int p, int l, int r)
{
tr[p] = {l, r, a[l], 0};
if(l == r) return;
int mid = (l + r) >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
pushup(p);
}
void lazy(int p, int k)
{
int l = tr[p].l;
int r = tr[p].r;
tr[p].sum += (r - l + 1) * k;
tr[p].add += k;
}
void pushdown(int p)
{
if(tr[p].add)
{
lazy(lc, tr[p].add);
lazy(rc, tr[p].add);
tr[p].add = 0;
}
}
void modify(int p, int x, int y, int k)
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y)
{
lazy(p, k);
return;
}
pushdown(p);
int mid = (l + r) >> 1;
if(x <= mid) modify(lc, x, y, k);
if(y > mid) modify(rc, x, y, k);
pushup(p);
}
int query(int p, int x, int y)
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y)
{
return tr[p].sum;
}
pushdown(p);
int mid = (l + r) >> 1;
int sum = 0;
if(x <= mid) sum += query(lc, x, y);
if(y > mid) sum += query(rc, x, y);
return sum;
}
signed main()
{
cin >> n >> m;
for(int i = 1; i <= n; i++)
{
int x; cin >> x;
a[i] += x;
a[i + 1] -= x;
}
build(1, 1, n);
while(m--)
{
int q;
cin >> q;
if(q == 1)
{
int l, r, k, d;
cin >> l >> r >> k >> d;
modify(1, l, l, k);
if(l + 1 <= r) modify(1, l + 1, r, d);
if(r + 1 <= n) modify(1, r + 1, r + 1, -(k + (r - l) * d));
}
else if(q == 2)
{
int x; cin >> x;
cout << query(1, 1, x) << endl;
}
}
return 0;
}
七、多个区间操作
1.【模板】线段树 2

cpp
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define endl '\n'
#define lc p << 1
#define rc p << 1 | 1
const int N = 1e5 + 10;
int n, m, mod;
int a[N];
struct node
{
int l, r, sum, mul, add;
}tr[N << 2];
void pushup(int p)
{
tr[p].sum = (tr[lc].sum + tr[rc].sum) % mod;
}
void build(int p, int l, int r)
{
tr[p] = {l, r, a[l], 1, 0};
if(l == r) return;
int mid = (l + r) >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
pushup(p);
}
void lazy(int p, int mul, int add)
{
tr[p].sum = (tr[p].sum * mul % mod + add * (tr[p].r - tr[p].l + 1)) % mod;
tr[p].add = (tr[p].add * mul + add) % mod;
tr[p].mul = tr[p].mul * mul % mod;
}
void pushdown(int p)
{
lazy(lc, tr[p].mul, tr[p].add);
lazy(rc, tr[p].mul, tr[p].add);
tr[p].mul = 1;
tr[p].add = 0;
}
void modify(int p, int x, int y, int mul, int add)
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y)
{
lazy(p, mul, add);
return;
}
pushdown(p);
int mid = (l + r) >> 1;
if(x <= mid) modify(lc, x, y, mul, add);
if(y > mid ) modify(rc, x, y, mul, add);
pushup(p);
}
int query(int p, int x, int y)
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y) return tr[p].sum;
int sum = 0;
pushdown(p);
int mid = (l + r) >> 1;
if(x <= mid) sum += query(lc, x, y);
if(y > mid ) sum += query(rc, x, y);
return sum % mod;
}
signed main()
{
cin >> n >> m >> mod;
for(int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
while(m--)
{
int op; cin >> op;
if(op == 1)
{
int x, y, k; cin >> x >> y >> k;
modify(1, x, y, k, 0);
}
else if(op == 2)
{
int x, y, k; cin >> x >> y >> k;
modify(1, x, y, 1, k);
}
else if(op == 3)
{
int x, y; cin >> x >> y;
cout << query(1, x, y) << endl;
}
}
return 0;
}
2.涉及区间重置问题

cpp
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define endl '\n'
#define lc p << 1
#define rc p << 1 | 1
const int N = 1e6 + 10;
int n, q;
int a[N];
struct node
{
int l, r, mx;
int add;
int update;
bool st;
}tr[N << 2];
void pushup(int p)
{
tr[p].mx = max(tr[lc].mx, tr[rc].mx);
}
void build(int p, int l, int r)
{
tr[p] = {l, r, a[l], 0, 0};
if(l == r) return;
int mid = (l + r) >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
pushup(p);
}
void lazy(int p, bool st, int update, int add)
{
if(st)
{
tr[p].mx = update;
tr[p].st = st;
tr[p].update = update;
tr[p].add = 0;
}
tr[p].add += add;
tr[p].mx += add;
}
void pushdown(int p)
{
lazy(lc, tr[p].st, tr[p].update, tr[p].add);
lazy(rc, tr[p].st, tr[p].update, tr[p].add);
tr[p].st = false;
tr[p].update = 0;
tr[p].add = 0;
}
void modify(int p, int x, int y, bool st, int update, int add)
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y)
{
lazy(p, st, update, add);
return;
}
pushdown(p);
int mid = (l + r) >> 1;
if(x <= mid) modify(lc, x, y, st, update, add);
if(y > mid) modify(rc, x, y, st, update, add);
pushup(p);
}
int query(int p, int x, int y)
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y)
{
return tr[p].mx;
}
int mx = -1e18;
pushdown(p);
int mid = (l + r) >> 1;
if(x <= mid) mx = max(mx, query(lc, x, y));
if(y > mid ) mx = max(mx, query(rc, x, y));
return mx;
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin >> n >> q;
for(int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
while(q--)
{
int op; cin >> op;
if(op == 1)
{
int l, r, x; cin >> l >> r >> x;
modify(1, l, r, 1, x, 0);
}
else if(op == 2)
{
int l, r, x; cin >> l >> r >> x;
modify(1, l, r, 0, 0, x);
}
else
{
int l, r; cin >> l >> r;
cout << query(1, l, r) << endl;
}
}
return 0;
}
八、线段树+分治
线段树本身就是基于分治思想的二叉树。那么对于很多可以通过分治解决的问题,查询起来也可以通过线段树来维护。其中,最经典的就是最大子段和问题。
线段树的区间查询是由多个小区间中的信息拼凑而成。面对上述问题,在某些情况下,查询过程中单单返回一个值是不足以拼凑出结果的,需要返回的是一个结构体。
1.最大子段和
对于最大子段和,用分治思想解决方式为下面三种情况的最大值:
- 左区间的最大子段和
max;- 右区间的最大子段和
max;- 左区间从右端点开始的最大段
rmax+ 右区间从左端点开始的最大段lmax。因此,线段树中需要维护
max,lmax,rmax。对于
pushup:
max:左区间最大值、右区间最大值以及左区间的rmax+ 右区间lmax三者的最大值;lmax:左区间的lmax、以及左区间的和sum+ 右区间的lmax两者的最大值;rmax:右区间的rmax、以及右区间的和sum+ 左区间的rmax两者的最大值;因此,还需要维护区间和
sum。对于
query:
- 查询的过程是将大区间分成若干小区间查询,那么组合成最大子段的时候,又是一个分治的过程。
- 因此,查询的结果应该返回一个结构体。通过左右孩子返回的结构体,拼凑出来当前结点的
max,lmax,rmax信息。
cpp
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define endl '\n'
#define lc p << 1
#define rc p << 1 | 1
const int N = 5e5 + 10;
int n, m;
int a[N];
struct node
{
int l, r;
int sum;
int max, lmax, rmax;
}tr[N << 2];
void pushup(node& a, node& b, node& c)
{
a.max = max({b.max, c.max, b.rmax + c.lmax});
a.lmax = max(b.lmax, b.sum + c.lmax);
a.rmax = max(c.rmax, c.sum + b.rmax);
a.sum = b.sum + c.sum;
}
void build(int p, int l, int r)
{
tr[p] = {l, r, a[l], a[l], a[l], a[l]};
if(l == r) return;
int mid = (l + r) >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
pushup(tr[p], tr[lc], tr[rc]);
}
void modify(int p, int x, int y)
{
int l = tr[p].l;
int r = tr[p].r;
if(l == x && r == x)
{
tr[p].sum = tr[p].max = tr[p].lmax = tr[p].rmax = y;
return;
}
int mid = (l + r) >> 1;
if(x <= mid) modify(lc, x, y);
else modify(rc, x, y);
pushup(tr[p], tr[lc], tr[rc]);
}
node query(int p, int x, int y)
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y)
{
return tr[p];
}
int mid = (l + r) >> 1;
if(y <= mid) return query(lc, x, y);
if(x > mid) return query(rc, x, y);
node ret, L = query(lc, x, y), R = query(rc, x, y);
pushup(ret, L, R);
return ret;
}
void solve()
{
cin >> n >> m;
for(int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
while(m--)
{
int k, a, b; cin >> k >> a >> b;
if(k == 1)
{
if(a > b) swap(a, b);
cout << query(1, a, b).max << endl;
}
else
{
modify(1, a, b);
}
}
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int _ = 1;
while(_--)
{
solve();
}
return 0;
}
2.上面知识的综合体
cpp
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define endl '\n'
#define lc p << 1
#define rc p << 1 | 1
const int N = 1e5 + 10;
int a[N];
struct node
{
int l, r;
int s0, l0, r0, m0;
int s1, l1, r1, m1;
int st;
int rev;
}tr[N << 2];
void pushup(node& p, node& l, node& r)
{
p.s0 = l.s0 + r.s0;
p.l0 = l.s1 == 0 ? l.s0 + r.l0 : l.l0;
p.r0 = r.s1 == 0 ? r.s0 + l.r0 : r.r0;
p.m0 = max({l.m0, r.m0, l.r0 + r.l0});
p.s1 = l.s1 + r.s1;
p.l1 = l.s0 == 0 ? l.s1 + r.l1 : l.l1;
p.r1 = r.s0 == 0 ? r.s1 + l.r1 : r.r1;
p.m1 = max({l.m1, r.m1, l.r1 + r.l1});
}
void build(int p, int l, int r)
{
tr[p] = {l, r, a[l] ^ 1, a[l] ^ 1, a[l] ^ 1, a[l] ^ 1, a[l] , a[l], a[l], a[l], -1, -1};
if(l == r) return;
int mid = (l + r) >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
pushup(tr[p], tr[lc], tr[rc]);
}
void lazy(int p, int st, int rev)
{
int l = tr[p].l;
int r = tr[p].r;
if(st == 0)
{
tr[p].s1 = tr[p].l1 = tr[p].r1 = tr[p].m1 = 0;
tr[p].s0 = tr[p].l0 = tr[p].r0 = tr[p].m0 = r - l + 1;
tr[p].st = st;
tr[p].rev = -1;
// rev = -1;
}
if(st == 1)
{
tr[p].s1 = tr[p].l1 = tr[p].r1 = tr[p].m1 = r - l + 1;
tr[p].s0 = tr[p].l0 = tr[p].r0 = tr[p].m0 = 0;
tr[p].st = st;
tr[p].rev = -1;
// rev = -1;
}
if(rev == 2)
{
swap(tr[p].l0, tr[p].l1);
swap(tr[p].r0, tr[p].r1);
swap(tr[p].s0, tr[p].s1);
swap(tr[p].m0, tr[p].m1);
tr[p].rev = tr[p].rev == 2 ? -1 : 2;
}
}
void pushdown(int p)
{
lazy(lc, tr[p].st, tr[p].rev);
lazy(rc, tr[p].st, tr[p].rev);
tr[p].st = -1;
tr[p].rev = -1;
}
void modify(int p, int x, int y, int f, int rev)
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y)
{
lazy(p, f, rev);
return;
}
pushdown(p);
int mid = (l + r) >> 1;
if(x <= mid) modify(lc, x, y, f, rev);
if(y > mid) modify(rc, x, y, f, rev);
pushup(tr[p], tr[lc], tr[rc]);
}
node query(int p, int x, int y)
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y)
{
return tr[p];
}
pushdown(p);
int mid = (l + r) >> 1;
if(y <= mid) return query(lc, x, y);
if(x > mid) return query(rc, x, y);
node ret, L = query(lc, x, y), R = query(rc, x, y);
pushup(ret, L, R);
return ret;
}
void solve()
{
int n, m; cin >> n >> m;
for(int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
while(m--)
{
int op, l, r;
cin >> op >> l >> r;
l++;
r++;
if(op == 0)
{
modify(1, l, r, 0, -1);
}
else if(op == 1)
{
modify(1, l, r, 1, -1);
}
else if(op == 2)
{
modify(1, l, r, -1, 2);
}
else if(op == 3)
{
cout << query(1, l, r).s1 << endl;
}
else if(op == 4)
{
cout << query(1, l, r).m1 << endl;
}
}
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int _ = 1;
while(_--)
{
solve();
}
return 0;
}
九、线段树+剪枝
线段树在维护区间修改操作时,有些操作是无法 "懒" 下来的,比如对整个区间的每一个数执行开根号操作。这时只能从上往下把所有的点全部修改。
但是,如果在修改的过程中发现,整个区间在修改到一定程度的时候,整个区间就无需修改。那么,就可以通过剪枝操作,优化区间的修改。
这样的线段树也叫作势能线段树。
1.开根号
P4145 上帝造题的七分钟 2 / 花神游历各国 - 洛谷
在开根号的过程中,当整个区间全部为 1 时,此时就不需要开根号了。因此,可以用线段树维护区间的最大值。当整个区间的最大值为 1 时,停止修改。
因为每一个数最多被开 6 次就会变成 1,因此修改操作最多会执行 6 * n * logn 次。时间上可行的。
cpp
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define endl '\n'
#define lc p << 1
#define rc p << 1 | 1
const int N = 1e5 + 10;
int n, m;
int a[N];
struct node
{
int l, r, sum, mx;
}tr[N << 2];
void pushup(int p)
{
tr[p].sum = tr[lc].sum + tr[rc].sum;
tr[p].mx = max(tr[lc].mx, tr[rc].mx);
}
void build(int p, int l, int r)
{
tr[p] = {l, r, a[l], a[l]};
if(l == r) return;
int mid = (l + r) >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
pushup(p);
}
void modify(int p, int x, int y)
{
if(tr[p].mx == 1) return;
int l = tr[p].l;
int r = tr[p].r;
if(l == r)
{
tr[p].sum = sqrt(tr[p].sum);
tr[p].mx = sqrt(tr[p].mx);
return;
}
int mid = (l + r) >> 1;
if(x <= mid) modify(lc, x, y);
if(y > mid) modify(rc, x, y);
pushup(p);
return;
}
int query(int p, int x, int y)
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y)
{
return tr[p].sum;
}
int sum = 0;
int mid = (l + r) >> 1;
if(x <= mid) sum += query(lc, x, y);
if(y > mid) sum += query(rc, x, y);
return sum;
}
void solve()
{
cin >> n;
for(int i = 1; i <= n; i++) cin >> a[i];
cin >> m;
build(1, 1, n);
while(m--)
{
int k, l, r; cin >> k >> l >> r;
if(l > r) swap(l, r);
if(k == 0) modify(1, l, r);
else cout << query(1, l, r) << endl;
}
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int _ = 1;
// cin >> _;
while(_--)
{
solve();
}
return 0;
}
2.取模
CF438D The Child and Sequence - 洛谷
取模操作与开根号操作⼀样。当某个区间的最大值小于取模的数时,整个区间所有的数都不会再取模。因此可以通过维护区间最大值,进行剪枝。
cpp
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define endl '\n'
#define lc p << 1
#define rc p << 1 | 1
const int N = 1e5 + 10;
int n, m;
int a[N];
struct node
{
int l, r, sum, mx;
}tr[N << 2];
void pushup(int p)
{
tr[p].sum = tr[lc].sum + tr[rc].sum;
tr[p].mx = max(tr[lc].mx, tr[rc].mx);
}
void build(int p, int l, int r)
{
tr[p] = {l, r, a[l], a[l]};
if(l == r) return;
int mid = (l + r) >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
pushup(p);
}
void modify1(int p, int x, int y, int mod)
{
if(tr[p].mx < mod) return;
int l = tr[p].l;
int r = tr[p].r;
if(l == r)
{
tr[p].sum = tr[p].sum % mod;
tr[p].mx = tr[p].mx % mod;
return;
}
int mid = (l + r) >> 1;
if(x <= mid) modify1(lc, x, y, mod);
if(y > mid) modify1(rc, x, y, mod);
pushup(p);
}
void modify2(int p, int x, int k)
{
int l = tr[p].l;
int r = tr[p].r;
if(l == x && r == x)
{
tr[p].sum = k;
tr[p].mx = k;
return;
}
int mid = (l + r) >> 1;
if(x <= mid) modify2(lc, x, k);
else modify2(rc, x, k);
pushup(p);
}
int query(int p, int x, int y)
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y)
{
return tr[p].sum;
}
int sum = 0;
int mid = (l + r) >> 1;
if(x <= mid) sum += query(lc, x, y);
if(y > mid) sum += query(rc, x, y);
return sum;
}
void solve()
{
cin >> n >> m;
for(int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
while(m--)
{
int op; cin >> op;
if(op == 1)
{
int l, r; cin >> l >> r;
cout << query(1, l, r) << endl;
}
else if(op == 2)
{
int l, r, x; cin >> l >> r >> x;
modify1(1, l, r, x);
}
else if(op == 3)
{
int k, x; cin >> k >> x;
modify2(1, k, x);
}
}
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int _ = 1;
// cin >> _;
while(_--)
{
solve();
}
return 0;
}
十、权值线段树+离散化
【引入】
- 对于一组数据,查询 \([x,y]\) 之间的数,一共出现了多少次?针对这样的问题,就可以用权值线段树来解决。
【权值线段树】
相较于普通的线段树,权值线段树维护区间内的数出现的次数:
- 结点的区间信息表示:数据的值域;
- 结点的权值信息表示:这些数据一共出现的次数。
比如,有数据 a = [1, 5, 5, 2, 2, 4, 1, 1],对应的权值线段树为:

实际做题中,数据的值域⼀般很大,如果仅仅考虑数的大小而不考虑具体的值,常常会先把原始数据离散化。
逆序对
对于 i 位置,只要统计出 [1, i-1] 区间内有多少个数比 a[i] 大即可。
可以用权值线段树维护这个信息。
cpp
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define endl '\n'
#define lc p << 1
#define rc p << 1 | 1
const int N = 5e5 + 10;
int n;
int cnt;
int a[N];
int t[N];
int find(int x)
{
return lower_bound(t + 1, t + 1 + cnt, x) - t;
}
struct node
{
int l, r, cnt;
}tr[N << 2];
void pushup(int p)
{
tr[p].cnt = tr[lc].cnt + tr[rc].cnt;
}
void build(int p, int l, int r)
{
tr[p] = {l, r, 0};
if(l == r) return;
int mid = (l + r) >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
pushup(p);
}
void modify(int p, int x)
{
int l = tr[p].l;
int r = tr[p].r;
if(l == x && r == x)
{
tr[p].cnt++;
return;
}
int mid = (l + r) >> 1;
if(x <= mid) modify(lc, x);
else modify(rc, x);
pushup(p);
}
int query(int p, int x, int y)
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y)
{
return tr[p].cnt;
}
int cnt = 0;
int mid = (l + r) >> 1;
if(x <= mid) cnt += query(lc, x, y);
if(y > mid) cnt += query(rc, x, y);
return cnt;
}
void solve()
{
cin >> n;
for(int i = 1; i <= n; i++)
{
cin >> a[i];
t[i] = a[i];
}
sort(t + 1, t + 1 + n);
cnt = unique(t + 1, t + 1 + n) - t - 1;
int ret = 0;
build(1, 1, cnt);
for(int i = 1; i <= n; i++)
{
int x = find(a[i]); // 找到新的下标
ret += query(1, x + 1, cnt);
modify(1, x);
}
cout << ret << endl;
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int _ = 1;
// cin >> _;
while(_--)
{
solve();
}
return 0;
}
十一、线段树+数学
1.区间方差
用方差的第二个公式:平方的期望-期望的平方,维护sum 和 平方和qsum
cpp
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define endl '\n'
#define lc p << 1
#define rc p << 1 | 1
const int N = 1e5 + 10;
int n, m;
int b[N];
const int mod = 1e9 + 7;
struct node
{
int l, r;
int sum, qsum;
}tr[N << 2];
void pushup(node& p, node& l, node& r)
{
p.sum = (l.sum + r.sum) % mod;
p.qsum = (l.qsum + r.qsum) % mod;
}
void build(int p, int l, int r)
{
tr[p] = {l, r, b[l], b[l] * b[l] % mod};
if(l == r) return;
int mid = (l + r) >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
pushup(tr[p], tr[lc], tr[rc]);
}
void modify(int p, int x, int k)
{
int l = tr[p].l;
int r = tr[p].r;
if(l == r)
{
tr[p].sum = k;
tr[p].qsum = k * k % mod;
return;
}
int mid = (l + r) >> 1;
if(x <= mid) modify(lc, x, k);
else modify(rc, x, k);
pushup(tr[p], tr[lc], tr[rc]);
}
node query(int p, int x, int y)
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y)
{
return tr[p];
}
int mid = (l + r) >> 1;
if(y <= mid) return query(lc, x, y);
if(x > mid) return query(rc, x, y);
node ret, L = query(lc, x, y), R = query(rc, x, y);
pushup(ret, L, R);
return ret;
}
int qpow(int a, int b, int p)
{
a %= p;
int ret = 1;
while(b)
{
if(b & 1) ret = ret * a % p;
a = a * a % p;
b >>= 1;
}
return ret;
}
void solve()
{
cin >> n >> m;
for(int i = 1; i <= n; i++) cin >> b[i];
build(1, 1, n);
while(m--)
{
int c, x, y; cin >> c >> x >> y;
if(c == 1)
{
modify(1, x, y);
}
else
{
node t = query(1, x, y);
int sum = t.sum;
int qsum = t.qsum;
int inv = qpow(y - x + 1, mod - 2, mod);
int ret1 = qsum * inv % mod;
int ret2 = sum * inv % mod * sum % mod * inv % mod;
cout << ((ret1 - ret2) % mod + mod) % mod << endl;
}
}
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int _ = 1;
// cin >> _;
while(_--)
{
solve();
}
return 0;
}
2.GCD
gcd是可以维护的,但是修改有的麻烦(时间复杂度太大),所以用GCD的一个性质,结合方差,可以将:区间修改-->单点修改

有这个结论,可以维护原序列差分序列中的最大公约数,此时区间修改就变成两次单点修改。
但是,在求差分序列 [l, r] 的最大公约数时,还需要知道原数列 a_l 的值。可以在差分序列中维护一个区间和,此时原数列 a_l 的值就是差分序列中 [1, l] 区间的和。
注意用差分解决问题时,最大公约数会出现负数的情况。注意取绝对值。
cpp
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define endl '\n'
#define lc p << 1
#define rc p << 1 | 1
const int N = 5e5 + 10;
int n, m;
int a[N];
int f[N];
struct node
{
int l, r;
int gcd;
int sum;
}tr[N << 2];
int gcd(int a, int b)
{
return b == 0 ? a : gcd(b, a % b);
}
void pushup(int p)
{
tr[p].sum = tr[lc].sum + tr[rc].sum;
tr[p].gcd = gcd(tr[lc].gcd, tr[rc].gcd);
}
void build(int p, int l, int r)
{
tr[p] = {l, r, a[l], a[l]};
if(l == r) return;
int mid = (l + r) >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
pushup(p);
}
void modify(int p, int x, int k)
{
int l = tr[p].l;
int r = tr[p].r;
if(l == r)
{
tr[p].sum += k;
tr[p].gcd += k;
return;
}
int mid = (l + r) >> 1;
if(x <= mid) modify(lc, x, k);
else modify(rc, x, k);
pushup(p);
}
int query1(int p, int x, int y)
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y)
{
return tr[p].sum;
}
int sum = 0;
int mid = (l + r) >> 1;
if(x <= mid) sum += query1(lc, x, y);
if(y > mid) sum += query1(rc, x, y);
return sum;
}
int query2(int p, int x, int y)
{
int l = tr[p].l;
int r = tr[p].r;
if(l >= x && r <= y)
{
return tr[p].gcd;
}
int g = 0;
int mid = (l + r) >> 1;
if(x <= mid) g = gcd(query2(lc, x, y), g);
if(y > mid) g = gcd(query2(rc, x, y), g);
return g;
}
void solve()
{
cin >> n >> m;
for(int i = 1; i <= n; i++)
{
int x; cin >> x;
a[i] += x;
a[i + 1] -= x;
}
build(1, 1, n);
while(m--)
{
char op; cin >> op;
if(op == 'C')
{
int l, r, d; cin >> l >> r >> d;
modify(1, l, d);
if(r + 1 <= n) modify(1, r + 1, -d);
}
else
{
int l, r; cin >> l >> r;
cout << abs(gcd(query1(1, 1, l), query2(1, l + 1, r)))<< endl;
}
}
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int _ = 1;
// cin >> _;
while(_--)
{
solve();
}
return 0;
}