算法--二叉搜索树

二叉搜索树

概念

TreeMap的类代码

可以看到有类似于map的 KV键值对

cpp 复制代码
// 大写 K 为键的类型,大写 V 为值的类型
template <typename K, typename V>
class TreeNode {
public:
    K key;
    V value;
    TreeNode<K, V>* left;
    TreeNode<K, V>* right;
    TreeNode(K key, V value) : key(key), value(value), left(nullptr), right(nullptr) {}
};

TreeMap有如下接口

cpp 复制代码
#include <iostream>
#include <vector>
#include <stdexcept>

// 增强 TreeNode,增加 size 字段
template <typename K, typename V>
class TreeNode {
public:
    K key;
    V value;
    TreeNode<K, V>* left;
    TreeNode<K, V>* right;
    int size; // 记录以该节点为根的子树节点数

    TreeNode(K key, V value, int size) 
        : key(key), value(value), left(nullptr), right(nullptr), size(size) {}
};

template <typename K, typename V>
class MyTreeMap {
private:
    TreeNode<K, V>* root;

    // --- 辅助函数:获取节点 size,处理 nullptr 情况 ---
    int size(TreeNode<K, V>* node) {
        if (node == nullptr) return 0;
        return node->size;
    }

    // --- 辅助函数:递归 Put ---
    TreeNode<K, V>* put(TreeNode<K, V>* node, K key, V value) {
        if (node == nullptr) {
            return new TreeNode<K, V>(key, value, 1);
        }

        if (key < node->key) {
            node->left = put(node->left, key, value);
        } else if (key > node->key) {
            node->right = put(node->right, key, value);
        } else {
            node->value = value; // 更新值
        }

        // 更新 size
        node->size = 1 + size(node->left) + size(node->right);
        return node;
    }

    // --- 辅助函数:递归 Get ---
    TreeNode<K, V>* get(TreeNode<K, V>* node, K key) {
        if (node == nullptr) return nullptr;
        if (key < node->key) return get(node->left, key);
        if (key > node->key) return get(node->right, key);
        return node;
    }

    // --- 辅助函数:查找最小节点 ---
    TreeNode<K, V>* min(TreeNode<K, V>* node) {
        if (node->left == nullptr) return node;
        return min(node->left);
    }

    // --- 辅助函数:查找最大节点 ---
    TreeNode<K, V>* max(TreeNode<K, V>* node) {
        if (node->right == nullptr) return node;
        return max(node->right);
    }

    // --- 辅助函数:删除最小节点 (用于辅助 remove) ---
    TreeNode<K, V>* deleteMin(TreeNode<K, V>* node) {
        if (node->left == nullptr) {
            TreeNode<K, V>* rightNode = node->right;
            delete node;
            return rightNode;
        }
        node->left = deleteMin(node->left);
        node->size = 1 + size(node->left) + size(node->right);
        return node;
    }

    // --- 辅助函数:递归 Remove ---
    TreeNode<K, V>* remove(TreeNode<K, V>* node, K key) {
        if (node == nullptr) return nullptr;

        if (key < node->key) {
            node->left = remove(node->left, key);
        } else if (key > node->key) {
            node->right = remove(node->right, key);
        } else {
            // 找到目标节点
            if (node->right == nullptr) {
                TreeNode<K, V>* leftNode = node->left;
                delete node;
                return leftNode;
            }
            if (node->left == nullptr) {
                TreeNode<K, V>* rightNode = node->right;
                delete node;
                return rightNode;
            }

            // 左右子节点都存在:用右子树的最小节点替代当前节点
            TreeNode<K, V>* t = node;
            node = min(t->right); // 这里仅仅是浅拷贝指针,实际应用中通常新建节点或拷贝值
            // 注意:为了内存安全,这里通常做法是拷贝 key/value,然后 deleteMin
            // 为了简化演示,我们假设构建新节点逻辑:
            TreeNode<K, V>* successor = new TreeNode<K, V>(node->key, node->value, 1);
            successor->right = deleteMin(t->right);
            successor->left = t->left;
            delete t; // 删除旧节点
            node = successor;
        }
        node->size = 1 + size(node->left) + size(node->right);
        return node;
    }

