二叉搜索树
- 概念
- [二叉搜索树中第 K 小的元素](#二叉搜索树中第 K 小的元素)
- 把二叉搜索树转换为累加树
- 删除二叉搜索树中的节点
- 二叉搜索树中的插入操作
- 验证二叉搜索树
- 不同的二叉树
- [不同的搜索二叉树 II](#不同的搜索二叉树 II)
概念
TreeMap的类代码
可以看到有类似于map的 KV键值对
cpp
// 大写 K 为键的类型,大写 V 为值的类型
template <typename K, typename V>
class TreeNode {
public:
K key;
V value;
TreeNode<K, V>* left;
TreeNode<K, V>* right;
TreeNode(K key, V value) : key(key), value(value), left(nullptr), right(nullptr) {}
};
TreeMap有如下接口
cpp
#include <iostream>
#include <vector>
#include <stdexcept>
// 增强 TreeNode,增加 size 字段
template <typename K, typename V>
class TreeNode {
public:
K key;
V value;
TreeNode<K, V>* left;
TreeNode<K, V>* right;
int size; // 记录以该节点为根的子树节点数
TreeNode(K key, V value, int size)
: key(key), value(value), left(nullptr), right(nullptr), size(size) {}
};
template <typename K, typename V>
class MyTreeMap {
private:
TreeNode<K, V>* root;
// --- 辅助函数:获取节点 size,处理 nullptr 情况 ---
int size(TreeNode<K, V>* node) {
if (node == nullptr) return 0;
return node->size;
}
// --- 辅助函数:递归 Put ---
TreeNode<K, V>* put(TreeNode<K, V>* node, K key, V value) {
if (node == nullptr) {
return new TreeNode<K, V>(key, value, 1);
}
if (key < node->key) {
node->left = put(node->left, key, value);
} else if (key > node->key) {
node->right = put(node->right, key, value);
} else {
node->value = value; // 更新值
}
// 更新 size
node->size = 1 + size(node->left) + size(node->right);
return node;
}
// --- 辅助函数:递归 Get ---
TreeNode<K, V>* get(TreeNode<K, V>* node, K key) {
if (node == nullptr) return nullptr;
if (key < node->key) return get(node->left, key);
if (key > node->key) return get(node->right, key);
return node;
}
// --- 辅助函数:查找最小节点 ---
TreeNode<K, V>* min(TreeNode<K, V>* node) {
if (node->left == nullptr) return node;
return min(node->left);
}
// --- 辅助函数:查找最大节点 ---
TreeNode<K, V>* max(TreeNode<K, V>* node) {
if (node->right == nullptr) return node;
return max(node->right);
}
// --- 辅助函数:删除最小节点 (用于辅助 remove) ---
TreeNode<K, V>* deleteMin(TreeNode<K, V>* node) {
if (node->left == nullptr) {
TreeNode<K, V>* rightNode = node->right;
delete node;
return rightNode;
}
node->left = deleteMin(node->left);
node->size = 1 + size(node->left) + size(node->right);
return node;
}
// --- 辅助函数:递归 Remove ---
TreeNode<K, V>* remove(TreeNode<K, V>* node, K key) {
if (node == nullptr) return nullptr;
if (key < node->key) {
node->left = remove(node->left, key);
} else if (key > node->key) {
node->right = remove(node->right, key);
} else {
// 找到目标节点
if (node->right == nullptr) {
TreeNode<K, V>* leftNode = node->left;
delete node;
return leftNode;
}
if (node->left == nullptr) {
TreeNode<K, V>* rightNode = node->right;
delete node;
return rightNode;
}
// 左右子节点都存在:用右子树的最小节点替代当前节点
TreeNode<K, V>* t = node;
node = min(t->right); // 这里仅仅是浅拷贝指针,实际应用中通常新建节点或拷贝值
// 注意:为了内存安全,这里通常做法是拷贝 key/value,然后 deleteMin
// 为了简化演示,我们假设构建新节点逻辑:
TreeNode<K, V>* successor = new TreeNode<K, V>(node->key, node->value, 1);
successor->right = deleteMin(t->right);
successor->left = t->left;
delete t; // 删除旧节点
node = successor;
}
node->size = 1 + size(node->left) + size(node->right);
return node;
}
// --- 辅助函数:Keys (中序遍历) ---
void traverse(TreeNode<K, V>* node, std::vector<K>& keys) {
if (node == nullptr) return;
traverse(node->left, keys);
keys.push_back(node->key);
traverse(node->right, keys);
}
// --- 辅助函数:Floor ---
TreeNode<K, V>* floor(TreeNode<K, V>* node, K key) {
if (node == nullptr) return nullptr;
if (key == node->key) return node;
if (key < node->key) return floor(node->left, key);
// key > node->key,可能在右子树,也可能就是当前节点
TreeNode<K, V>* t = floor(node->right, key);
if (t != nullptr) return t;
return node;
}
// --- 辅助函数:Ceiling ---
TreeNode<K, V>* ceiling(TreeNode<K, V>* node, K key) {
if (node == nullptr) return nullptr;
if (key == node->key) return node;
if (key > node->key) return ceiling(node->right, key);
TreeNode<K, V>* t = ceiling(node->left, key);
if (t != nullptr) return t;
return node;
}
// --- 辅助函数:Select ---
TreeNode<K, V>* select(TreeNode<K, V>* node, int k) {
if (node == nullptr) return nullptr;
int t = size(node->left); // 左子树的节点数就是当前节点的排名(从0开始)
if (t > k) return select(node->left, k);
else if (t < k) return select(node->right, k - t - 1);
else return node;
}
// --- 辅助函数:Rank ---
int rank(TreeNode<K, V>* node, K key) {
if (node == nullptr) return 0;
if (key < node->key) return rank(node->left, key);
else if (key > node->key) return 1 + size(node->left) + rank(node->right, key);
else return size(node->left);
}
// --- 辅助函数:Range ---
void range(TreeNode<K, V>* node, std::vector<K>& list, K low, K high) {
if (node == nullptr) return;
if (low < node->key) range(node->left, list, low, high);
if (low <= node->key && high >= node->key) list.push_back(node->key);
if (high > node->key) range(node->right, list, low, high);
}
public:
MyTreeMap() : root(nullptr) {}
// 1. 增/改
void put(K key, V value) {
root = put(root, key, value);
}
// 2. 查
V get(K key) {
TreeNode<K, V>* node = get(root, key);
if (node == nullptr) throw std::runtime_error("Key not found");
return node->value;
}
// 3. 删
void remove(K key) {
root = remove(root, key);
}
// 4. 是否包含
bool containsKey(K key) {
return get(root, key) != nullptr;
}
// 5. 所有键
std::vector<K> keys() {
std::vector<K> list;
traverse(root, list);
return list;
}
// 6. 最小键
K firstKey() {
if (root == nullptr) throw std::runtime_error("Empty map");
return min(root)->key;
}
// 7. 最大键
K lastKey() {
if (root == nullptr) throw std::runtime_error("Empty map");
return max(root)->key;
}
// 8. Floor (<= key 的最大键)
K floorKey(K key) {
TreeNode<K, V>* node = floor(root, key);
if (node == nullptr) throw std::runtime_error("No floor key");
return node->key;
}
// 9. Ceiling (>= key 的最小键)
K ceilingKey(K key) {
TreeNode<K, V>* node = ceiling(root, key);
if (node == nullptr) throw std::runtime_error("No ceiling key");
return node->key;
}
// 10. Select (排名为 k 的键, 0-indexed)
K selectKey(int k) {
if (k < 0 || k >= size(root)) throw std::runtime_error("Index out of bounds");
return select(root, k)->key;
}
// 11. Rank (键 key 的排名)
int rank(K key) {
return rank(root, key);
}
// 12. Range (区间查找)
std::vector<K> rangeKeys(K low, K high) {
std::vector<K> list;
range(root, list, low, high);
return list;
}
// 获取总大小
int size() {
return size(root);
}
};
二叉搜索树中第 K 小的元素
方法一:迭代法
思路是初始化一个栈,利用先进后出的特性将二叉树从根开始一路想左压入栈中
然后依次pop 顶部元素(最小的,次小的,pop k个)找到第k小的元素
cpp
class Solution {
public:
int kthSmallest(TreeNode* root, int k) {
stack<TreeNode*> st;
TreeNode* cur = root;
while(cur != nullptr || !st.empty()){
while(cur != nullptr){
st.push(cur);
cur = cur->left;
}
cur = st.top();
st.pop();
k--;
if(k == 0) return cur->val;
cur = cur->right;
}
return -1;
}
};
方法2:递归法
cpp
class Solution {
public:
int res = 0;
int count = 0;
void traverse(TreeNode* node,int k){
if(node == nullptr) return;
traverse(node->left,k);
if (count == k) return;
count++;
if(count == k) {
res = node->val;
return;
}
traverse(node->right,k);
}
int kthSmallest(TreeNode* root, int k) {
count = 0;
traverse(root,k);
return res;
}
};
思考进阶如果二叉搜索树经常被修改(插入/删除)怎么办
我们必须对数据结构进行增强(Augmenting)。 具体来说,就是在每个 TreeNode 中多维护一个字段 size。
size 的定义:以当前节点为根的子树中节点的总个数。
叶子节点的 size = 1。
空节点的 size = 0。
非叶节点 size = left.size + right.size + 1。
有了这个字段,我们就可以像二分查找一样,迅速定位第 k k k 小的元素,而不需要一个个去数。
cpp
struct TreeNode {
int val;
int size; // <--- 新增字段:子树节点数
TreeNode *left;
TreeNode *right;
TreeNode(int x) : val(x), size(1), left(nullptr), right(nullptr) {}
};
查询算法:Select 操作 ( O ( log N ) O(\log N) O(logN))
假设我们要找第 k k k 小。当前节点是 root。
左子树有 leftSize 个人:这意味着当前 root 节点自己在整棵树里排第 leftSize + 1 名。令 leftSize 为左子树的节点数。
判断逻辑:如果 k = = l e f t S i z e + 1 k == leftSize + 1 k==leftSize+1:恭喜,当前节点就是第 k k k 小!直接返回。如果 k ≤ l e f t S i z e k \le leftSize k≤leftSize:说明目标在左子树里,递归去左边找第 k k k 小。如果 k > l e f t S i z e + 1 k > leftSize + 1 k>leftSize+1:说明目标在右子树里。关键点:在右子树里,它不再是第 k k k 小了,而是第 k − ( l e f t S i z e + 1 ) k - (leftSize + 1) k−(leftSize+1) 小。
cpp
class OrderStatisticTree {
// 辅助函数:安全获取 size,处理空指针
int getSize(TreeNode* node) {
return node ? node->size : 0;
}
public:
// 查找第 k 小元素,复杂度 O(log N)
int kthSmallest(TreeNode* root, int k) {
TreeNode* curr = root;
while (curr != nullptr) {
int leftSize = getSize(curr->left);
if (leftSize + 1 == k) {
return curr->val; // 找到了
}
else if (k <= leftSize) {
curr = curr->left; // 在左边,k 不变
}
else {
// 在右边,k 要减去左边的人数和根节点自己
k = k - (leftSize + 1);
curr = curr->right;
}
}
return -1; // 理论上不应到达这里
}
};
cpp
// 插入新节点,同时维护 size
TreeNode* insert(TreeNode* node, int val) {
if (node == nullptr) {
return new TreeNode(val); // 新节点 size 默认为 1
}
if (val < node->val) {
node->left = insert(node->left, val);
} else if (val > node->val) {
node->right = insert(node->right, val);
}
// --- 核心:回溯时更新当前节点的 size ---
node->size = 1 + getSize(node->left) + getSize(node->right);
return node;
}
平衡性 (AVL / 红黑树)
如果真的在工程中使用,还有一个隐患:普通的 BST 可能会退化成链表。
如果退化成链表,即使维护了 size,树高变成了 N N N,查询复杂度还是会退化成 O ( N ) O(N) O(N)。
因此,真正的工业级实现(如 Java 的 TreeList 或某些数据库索引)会结合 平衡二叉树 (AVL / Red-Black Tree) 和 Size 增强。
把二叉搜索树转换为累加树
cpp
class Solution {
public:
int sum = 0;
void traverse(TreeNode* node){
if(node == nullptr) return;
traverse(node->right);
sum += node->val;
node->val = sum;
traverse(node->left);
}
TreeNode* convertBST(TreeNode* root) {
sum = 0;
traverse(root);
return root;
}
};
删除二叉搜索树中的节点
第一步先找到,要删除的节点,和标准的BST查找一样
如果 key < root->val:去左子树删 -> root->left = deleteNode(root->left, key)
如果 key > root->val:去右子树删 -> root->right = deleteNode(root->right, key)
如果 key == root->val:找到了! 进入第二步。
找到之后开始删除
情况1:
如果是叶子节点:不影响整体结构直接删除
情况2:
如果有一个叶子结点,那么删除后这个叶子结点直接替换删除节点的位置
情况3:
如果有两个叶子结点,那么删除后一般是用右子树里面最小的那个节点(右子树递归左子树(这个值与删除的这个值最接近))
把这个值复制给删除的那个位置,然后像情况1或者情况2(有一个右叶子/右子树)那样删除叶子节点即可。
cpp
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/
class Solution {
public:
TreeNode* getMin(TreeNode* node) {
while (node->left != nullptr) {
node = node->left;
}
return node;
}
TreeNode* deleteNode(TreeNode* root, int key) {
if(root == nullptr) return nullptr;
if(key<root->val){
root->left = deleteNode(root->left,key);
return root;
}
if(key>root->val){
root->right = deleteNode(root->right,key);
return root;
}
if(root->left == nullptr){
TreeNode* rightnode = root->right;
delete root;
return rightnode;
}
if(root->right == nullptr){
TreeNode* leftnode = root->left;
delete root;
return leftnode;
}
TreeNode* minNode = getMin(root->right);
root->val = minNode->val;
root->right = deleteNode(root->right, minNode->val);
return root;
}
};
二叉搜索树中的插入操作
cpp
class Solution {
public:
TreeNode* insertIntoBST(TreeNode* root, int val) {
if(root == nullptr) return new TreeNode(val);
if(val<root->val){
root->left = insertIntoBST(root->left,val);
}else if(val > root->val){
root->right = insertIntoBST(root->right,val);
}
return root;
}
};
验证二叉搜索树
思路就是用一个临时变量代表上一个前一个访问节点的值
然后利用中序遍历
一旦发现当前前一个访问节点的值大于当前节点的值就返回false
cpp
class Solution {
private:
TreeNode* pre = nullptr;
public:
bool isValidBST(TreeNode* root) {
if(root == nullptr) return true;
bool leftisok = isValidBST(root->left);
if(!leftisok) return false;
if(pre != nullptr && root->val <= pre->val) return false;
pre = root;
bool rightisok = isValidBST(root->right);
return rightisok;
}
};
不同的二叉树
思路是用动态规划dp
加上从1-n各节点
加入j为n
那么左子树有j-1个节点,右侧有n-j各节点
由于二叉搜索树的特性,有几种二叉树之和节点个数有关和具体的数值无关,比如123和234的二叉搜索树是一样的
所以以j为根的二叉搜索树,个数为, (左子树 j − 1 j-1 j−1 个节点的种类数) × \times × (右子树 n − j n-j n−j 个节点的种类数)
状态转移方程:
定义 dp[i] 表示由 i 个节点组成的不同 BST 的数量。对于 dp[n],我们需要遍历 1 1 1 到 n n n 的每一个数字作为根节点的情况,并将它们累加:
d p [ n ] = ∑ j = 1 n ( d p [ j − 1 ] × d p [ n − j ] ) dp[n] = \sum_{j=1}^{n} (dp[j-1] \times dp[n-j]) dp[n]=j=1∑n(dp[j−1]×dp[n−j])
cpp
class Solution {
public:
int numTrees(int n) {
vector<int> dp(n+1,0);
dp[0] = 1;
dp[1] = 1;
for(int i = 2;i<=n;i++){
for(int j = 1;j<=i;j++){
dp[i] += dp[j-1] * dp[i-j];
}
}
return dp[n];
}
};
不同的搜索二叉树 II
思路:
首先明确搜索二叉树,遍历顺序左中右
如果选定i,左孩子从[1 ... i-1] 里面选,右孩子从[i+1 ... n] 里面选。
然后左右两边是独立的区间,而且左边区间 [1, i-1] 也是要构造 BST,右边 [i+1, n] 也是要构造 BST。
就用到了分治算法,把大问题分解为小问题
如果我的根是i,我要先构造左子树和右子树
然后我的左子树还有左子树和右子树,-》》递归
现在每一个左子树右子树都找好了,怎么处理构造
把左子树的一个和右子树的一个拿出来组成到一起
乘积左边的可能×右边的可能-》双层for循环'
cpp
class Solution {
public:
vector<TreeNode*> build(int start, int end){
vector<TreeNode*> allTrees;
if(start > end){
allTrees.push_back(nullptr);
return allTrees;
}
for(int i = start;i<=end;i++){
vector<TreeNode*> leftsubTrees = build(start,i-1);
vector<TreeNode*> rightsubTrees = build(i+1,end);
for(TreeNode* left : leftsubTrees){
for(TreeNode* right :rightsubTrees){
TreeNode* root = new TreeNode(i);
root->left = left;
root->right = right;
allTrees.push_back(root);
}
}
}
return allTrees;
}
vector<TreeNode*> generateTrees(int n) {
if(n == 0) return {};
return build(1,n);
}
};