线段树是一种支持区间修改和区间查询的数据结构, 详细介绍可以参考算法学习笔记(14): 线段树, 下面记录几种使用过的模板:
- 区间加+查询区间求和
- 区间更新+查询区间极小值
- 区间加+查询区间极小值
注意模板线段树函数中的下标都是从1开始, 用于初始化的数组下标从0开始
区间加+查询区间求和
cpp
class SegmentTree {
public:
typedef long long ll;
inline void push_down(ll index) {
st[index << 1].lazy = 1;
st[index << 1 | 1].lazy = 1;
st[index << 1].mark += st[index].mark;
st[index << 1 | 1].mark += st[index].mark;
st[index << 1].s += st[index].mark * (st[index << 1].tr - st[index << 1].tl + 1);
st[index << 1 | 1].s += st[index].mark * (st[index << 1 | 1].tr - st[index << 1 | 1].tl + 1);
st[index].lazy = 0;
st[index].mark = 0;
}
inline void push_up(ll index) {
st[index].s = st[index << 1].s + st[index << 1 | 1].s;
}
SegmentTree(vector<int> &init_list) {
st = vector<SegmentTreeNode>(init_list.size() * 4 + 10);
build(init_list, 1, init_list.size());
}
void build(vector<int> &init_list, ll l, ll r, ll index = 1) {
st[index].tl = l;
st[index].tr = r;
st[index].lazy = 0;
st[index].mark = 0;
if (l == r) {
st[index].s = init_list[l - 1];
} else {
ll mid = (l + r) >> 1;
build(init_list, l, mid, index << 1);
build(init_list, mid + 1, r, index << 1 | 1);
push_up(index);
}
}
void add(ll l, ll r, ll d, ll index = 1) {
if (l > st[index].tr or r < st[index].tl)
return;
else if (l <= st[index].tl and st[index].tr <= r) {
st[index].s += (st[index].tr - st[index].tl + 1) * d;
st[index].mark += d;
st[index].lazy = 1;
} else {
if (st[index].lazy)
push_down(index);
add(l, r, d, index << 1);
add(l, r, d, index << 1 | 1);
push_up(index);
}
}
ll query(ll l, ll r, ll index = 1) {
if (l <= st[index].tl and st[index].tr <= r) {
return st[index].s;
} else {
if (st[index].lazy)
push_down(index);
if (r <= st[index << 1].tr)
return query(l, r, index << 1);
else if (l > st[index << 1].tr)
return query(l, r, index << 1 | 1);
return query(l, r, index << 1) + query(l, r, index << 1 | 1);
}
}
private:
struct SegmentTreeNode {
ll tl;
ll tr;
ll s;
ll mark;
int lazy;
};
vector<SegmentTreeNode> st;
};
区间更新+查询区间极小值
这里的区间更新等效于执行这样的操作: l i [ k ] = m i n ( l i [ k ] , v a l ) , l e f t ≤ k ≤ r i g h t li[k]=min(li[k],val) , left \le k \le right li[k]=min(li[k],val),left≤k≤right
cpp
class SegmentTree {
public:
typedef long long ll;
inline void push_down(ll index) {
st[index << 1].lazy = 1;
st[index << 1 | 1].lazy = 1;
st[index << 1].mark = min(st[index << 1].mark, st[index].mark);
st[index << 1 | 1].mark = min(st[index << 1 | 1].mark, st[index].mark);
st[index << 1].s = min(st[index << 1].s, st[index].mark);
st[index << 1 | 1].s = min(st[index << 1 | 1].s, st[index].mark);
st[index].lazy = 0;
}
inline void push_up(ll index) {
st[index].s = min(st[index << 1].s, st[index << 1 | 1].s);
}
SegmentTree(vector<int> &init_list) {
st = vector<SegmentTreeNode>(init_list.size() * 4 + 10);
build(init_list, 1, init_list.size());
}
void build(vector<int> &init_list, ll l, ll r, ll index = 1) {
st[index].tl = l;
st[index].tr = r;
st[index].lazy = 0;
st[index].mark = INT64_MAX;
if (l == r) {
st[index].s = init_list[l - 1];
} else {
ll mid = (l + r) >> 1;
build(init_list, l, mid, index << 1);
build(init_list, mid + 1, r, index << 1 | 1);
push_up(index);
}
}
void modify(ll l, ll r, ll val, ll index = 1) {
if (l > st[index].tr or r < st[index].tl)
return;
else if (l <= st[index].tl and st[index].tr <= r) {
st[index].s = min(st[index].s, val);
st[index].mark = min(val, st[index].mark);
st[index].lazy = 1;
} else {
if (st[index].lazy)
push_down(index);
modify(l, r, val, index << 1);
modify(l, r, val, index << 1 | 1);
push_up(index);
}
}
ll query(ll l, ll r, ll index = 1) {
if (l <= st[index].tl and st[index].tr <= r) {
return st[index].s;
} else {
if (st[index].lazy)
push_down(index);
if (r <= st[index << 1].tr)
return query(l, r, index << 1);
else if (l > st[index << 1].tr)
return query(l, r, index << 1 | 1);
return min(query(l, r, index << 1), query(l, r, index << 1 | 1));
}
}
private:
struct SegmentTreeNode {
ll tl;
ll tr;
ll s;
ll mark;
int lazy;
};
vector<SegmentTreeNode> st;
};
区间加+查询区间极小值
cpp
class SegmentTree {
public:
typedef long long ll;
inline void push_down(ll index) {
st[index << 1].lazy = 1;
st[index << 1 | 1].lazy = 1;
st[index << 1].mark += st[index].mark;
st[index << 1 | 1].mark += st[index].mark;
st[index << 1].s += st[index].mark;
st[index << 1 | 1].s += st[index].mark;
st[index].lazy = 0;
st[index].mark = 0;
}
inline void push_up(ll index) {
st[index].s = min(st[index << 1].s, st[index << 1 | 1].s);
}
SegmentTree(vector<int> &init_list) {
st = vector<SegmentTreeNode>(init_list.size() * 4 + 10);
build(init_list, 1, init_list.size());
}
void build(vector<int> &init_list, ll l, ll r, ll index = 1) {
st[index].tl = l;
st[index].tr = r;
st[index].lazy = 0;
st[index].mark = INT64_MAX;
if (l == r) {
st[index].s = init_list[l - 1];
} else {
ll mid = (l + r) >> 1;
build(init_list, l, mid, index << 1);
build(init_list, mid + 1, r, index << 1 | 1);
push_up(index);
}
}
void add(ll l, ll r, ll d, ll index = 1) {
if (l > st[index].tr or r < st[index].tl)
return;
else if (l <= st[index].tl and st[index].tr <= r) {
st[index].s += d;
st[index].mark += d;
st[index].lazy = 1;
} else {
if (st[index].lazy)
push_down(index);
add(l, r, d, index << 1);
add(l, r, d, index << 1 | 1);
push_up(index);
}
}
ll query(ll l, ll r, ll index = 1) {
if (l <= st[index].tl and st[index].tr <= r) {
return st[index].s;
} else {
if (st[index].lazy)
push_down(index);
if (r <= st[index << 1].tr)
return query(l, r, index << 1);
else if (l > st[index << 1].tr)
return query(l, r, index << 1 | 1);
return min(query(l, r, index << 1), query(l, r, index << 1 | 1));
}
}
private:
struct SegmentTreeNode {
ll tl;
ll tr;
ll s;
ll mark;
int lazy;
};
vector<SegmentTreeNode> st;
};