    // --- 辅助函数:Keys (中序遍历) ---
    void traverse(TreeNode<K, V>* node, std::vector<K>& keys) {
        if (node == nullptr) return;
        traverse(node->left, keys);
        keys.push_back(node->key);
        traverse(node->right, keys);
    }
    
    // --- 辅助函数:Floor ---
    TreeNode<K, V>* floor(TreeNode<K, V>* node, K key) {
        if (node == nullptr) return nullptr;
        if (key == node->key) return node;
        if (key < node->key) return floor(node->left, key);
        
        // key > node->key,可能在右子树,也可能就是当前节点
        TreeNode<K, V>* t = floor(node->right, key);
        if (t != nullptr) return t;
        return node;
    }

    // --- 辅助函数:Ceiling ---
    TreeNode<K, V>* ceiling(TreeNode<K, V>* node, K key) {
        if (node == nullptr) return nullptr;
        if (key == node->key) return node;
        if (key > node->key) return ceiling(node->right, key);

        TreeNode<K, V>* t = ceiling(node->left, key);
        if (t != nullptr) return t;
        return node;
    }

    // --- 辅助函数:Select ---
    TreeNode<K, V>* select(TreeNode<K, V>* node, int k) {
        if (node == nullptr) return nullptr;
        int t = size(node->left); // 左子树的节点数就是当前节点的排名(从0开始)
        if (t > k) return select(node->left, k);
        else if (t < k) return select(node->right, k - t - 1);
        else return node;
    }

    // --- 辅助函数:Rank ---
    int rank(TreeNode<K, V>* node, K key) {
        if (node == nullptr) return 0;
        if (key < node->key) return rank(node->left, key);
        else if (key > node->key) return 1 + size(node->left) + rank(node->right, key);
        else return size(node->left);
    }

    // --- 辅助函数:Range ---
    void range(TreeNode<K, V>* node, std::vector<K>& list, K low, K high) {
        if (node == nullptr) return;
        if (low < node->key) range(node->left, list, low, high);
        if (low <= node->key && high >= node->key) list.push_back(node->key);
        if (high > node->key) range(node->right, list, low, high);
    }

public:
    MyTreeMap() : root(nullptr) {}

    // 1. 增/改
    void put(K key, V value) {
        root = put(root, key, value);
    }

    // 2. 查
    V get(K key) {
        TreeNode<K, V>* node = get(root, key);
        if (node == nullptr) throw std::runtime_error("Key not found");
        return node->value;
    }

    // 3. 删
    void remove(K key) {
        root = remove(root, key);
    }

    // 4. 是否包含
    bool containsKey(K key) {
        return get(root, key) != nullptr;
    }

    // 5. 所有键
    std::vector<K> keys() {
        std::vector<K> list;
        traverse(root, list);
        return list;
    }

    // 6. 最小键
    K firstKey() {
        if (root == nullptr) throw std::runtime_error("Empty map");
        return min(root)->key;
    }

    // 7. 最大键
    K lastKey() {
        if (root == nullptr) throw std::runtime_error("Empty map");
        return max(root)->key;
    }

    // 8. Floor (<= key 的最大键)
    K floorKey(K key) {
        TreeNode<K, V>* node = floor(root, key);
        if (node == nullptr) throw std::runtime_error("No floor key");
        return node->key;
    }

    // 9. Ceiling (>= key 的最小键)
    K ceilingKey(K key) {
        TreeNode<K, V>* node = ceiling(root, key);
        if (node == nullptr) throw std::runtime_error("No ceiling key");
        return node->key;
    }

    // 10. Select (排名为 k 的键, 0-indexed)
    K selectKey(int k) {
        if (k < 0 || k >= size(root)) throw std::runtime_error("Index out of bounds");
        return select(root, k)->key;
    }

    // 11. Rank (键 key 的排名)
    int rank(K key) {
        return rank(root, key);
    }

