前言:为什么需要线段树?
假设你正在开发一个系统,需要频繁回答以下问题:
"数组
A[0...n-1]中,区间[L, R]的和/最大值/最小值 是多少?"并且,数组中的值还会动态更新 (如
A[i] = x)。
如果每次查询都遍历区间,时间复杂度为 O(n) ;若有 q 次查询,总复杂度 O(qn) ------ 在 n, q 达到 10^5 时,程序将严重超时。
线段树(Segment Tree) 就是为解决这类"动态区间查询"问题而生的高效数据结构。它能在:
- O(log n) 时间内完成单点/区间更新
- O(log n) 时间内完成区间查询(如求和、最值、GCD 等)
什么是线段树?
线段树是一种二叉树形数据结构 ,每个节点代表一个区间(线段) ,并通过分治思想将大区间分解为小区间,实现高效查询与更新。
核心思想:分治 + 预计算
- 根节点 :代表整个区间
[1, n] - 内部节点:代表其左右子区间的合并结果(如和、最大值)
- 叶子节点 :代表单个元素
[i, i]
通过预计算并存储 每个区间的聚合值,我们可以在查询时合并若干预计算结果,避免重复遍历。
模板
说明
Segment_tree<200000,int,false> tree; //(l~r范围,类型,是否需要gcd(不写默认false))
建树:Build(1,l,r,vector_a) //vector_a为传vector(类型必须与声明的类型一致),不写则默认所有位置为0
区间加k: Add(1,l,r,k)
区间赋值为k:Set(1,l,r,k)
查区间最大/小值:Ask_mx(1,l,r) Ask_mn(1,l,r);
查区间gcd: Ask_gcd(1,l,r);
查区间和:Ask_sum(1,l,r)
cpp
template<int N,typename T,bool EnableGCD=false> struct Segment_tree {
//建树 Segment_tree<200000(l~r),int(类型),true(是否需要gcd功能)> Tree;
#define ls (p<<1)
#define rs ((p<<1)|1)
constexpr static T inf = numeric_limits<T>::max();
struct Info{
int l, r;
T mx, mn, sum;
T lz_add, lz_set;
bool has_set;
typename conditional<EnableGCD, T, bool>::type gcd;
}tr[4 * N];
void push_up(int p) {
tr[p].mx = max(tr[ls].mx, tr[rs].mx);
tr[p].mn = min(tr[ls].mn, tr[rs].mn);
tr[p].sum = tr[ls].sum + tr[rs].sum;
if constexpr (EnableGCD) tr[p].gcd = gcd(tr[ls].gcd, tr[rs].gcd);
}
void push_down(int p) {
if (tr[p].has_set) {
T tag = tr[p].lz_set;
tr[ls].lz_set = tr[rs].lz_set = tag;
tr[ls].mx=tr[rs].mx = tag;
tr[ls].mn=tr[rs].mn = tag;
tr[ls].sum=tag * (tr[ls].r - tr[ls].l + 1);
tr[rs].sum=tag * (tr[rs].r - tr[rs].l + 1);
if constexpr (EnableGCD) tr[ls].gcd = tr[rs].gcd =(tag>=0?tag:-tag);
tr[ls].lz_add=tr[rs].lz_add = T(0);
tr[ls].has_set=tr[rs].has_set = true;
tr[p].lz_set=T(0);
tr[p].has_set=false;
}
if (tr[p].lz_add != T(0)) {
T tag = tr[p].lz_add;
tr[ls].lz_add += tag;
tr[ls].mx += tag;
tr[ls].mn += tag;
tr[ls].sum += tag * (tr[ls].r - tr[ls].l + 1);
tr[rs].lz_add += tag;
tr[rs].mx += tag;
tr[rs].mn += tag;
tr[rs].sum += tag * (tr[rs].r - tr[rs].l + 1);
tr[p].lz_add = T(0);
}
}
void Build(int p, int lo, int ro, const vector<T>& init = {}) { //建树
tr[p].l = lo;
tr[p].r = ro;
tr[p].lz_add = T(0);
tr[p].lz_set = T(0);
tr[p].has_set = false;
if (lo == ro) {
T val = init.empty() ? T(0) : init[lo];
tr[p].mx = tr[p].mn = tr[p].sum = val;
if constexpr (EnableGCD) tr[p].gcd =(val>=0?val:-val); // GCD总是非负的
return;
}
int mid = (lo + ro) >> 1;
Build(ls, lo, mid, init);
Build(rs, mid + 1, ro, init);
push_up(p);
}
T Ask_mx(int p, int lo, int ro) { //查区间最大
if (lo <= tr[p].l && ro >= tr[p].r) return tr[p].mx;
push_down(p);
int mid = (tr[p].l + tr[p].r) >> 1;
T res = -inf;
if (lo <= mid) res = max(res, Ask_mx(ls, lo, ro));
if (ro > mid) res = max(res, Ask_mx(rs, lo, ro));
return res;
}
T Ask_mn(int p, int lo, int ro) { //查区间最小
if (lo <= tr[p].l && ro >= tr[p].r) return tr[p].mn;
push_down(p);
int mid = (tr[p].l + tr[p].r) >> 1;
T res = inf;
if (lo <= mid) res = min(res, Ask_mn(ls, lo, ro));
if (ro > mid) res = min(res, Ask_mn(rs, lo, ro));
return res;
}
T Ask_sum(int p, int lo, int ro) { //查区间和
if (lo <= tr[p].l && ro >= tr[p].r) return tr[p].sum;
push_down(p);
int mid = (tr[p].l + tr[p].r) >> 1;
T res = T(0);
if (lo <= mid) res += Ask_sum(ls, lo, ro);
if (ro > mid) res += Ask_sum(rs, lo, ro);
return res;
}
template <bool E = EnableGCD>
enable_if_t<E, T> Ask_gcd(int p, int lo, int ro) { //查区间gcd
if (lo <= tr[p].l && ro >= tr[p].r) return tr[p].gcd;
push_down(p);
int mid = (tr[p].l + tr[p].r) >> 1;
T res = T(0);
if (lo <= mid) res = gcd(res, Ask_gcd(ls, lo, ro));
if (ro > mid) res = gcd(res, Ask_gcd(rs, lo, ro));
return res;
}
void Add(int p, int lo, int ro, T k) { //区间加
if (lo <= tr[p].l && ro >= tr[p].r) {
tr[p].lz_add += k;
tr[p].mx += k;
tr[p].mn += k;
tr[p].sum += k * (tr[p].r - tr[p].l + 1);
return;
}
push_down(p);
int mid = (tr[p].l + tr[p].r) >> 1;
if (lo <= mid) Add(ls, lo, ro, k);
if (ro > mid) Add(rs, lo, ro, k);
push_up(p);
}
void Set(int p, int lo, int ro, T k) { //区间赋值
if (lo > ro) return;
if (lo <= tr[p].l && ro >= tr[p].r) {
tr[p].lz_set = k;
tr[p].mx = k;
tr[p].mn = k;
tr[p].sum = k * (tr[p].r - tr[p].l + 1);
if constexpr (EnableGCD) tr[p].gcd = (k>=0?k:-k);
tr[p].lz_add = T(0);
tr[p].has_set = true;
return;
}
push_down(p);
int mid = (tr[p].l + tr[p].r) >> 1;
if (lo <= mid) Set(ls, lo, ro, k);
if (ro > mid) Set(rs, lo, ro, k);
push_up(p);
}
//建树 Segment_tree<200000(l~r),int(类型),true(是否需要gcd功能,默认false)> Tree;
};
/*************************************************************************/
/*
Segment_tree<200000,int,false> tree; //(l~r范围,类型,是否需要gcd(不写默认false))
建树:Build(1,l,r,vector_a) //vector_a为传vector(类型必须与声明的类型一致),不写则默认所有位置为0
区间加k: Add(1,l,r,k)
区间赋值为k:Set(1,l,r,k)
查区间最大/小值:Ask_mx(1,l,r) Ask_mn(1,l,r);
查区间gcd: Ask_gcd(1,l,r);
查区间和:Ask_sum(1,l,r)
*/
/************************************************************************/