线段树
什么是线段树
线段树是一种**[二叉搜索树]**,与[**区间树]**相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树的一个结点
[Segment Tree] is a data structure that stores data about range of elements in nodes as a tree. It is mostly used to handle range queries with updates in an efficient manner.
线段树特性
线段树具有平衡二叉树的特性,节点有范围(可以用来记录范围最值,求和,平均值等)。
-
A segment tree is a binary tree with a leaf node for each element in the array. ---- 平衡二叉树
-
Each internal node represents a segment or a range of elements in the array. --- 节点表示范围
-
The root node represents ++the entire array++ . --- 根节点表示整棵树的范围
-
Each node stores information about the segment it represents, such as the sum or minimum of the elements in the segment. -- -每个节点存储线段的最大值,最小值
编号 | 线段树特性 |
---|---|
1 | 平衡二叉树 |
2 | 节点表示范围 |
3 | 根节点表示整棵树的范围 |
4 | 每个节点存储线段的最大、最小值 |
线段树的实现-节点
创建节点SegmentNode
java
**
* 线段树节点
*/
public class SegmentNode {
public SegmentNode left;
public SegmentNode right;
public int start;
public int end;
public int sum;
public SegmentNode(int start, int end) {
this.start = start;
this.end = end;
}
@Override
public String toString() {
return "SegmentNode{" +
"left=" + left +
", right=" + right +
", start=" + start +
", end=" + end +
", sum=" + sum +
'}';
}
}
构建线段树
java
public SegmentNode buildTree(int start, int end) {
if (start > end) {
return null;
}
SegmentNode newNode = new SegmentNode(start, end);
if (start == end) {
newNode.sum = elements[start];
return newNode;
}
int mid = start + (end - start) / 2;
newNode.left = buildTree(start, mid);
newNode.right = buildTree(mid + 1, end);
return newNode;
}
查询指定线段树数据
java
public long query(SegmentNode root, int queryStart, int queryEnd){
if(queryStart > root.end || queryEnd < root.start){
return 0;
}
if(queryStart <= root.start && queryEnd >= root.end){
return root.sum;
}
long leftSum = query(root.left, queryStart, queryEnd);
long rightSum = query(root.right, queryStart,queryEnd);
return leftSum +rightSum;
}
更新线段树
在线段树中,
pushDown()
方法通常用于实现懒更新 (Lazy Propagation)策略。懒更新策略允许我们在++更新操作时避免不必要的子树遍历,只在真正需要时才将更新值下推到子节点++。在节点中添加lazy属性记录lazy值。如果我们每进行一次加的操作,就将全部线段树更改一边,时间复杂度会很高。因此,我们需要进行一个延迟加和的操作。
[思路]:如果 [left,right] 区间增加 a,在查询时,就可以把 [left,right] 区间标记的增加量推下去就可以直接求值了。
这时候,我们需要记录一个懒标记lazy,来记录这个区间增加量。
pushDown()
创建
pushDown( )
方法,如果子节点不为空,更新子节点的lazy值,即记录子节点的lazy值。
java
private void pushDown(SegmentNode node) {
if (node.left == null || node.right == null) return;
if (node.lazy != 0) {
int mid = node.start + (node.end - node.start)/2;
node.left.sum += node.lazy * (mid - node.start + 1);
node.left.lazy += node.lazy;
node.right.sum += node.lazy * (node.end -mid);
node.right.lazy += node.lazy;
node.lazy = 0;
}
}
更新节点
java
public void update(SegmentNode currNode,
int updateStart,
int updateEnd,
int updateVal) {
//不在范围内,更新失败
if (updateStart > currNode.end || updateEnd < currNode.start) return ;
//完全闭合区域
if (updateStart >= currNode.start && updateEnd <= currNode.end) {
currNode.sum += (currNode.end - currNode.start + 1) * updateVal;
currNode.lazy += updateVal;
return ;
}
pushDown(currNode);
update(currNode.left, updateStart, updateEnd, updateVal);
update(currNode.right, updateStart, updateEnd, updateVal);
currNode.sum = currNode.left.sum + currNode.right.sum;
return ;
}
打印线段树
java
public void printTree(SegmentTreeNode root, String prefix) {
System.out.println(
prefix + "Node [" + root.start + ", " + root.end + "]: value = " + root.sum + ", lazy = " + root.lazy);
if (root.left != null) {
printTree(root.left, prefix + " "); // 增加前缀空格以便形成缩进
}
if (root.right != null) {
printTree(root.right, prefix + " ");
}
}
完整代码实现
java
public class SegmentNode {
public SegmentNode left;
public SegmentNode right;
public int start;
public int end;
public int sum;
public int lazy;
public SegmentNode(int start, int end) {
this.start = start;
this.end = end;
}
@Override
public String toString() {
return "SegmentNode{" +
", start=" + start +
", end=" + end +
", sum=" + sum +
'}';
}
}
java
package com.training.segment;
public class SegmentTree {
private int[] elements;
public SegmentTree(int[] elements) {
this.elements = elements;
}
public SegmentNode buildTree(int start, int end) {
if (start > end) {
return null;
}
SegmentNode newNode = new SegmentNode(start, end);
if (start == end) {
newNode.sum = elements[start];
return newNode;
}
int mid = start + (end - start) / 2;
newNode.left = buildTree(start, mid);
newNode.right = buildTree(mid + 1, end);
newNode.sum = (newNode.left == null ? 0 : newNode.left.sum) + (newNode.right == null ? 0 : newNode.right.sum);
return newNode;
}
public long query(SegmentNode root, int queryStart, int queryEnd) {
if (queryStart > root.end || queryEnd < root.start) {
return 0;
}
if (queryStart <= root.start && queryEnd >= root.end) {
return root.sum;
}
long leftSum = query(root.left, queryStart, queryEnd);
long rightSum = query(root.right, queryStart, queryEnd);
return leftSum + rightSum;
}
private void pushDown(SegmentNode node) {
if (node.left == null || node.right == null) return;
if (node.lazy != 0) {
int mid = node.start + (node.end - node.start) / 2;
node.left.sum += node.lazy * (mid - node.start + 1);
node.left.lazy += node.lazy;
node.right.sum += node.lazy * (node.end - mid);
node.right.lazy += node.lazy;
node.lazy = 0;
}
}
public void update(SegmentNode currNode,
int updateStart,
int updateEnd,
int updateVal) {
//不在范围内,更新失败
if (updateStart > currNode.end || updateEnd < currNode.start) return;
//完全闭合区域
if (updateStart <= currNode.start && updateEnd >= currNode.end) {
currNode.sum += (currNode.end - currNode.start + 1) * updateVal;
currNode.lazy += updateVal;
return;
}
pushDown(currNode);
update(currNode.left, updateStart, updateEnd, updateVal);
update(currNode.right, updateStart, updateEnd, updateVal);
int leftSum = currNode.left == null ? 0 : currNode.left.sum;
int rightSum = currNode.right == null ? 0 : currNode.right.sum;
currNode.sum = leftSum + rightSum;
}
public void printTree(SegmentNode root, String prefix) {
System.out.println(
prefix + "Node [" + root.start + ", " + root.end + "]: value = "
+ root.sum);
if (root.left != null) {
printTree(root.left, prefix + " "); // 增加前缀空格以便形成缩进
}
if (root.right != null) {
printTree(root.right, prefix + " ");
}
}
public static void main(String[] args) {
int[] data = {1, 3, 5, 7, 9, 11};
SegmentTree tree = new SegmentTree(data);
SegmentNode root = tree.buildTree(0, data.length - 1);
tree.printTree(root, "");
tree.update(root, 0, 2, 3);
System.out.println("-------------------------------------------------------");
tree.printTree(root, "");
System.out.println(tree.query(root, 2, 2));
}
}
测试结果
java
Node [0, 5]: value = 36
Node [0, 2]: value = 9
Node [0, 1]: value = 4
Node [0, 0]: value = 1
Node [1, 1]: value = 3
Node [2, 2]: value = 5
Node [3, 5]: value = 27
Node [3, 4]: value = 16
Node [3, 3]: value = 7
Node [4, 4]: value = 9
Node [5, 5]: value = 11
-------------------------------------------------------
Node [0, 5]: value = 45
Node [0, 2]: value = 18
Node [0, 1]: value = 4
Node [0, 0]: value = 1
Node [1, 1]: value = 3
Node [2, 2]: value = 5
Node [3, 5]: value = 27
Node [3, 4]: value = 16
Node [3, 3]: value = 7
Node [4, 4]: value = 9
Node [5, 5]: value = 11
5