    // 12. Range (区间查找)
    std::vector<K> rangeKeys(K low, K high) {
        std::vector<K> list;
        range(root, list, low, high);
        return list;
    }

    // 获取总大小
    int size() {
        return size(root);
    }
};

二叉搜索树中第 K 小的元素

二叉搜索树中第 K 小的元素

方法一:迭代法

思路是初始化一个栈,利用先进后出的特性将二叉树从根开始一路想左压入栈中

然后依次pop 顶部元素(最小的,次小的,pop k个)找到第k小的元素

cpp 复制代码
class Solution {
public:
    int kthSmallest(TreeNode* root, int k) {
        stack<TreeNode*> st;
        TreeNode* cur = root;
        while(cur != nullptr || !st.empty()){
            while(cur != nullptr){
                st.push(cur);
                cur = cur->left;
            }
            cur = st.top();
            st.pop();
            k--;
            if(k == 0) return cur->val;
            cur  = cur->right;
        }
        return -1;
    }
};

方法2:递归法

cpp 复制代码
class Solution {
public:
    int res = 0;
    int count = 0;
    void traverse(TreeNode* node,int k){
        if(node == nullptr) return;
        traverse(node->left,k);
        if (count == k) return;
        count++;
        if(count == k) {
            res = node->val;
            return;
        }
        traverse(node->right,k);
    }
    int kthSmallest(TreeNode* root, int k) {
    count = 0;
    traverse(root,k);
    return res;
     }
};

思考进阶如果二叉搜索树经常被修改(插入/删除)怎么办

我们必须对数据结构进行增强(Augmenting)。 具体来说,就是在每个 TreeNode 中多维护一个字段 size。

size 的定义:以当前节点为根的子树中节点的总个数。

叶子节点的 size = 1。

空节点的 size = 0。

非叶节点 size = left.size + right.size + 1。

有了这个字段,我们就可以像二分查找一样,迅速定位第 k k k 小的元素,而不需要一个个去数。

cpp 复制代码
struct TreeNode {
    int val;
    int size; // <--- 新增字段:子树节点数
    TreeNode *left;
    TreeNode *right;
    
    TreeNode(int x) : val(x), size(1), left(nullptr), right(nullptr) {}
};

查询算法:Select 操作 ( O ( log ⁡ N ) O(\log N) O(logN))

假设我们要找第 k k k 小。当前节点是 root。

左子树有 leftSize 个人:这意味着当前 root 节点自己在整棵树里排第 leftSize + 1 名。令 leftSize 为左子树的节点数。

判断逻辑:如果 k = = l e f t S i z e + 1 k == leftSize + 1 k==leftSize+1:恭喜,当前节点就是第 k k k 小!直接返回。如果 k ≤ l e f t S i z e k \le leftSize k≤leftSize:说明目标在左子树里,递归去左边找第 k k k 小。如果 k > l e f t S i z e + 1 k > leftSize + 1 k>leftSize+1:说明目标在右子树里。关键点:在右子树里,它不再是第 k k k 小了,而是第 k − ( l e f t S i z e + 1 ) k - (leftSize + 1) k−(leftSize+1) 小。

cpp 复制代码
class OrderStatisticTree {
    // 辅助函数:安全获取 size,处理空指针
    int getSize(TreeNode* node) {
        return node ? node->size : 0;
    }

public:
    // 查找第 k 小元素,复杂度 O(log N)
    int kthSmallest(TreeNode* root, int k) {
        TreeNode* curr = root;
        while (curr != nullptr) {
            int leftSize = getSize(curr->left);
            
            if (leftSize + 1 == k) {
                return curr->val; // 找到了
            } 
            else if (k <= leftSize) {
                curr = curr->left; // 在左边,k 不变
            } 
            else {
                // 在右边,k 要减去左边的人数和根节点自己
                k = k - (leftSize + 1);
                curr = curr->right;
            }
        }
        return -1; // 理论上不应到达这里
    }
};
cpp 复制代码
// 插入新节点,同时维护 size
    TreeNode* insert(TreeNode* node, int val) {
        if (node == nullptr) {
            return new TreeNode(val); // 新节点 size 默认为 1
        }
        
        if (val < node->val) {
            node->left = insert(node->left, val);
        } else if (val > node->val) {
            node->right = insert(node->right, val);
        }
        
        // --- 核心:回溯时更新当前节点的 size ---
        node->size = 1 + getSize(node->left) + getSize(node->right);
        
        return node;
    }

平衡性 (AVL / 红黑树)

如果真的在工程中使用,还有一个隐患:普通的 BST 可能会退化成链表。

如果退化成链表,即使维护了 size,树高变成了 N N N,查询复杂度还是会退化成 O ( N ) O(N) O(N)。

因此,真正的工业级实现(如 Java 的 TreeList 或某些数据库索引)会结合 平衡二叉树 (AVL / Red-Black Tree) 和 Size 增强。

把二叉搜索树转换为累加树

把二叉搜索树转换为累加树

cpp 复制代码
class Solution {
public:
    int sum = 0;
    void traverse(TreeNode* node){
        if(node == nullptr) return;
        traverse(node->right);

        sum += node->val;
        node->val = sum;
        traverse(node->left);
    }
    TreeNode* convertBST(TreeNode* root) {
        sum = 0;
        traverse(root);
        return root;
    }
};

删除二叉搜索树中的节点

删除二叉搜索树中的节点

第一步先找到,要删除的节点,和标准的BST查找一样

如果 key < root->val:去左子树删 -> root->left = deleteNode(root->left, key)

如果 key > root->val:去右子树删 -> root->right = deleteNode(root->right, key)

如果 key == root->val:找到了! 进入第二步。

找到之后开始删除

情况1:

如果是叶子节点:不影响整体结构直接删除

情况2:

如果有一个叶子结点,那么删除后这个叶子结点直接替换删除节点的位置

情况3:

如果有两个叶子结点,那么删除后一般是用右子树里面最小的那个节点(右子树递归左子树(这个值与删除的这个值最接近))

把这个值复制给删除的那个位置,然后像情况1或者情况2(有一个右叶子/右子树)那样删除叶子节点即可。

cpp 复制代码
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
 * };
 */
class Solution {
public:
    TreeNode* getMin(TreeNode* node) {
        while (node->left != nullptr) {
            node = node->left;
        }
        return node;
    }
    TreeNode* deleteNode(TreeNode* root, int key) {
        if(root == nullptr) return nullptr;
        if(key<root->val){
            root->left = deleteNode(root->left,key);
            return root;
        }
        if(key>root->val){
            root->right = deleteNode(root->right,key);
            return root;
        }

        if(root->left == nullptr){
            TreeNode* rightnode = root->right;
            delete root;
            return rightnode;
        }
        if(root->right == nullptr){
            TreeNode* leftnode = root->left;
            delete root;
            return leftnode;
        }

        TreeNode* minNode = getMin(root->right);
        root->val = minNode->val;

        root->right = deleteNode(root->right, minNode->val);
        return root;
    }
};

二叉搜索树中的插入操作

二叉搜索树中的插入操作

cpp 复制代码
class Solution {
public:
    TreeNode* insertIntoBST(TreeNode* root, int val) {
        if(root == nullptr) return new TreeNode(val);

        if(val<root->val){
            root->left = insertIntoBST(root->left,val);
        }else if(val > root->val){
            root->right = insertIntoBST(root->right,val);
        }

        return root;
    }
};

验证二叉搜索树

验证二叉搜索树

思路就是用一个临时变量代表上一个前一个访问节点的值

然后利用中序遍历

一旦发现当前前一个访问节点的值大于当前节点的值就返回false

cpp 复制代码
class Solution {
    private:
    TreeNode* pre = nullptr;
public:
    bool isValidBST(TreeNode* root) {
        if(root == nullptr) return true;

        bool leftisok = isValidBST(root->left);
        if(!leftisok) return false;

        if(pre != nullptr && root->val <= pre->val) return false;

        pre = root;
        bool rightisok = isValidBST(root->right);
        return rightisok;
    }
};

不同的二叉树

思路是用动态规划dp

加上从1-n各节点

加入j为n

那么左子树有j-1个节点,右侧有n-j各节点

由于二叉搜索树的特性,有几种二叉树之和节点个数有关和具体的数值无关,比如123和234的二叉搜索树是一样的

所以以j为根的二叉搜索树,个数为, (左子树 j − 1 j-1 j−1 个节点的种类数) × \times × (右子树 n − j n-j n−j 个节点的种类数)

状态转移方程:

定义 dp[i] 表示由 i 个节点组成的不同 BST 的数量。对于 dp[n],我们需要遍历 1 1 1 到 n n n 的每一个数字作为根节点的情况,并将它们累加:
d p [ n ] = ∑ j = 1 n ( d p [ j − 1 ] × d p [ n − j ] ) dp[n] = \sum_{j=1}^{n} (dp[j-1] \times dp[n-j]) dp[n]=j=1∑n(dp[j−1]×dp[n−j])

cpp 复制代码
class Solution {
public:
    int numTrees(int n) {
        vector<int> dp(n+1,0);
        dp[0] = 1;
        dp[1] = 1;

        for(int i = 2;i<=n;i++){
            for(int j = 1;j<=i;j++){
                dp[i] += dp[j-1] * dp[i-j];
            }
        }
        return dp[n];
    }
};

不同的搜索二叉树 II

思路:

首先明确搜索二叉树,遍历顺序左中右

如果选定i,左孩子从[1 ... i-1] 里面选,右孩子从[i+1 ... n] 里面选。

然后左右两边是独立的区间,而且左边区间 [1, i-1] 也是要构造 BST,右边 [i+1, n] 也是要构造 BST。

就用到了分治算法,把大问题分解为小问题

如果我的根是i,我要先构造左子树和右子树
然后我的左子树还有左子树和右子树,-》》递归

现在每一个左子树右子树都找好了,怎么处理构造

把左子树的一个和右子树的一个拿出来组成到一起
乘积左边的可能×右边的可能-》双层for循环'

cpp 复制代码
class Solution {
public:
    vector<TreeNode*> build(int start, int end){
        vector<TreeNode*> allTrees;
        if(start > end){
            allTrees.push_back(nullptr);
            return allTrees;
        }

        for(int i = start;i<=end;i++){
            vector<TreeNode*> leftsubTrees = build(start,i-1);
            vector<TreeNode*> rightsubTrees = build(i+1,end);

            for(TreeNode* left : leftsubTrees){
                for(TreeNode* right :rightsubTrees){
                    TreeNode* root = new TreeNode(i);
                    root->left = left;
                    root->right = right;
                    allTrees.push_back(root);
                }
            }
        }
        return allTrees;

    }

    vector<TreeNode*> generateTrees(int n) {
        if(n == 0) return {};
        return build(1,n);
    }
};
相关推荐
liu****2 小时前
4.Qt窗口开发全解析:菜单栏、工具栏、状态栏及对话框实战
数据库·c++·qt·系统架构
近津薪荼2 小时前
优选算法——双指针6(单调性)
c++·学习·算法
向哆哆2 小时前
画栈 · 跨端画师接稿平台:基于 Flutter × OpenHarmony 的整体设计与数据结构解析
数据结构·flutter·开源·鸿蒙·openharmony·开源鸿蒙
helloworldandy2 小时前
高性能图像处理库
开发语言·c++·算法
2401_836563182 小时前
C++中的枚举类高级用法
开发语言·c++·算法
bantinghy2 小时前
Nginx基础加权轮询负载均衡算法
服务器·算法·nginx·负载均衡
chao1898443 小时前
矢量拟合算法在网络参数有理式拟合中的应用
开发语言·算法
代码无bug抓狂人3 小时前
动态规划(附带入门例题)
c语言·算法·动态规划
EmbedLinX3 小时前
C++ 面向对象
开发语言·c++