【算法笔记】有序表——SB树

目录

《【算法笔记】有序表------AVL树》
《【算法笔记】有序表------SB树》
《【算法笔记】有序表------跳表》
《【算法笔记】有序表------相关题目》


1、算法概述

  • SB树:SBT(节点大小平衡树)是一种通过维护子树节点数大小信息来实现平衡的自平衡二叉搜索树。

  • 它由陈启峰在2006年提出,核心思想是利用节点的子树大小关系来指导平衡调整,同时天然支持高效的顺序统计操作。

  • SBT的平衡性质:SBT是根据节点树的大小size来判断平衡性的。

  • SBT的平衡定义为:任意节点的 size不小于其兄弟节点的任意子节点(即其"侄子"节点)的size。

  • 当这组性质被破坏时,SBT会通过旋转操作进行调整以恢复平衡。

  • 和AVL树一样,SBT同样存在LL、LR、RL、RR四种不同的不平衡的情况,

  • 可以通过判断顺,来处理LL、LL和LR同时存在的情况,RL和RR同时存在的情况也是一样的

  • 当出现不平衡情况的时候,也是和AVL树一样做对应的左旋和右旋操作即可。

  • BST的平衡性要求没有AVL那样严格,同时天然支持高效的顺序统计操作。

  • 同时,一般BST树在删除节点的时候不去维护平衡操作,在加入节点的时候集中维护平衡操作。

  • 时间复杂度:O(logn)

2、利用SB树实现的自定义map

java 复制代码
     /**
     * 利用SB树实现的自定义map
     */
    public static class SizeBalancedTreeMap<K extends Comparable<K>, V> {

        private Node<K, V> root;

        public int size() {
            return root != null ? root.size : 0;
        }

        public boolean containsKey(K key) {
            if (key == null) {
                return false;
            }
            Node<K, V> node = findNode(key);
            return node != null;
        }

        public void put(K key, V value) {
            if (key == null) {
                return;
            }
            // 查找是否存在
            Node<K, V> node = findNode(key);
            if (node != null) {
                // 存在更新
                node.value = value;
                return;
            }
            // 不存在就添加
            this.root = add(this.root, key, value);
        }

        public void remove(K key) {
            if (key == null) {
                return;
            }
            if (containsKey(key)) {
                this.root = delete(this.root, key);
            }
        }

        public V get(K key) {
            if (key == null) {
                return null;
            }
            Node<K, V> node = findNode(key);
            return node != null ? node.value : null;
        }

        /**
         * 获取第index(从0开始)个元素的key
         */
        public K getIndexKey(int index) {
            if (index < 0 || index >= this.size()) {
                return null;
            }
            return getWithIndex(this.root, index + 1).key;
        }

        /**
         * 获取第index(从0开始)个元素的value
         */
        public V getIndexValue(int index) {
            if (index < 0 || index >= this.size()) {
                return null;
            }
            return getWithIndex(this.root, index + 1).value;
        }

        /**
         * 获得最小的 key
         */
        public K firstKey() {
            if (root == null) {
                return null;
            }
            Node<K, V> cur = root;
            while (cur.left != null) {
                cur = cur.left;
            }
            return cur.key;
        }

        /**
         * 获得最大的 key
         */
        public K lastKey() {
            if (root == null) {
                return null;
            }
            Node<K, V> cur = root;
            while (cur.right != null) {
                cur = cur.right;
            }
            return cur.key;
        }

        /**
         * 获得小于等于 key的最大的数
         */
        public K floorKey(K key) {
            if (key == null) {
                return null;
            }
            Node<K, V> lastNoBigNode = findLastNoBig(key);
            return lastNoBigNode == null ? null : lastNoBigNode.key;
        }

        /**
         * 获得大于等于 key的最小的数
         */
        public K ceilingKey(K key) {
            if (key == null) {
                return null;
            }
            Node<K, V> lastNoSmallNode = findLastNoSmall(key);
            return lastNoSmallNode == null ? null : lastNoSmallNode.key;
        }

        /**
         * 对cur节点进行左旋操作
         */
        private Node<K, V> leftRotate(Node<K, V> cur) {
            if (cur == null || cur.right == null) {
                return cur;
            }
            Node<K, V> right = cur.right;
            cur.right = right.left;
            right.left = cur;
            // 调整cur和right的节点数
            //right.size = cur.size;
            cur.size = (cur.left != null ? cur.left.size : 0) + (cur.right != null ? cur.right.size : 0) + 1;
            right.size = right.left.size + (right.right != null ? right.right.size : 0) + 1;
            return right;
        }

        /**
         * 对cur节点进行右旋操作
         */
        private Node<K, V> rightRotate(Node<K, V> cur) {
            if (cur == null || cur.left == null) {
                return cur;
            }
            Node<K, V> left = cur.left;
            cur.left = left.right;
            left.right = cur;
            // 调整cur和left的节点数
            //left.size = cur.size;
            cur.size = (cur.left != null ? cur.left.size : 0) + (cur.right != null ? cur.right.size : 0) + 1;
            left.size = (left.left != null ? left.left.size : 0) + left.right.size + 1;
            return left;
        }


        /**
         * 调整指定节点cur的平衡性
         */
        private Node<K, V> maintain(Node<K, V> cur) {
            if (cur == null) {
                return null;
            }

            // 左节点的大小
            int leftSize = cur.left != null ? cur.left.size : 0;
            // 左节点的左节点的大小
            int leftLeftSize = cur.left != null && cur.left.left != null ? cur.left.left.size : 0;
            // 左节点的右节点大小
            int leftRightSize = cur.left != null && cur.left.right != null ? cur.left.right.size : 0;
            // 右节点的大小
            int rightSize = cur.right != null ? cur.right.size : 0;
            // 右节点的左节点的大小
            int rightLeftSize = cur.right != null && cur.right.left != null ? cur.right.left.size : 0;
            // 右节点的右节点的大小
            int rightRightSize = cur.right != null && cur.right.right != null ? cur.right.right.size : 0;

            // 根据情况处理
            if (leftLeftSize > rightSize) {
                // LL : 右旋,一定要先处理LL,在处理LR
                cur = rightRotate(cur);
                // 影响的节点是转到右边的right和新的cur,需要调整这两个
                cur.right = maintain(cur.right);
                cur = maintain(cur);

            } else if (leftRightSize > rightSize) {
                // LR : 先将left左旋,然后在当前节点右旋
                cur.left = leftRotate(cur.left);
                cur = rightRotate(cur);
                // left左旋时影响left和left的left,当前节点右旋影响的是right和cur,
                // 因为left变为了新的cur,所以影响到的是left、right和cur三个节点调整
                cur.left = maintain(cur.left);
                cur.right = maintain(cur.right);
                cur = maintain(cur);
            } else if (rightRightSize > leftSize) {
                // RR : 左旋,一定要先处理RR,在处理RL
                cur = leftRotate(cur);
                cur.left = maintain(cur.left);
                cur = maintain(cur);
            } else if (rightLeftSize > leftSize) {
                // RL : 先将right右旋,然后cur左旋,同样影响left、right和cur三个节点
                cur.right = rightRotate(cur.right);
                cur = leftRotate(cur);
                cur.left = maintain(cur.left);
                cur.right = maintain(cur.right);
                cur = maintain(cur);
            }
            return cur;
        }

        /**
         * 根据指定的key查找节点
         * 如果存在就返回那个节点,不存在就返回null
         */
        private Node<K, V> findNode(K key) {
            if (key == null) {
                return null;
            }
            Node<K, V> cur = this.root;
            while (cur != null) {
                int cmt = key.compareTo(cur.key);
                if (cmt == 0) {
                    // 值相等,直接返回
                    return cur;
                }
                if (cmt < 0) {
                    // key比当前的值小,需要在左侧进行查找
                    cur = cur.left;
                } else {
                    // key比当前值大,需要在右侧进行查找
                    cur = cur.right;
                }
            }
            // 到这里,说明挑出了while循环,此时cur == null,说明没找到
            return null;
        }

        /**
         * 查找最后一个不小于指定key的节点
         * 不小于指定key指的是离key最近的大于等于key的值的节点
         * 方法:
         * 按照BST的方式进行查找,相等的时候直接返回,
         * 往左侧查找的时候,先用一个变量将当前节点记录下来,因为当前节点比要查找的大,
         * 这样挑出循环,说明没找到相等的,变量记录的就是比key大的最近的节点
         */
        private Node<K, V> findLastNoSmall(K key) {
            if (key == null) {
                return null;
            }
            Node<K, V> ans = null;
            Node<K, V> cur = root;
            while (cur != null) {
                int cmt = key.compareTo(cur.key);
                if (cmt == 0) {
                    // 相等,直接返回当前节点
                    return cur;
                }
                if (cmt < 0) {
                    // key 比当前值小,先记录当前值,再从左侧查找
                    ans = cur;
                    cur = cur.left;
                } else {
                    cur = cur.right;
                }
            }
            // 到这里,说明没有找到相等的,返回记录的值
            return ans;
        }

        /**
         * 查找最后一个不大于指定key的节点
         * 不大于指定key的意思是小于等于key
         * 方法:
         * 按照BST的方法查找,
         * 找到相等的,直接返回
         * 没有相等时,在往右侧查找的时候,用一个变量记录当前的值,当前值小于key
         * 这样到最后,返回变量记录的节点,就是小于key最近的节点
         */
        private Node<K, V> findLastNoBig(K key) {
            if (key == null) {
                return null;
            }
            Node<K, V> ans = null;
            Node<K, V> cur = root;
            while (cur != null) {
                int cmt = key.compareTo(cur.key);
                if (cmt == 0) {
                    // 相等,直接返回当前节点
                    return cur;
                }
                if (cmt < 0) {
                    cur = cur.left;
                } else {
                    // key 比当前值大,先记录当前值,再从右侧查找
                    ans = cur;
                    cur = cur.right;
                }
            }
            // 到这里,说明没有找到相等的,返回记录的值
            return ans;
        }

        /**
         * 给指定的子树cur添加记录
         */
        private Node<K, V> add(Node<K, V> cur, K key, V value) {
            if (cur == null) {
                // 要添加的子树为null,说明当前值就是添加到这个位置,返回一个新节点
                return new Node<>(key, value);
            }
            int cmt = key.compareTo(cur.key);
            if (cmt == 0) {
                // 只更新值
                cur.value = value;
                return cur;
            }
            if (cmt < 0) {
                // key比cur的值小,添加到左子树
                cur.left = add(cur.left, key, value);
            } else {
                // key比cur的值大,添加到右子树
                cur.right = add(cur.right, key, value);
            }
            cur.size++;
            // 调整当前节点
            return maintain(cur);
        }

        /**
         * 在指定子树cur上删除节点
         */
        private Node<K, V> delete(Node<K, V> cur, K key) {
            if (cur == null || key == null) {
                return cur;
            }
            cur.size--;
            int cmt = key.compareTo(cur.key);
            if (cmt < 0) {
                // key比当前值小,从左侧子树删除
                cur.left = delete(cur.left, key);
            } else if (cmt > 0) {
                // key比当前值大,从右侧子树删除
                cur.right = delete(cur.right, key);
            } else {
                // cur就是要删除的节点,根据删除逻辑删除即可
                if (cur.left == null && cur.right == null) {
                    // 无子节点,直接删除
                    cur = null;
                } else if (cur.left == null) {
                    // 有右节点,用右节点代替
                    cur = cur.right;
                } else if (cur.right == null) {
                    // 有左节点,用左节点代替
                    cur = cur.left;
                } else {
                    // 两个节点都有,用右子树中的最小值来替代,然后递归删除那个节点
                    // 因为在删除的过程中要调整整个路径上的size,所以需要一个pre指向前一个节点
                    Node<K, V> pre = null;
                    Node<K, V> des = cur.right;
                    des.size--;
                    while (des.left != null) {
                        // 找到右子树的最左侧节点,就是右树中的最小值
                        pre = des;
                        des = des.left;
                        des.size--;
                    }

                    if (pre != null) {
                        // 删掉des的left,因为pre肯定没有left节点了
                        pre.left = des.right;
                        // 这一句要放到这里,因为如果pre为null,代表cur.right没有左节点,直接接cur.left即可
                        des.right = cur.right;
                    }
                    // 用des替换cur
                    des.left = cur.left;
                    des.size = des.left.size + (des.right == null ? 0 : des.right.size) + 1;
                    cur = des;
                }

            }
            return cur;
        }

        /**
         * 从子树cur中获得第index个的元素
         * 根据这个方法,可以很方便的找到一个范围上的元素值
         */
        private Node<K, V> getWithIndex(Node<K, V> cur, int index) {
            if (cur == null || index < 0) {
                return null;
            }
            // 获得左侧子元素的个数
            int leftSize = cur.left != null ? cur.left.size : 0;
            // 如果左侧子元素的个数+1和要的index相等,则根据BST的性质,cur就是需要找的节点
            if (index == leftSize + 1) {
                return cur;
            }
            // 如果index小于或者等于左侧子元素的个数,就直接去左侧找
            if (index <= leftSize) {
                return getWithIndex(cur.left, index);
            }
            // 到这里,说明index大于leftSize + 1的值,需要去右侧找,去右侧的时候,要把左侧的个数减去
            return getWithIndex(cur.right, index - leftSize - 1);
        }

        /**
         * 节点类
         */
        static class Node<K extends Comparable<K>, V> {
            private final K key;
            private V value;
            private Node<K, V> left;
            private Node<K, V> right;
            // 子树的节点数量
            public int size;

            public Node(K key, V value) {
                this.key = key;
                this.value = value;
                this.size = 1;
            }
        }
    }

整体代码和测试:

java 复制代码
import java.util.Objects;
import java.util.TreeMap;

/**
 * SB树:SBT(节点大小平衡树)是一种通过维护子树节点数大小信息来实现平衡的自平衡二叉搜索树。
 * 它由陈启峰在2006年提出,核心思想是利用节点的子树大小关系来指导平衡调整,同时天然支持高效的顺序统计操作。
 * <br>
 * SBT的平衡性质:SBT是根据节点树的大小size来判断平衡性的。
 * SBT的平衡定义为:任意节点的 size不小于其兄弟节点的任意子节点(即其"侄子"节点)的size。
 * 当这组性质被破坏时,SBT会通过旋转操作进行调整以恢复平衡。
 * 和AVL树一样,SBT同样存在LL、LR、RL、RR四种不同的不平衡的情况,
 * 可以通过判断顺,来处理LL、LL和LR同时存在的情况,RL和RR同时存在的情况也是一样的
 * 当出现不平衡情况的时候,也是和AVL树一样做对应的左旋和右旋操作即可。
 * BST的平衡性要求没有AVL那样严格,同时天然支持高效的顺序统计操作。
 * 同时,一般BST树在删除节点的时候不去维护平衡操作,在加入节点的时候集中维护平衡操作。
 * <br>
 * 时间复杂度:O(logn)
 */
public class SizeBalancedTree {


    /**
     * 利用SB树实现的自定义map
     */
    public static class SizeBalancedTreeMap<K extends Comparable<K>, V> {

        private Node<K, V> root;

        public int size() {
            return root != null ? root.size : 0;
        }

        public boolean containsKey(K key) {
            if (key == null) {
                return false;
            }
            Node<K, V> node = findNode(key);
            return node != null;
        }

        public void put(K key, V value) {
            if (key == null) {
                return;
            }
            // 查找是否存在
            Node<K, V> node = findNode(key);
            if (node != null) {
                // 存在更新
                node.value = value;
                return;
            }
            // 不存在就添加
            this.root = add(this.root, key, value);
        }

        public void remove(K key) {
            if (key == null) {
                return;
            }
            if (containsKey(key)) {
                this.root = delete(this.root, key);
            }
        }

        public V get(K key) {
            if (key == null) {
                return null;
            }
            Node<K, V> node = findNode(key);
            return node != null ? node.value : null;
        }

        /**
         * 获取第index(从0开始)个元素的key
         */
        public K getIndexKey(int index) {
            if (index < 0 || index >= this.size()) {
                return null;
            }
            return getWithIndex(this.root, index + 1).key;
        }

        /**
         * 获取第index(从0开始)个元素的value
         */
        public V getIndexValue(int index) {
            if (index < 0 || index >= this.size()) {
                return null;
            }
            return getWithIndex(this.root, index + 1).value;
        }

        /**
         * 获得最小的 key
         */
        public K firstKey() {
            if (root == null) {
                return null;
            }
            Node<K, V> cur = root;
            while (cur.left != null) {
                cur = cur.left;
            }
            return cur.key;
        }

        /**
         * 获得最大的 key
         */
        public K lastKey() {
            if (root == null) {
                return null;
            }
            Node<K, V> cur = root;
            while (cur.right != null) {
                cur = cur.right;
            }
            return cur.key;
        }

        /**
         * 获得小于等于 key的最大的数
         */
        public K floorKey(K key) {
            if (key == null) {
                return null;
            }
            Node<K, V> lastNoBigNode = findLastNoBig(key);
            return lastNoBigNode == null ? null : lastNoBigNode.key;
        }

        /**
         * 获得大于等于 key的最小的数
         */
        public K ceilingKey(K key) {
            if (key == null) {
                return null;
            }
            Node<K, V> lastNoSmallNode = findLastNoSmall(key);
            return lastNoSmallNode == null ? null : lastNoSmallNode.key;
        }

        /**
         * 对cur节点进行左旋操作
         */
        private Node<K, V> leftRotate(Node<K, V> cur) {
            if (cur == null || cur.right == null) {
                return cur;
            }
            Node<K, V> right = cur.right;
            cur.right = right.left;
            right.left = cur;
            // 调整cur和right的节点数
            //right.size = cur.size;
            cur.size = (cur.left != null ? cur.left.size : 0) + (cur.right != null ? cur.right.size : 0) + 1;
            right.size = right.left.size + (right.right != null ? right.right.size : 0) + 1;
            return right;
        }

        /**
         * 对cur节点进行右旋操作
         */
        private Node<K, V> rightRotate(Node<K, V> cur) {
            if (cur == null || cur.left == null) {
                return cur;
            }
            Node<K, V> left = cur.left;
            cur.left = left.right;
            left.right = cur;
            // 调整cur和left的节点数
            //left.size = cur.size;
            cur.size = (cur.left != null ? cur.left.size : 0) + (cur.right != null ? cur.right.size : 0) + 1;
            left.size = (left.left != null ? left.left.size : 0) + left.right.size + 1;
            return left;
        }


        /**
         * 调整指定节点cur的平衡性
         */
        private Node<K, V> maintain(Node<K, V> cur) {
            if (cur == null) {
                return null;
            }

            // 左节点的大小
            int leftSize = cur.left != null ? cur.left.size : 0;
            // 左节点的左节点的大小
            int leftLeftSize = cur.left != null && cur.left.left != null ? cur.left.left.size : 0;
            // 左节点的右节点大小
            int leftRightSize = cur.left != null && cur.left.right != null ? cur.left.right.size : 0;
            // 右节点的大小
            int rightSize = cur.right != null ? cur.right.size : 0;
            // 右节点的左节点的大小
            int rightLeftSize = cur.right != null && cur.right.left != null ? cur.right.left.size : 0;
            // 右节点的右节点的大小
            int rightRightSize = cur.right != null && cur.right.right != null ? cur.right.right.size : 0;

            // 根据情况处理
            if (leftLeftSize > rightSize) {
                // LL : 右旋,一定要先处理LL,在处理LR
                cur = rightRotate(cur);
                // 影响的节点是转到右边的right和新的cur,需要调整这两个
                cur.right = maintain(cur.right);
                cur = maintain(cur);

            } else if (leftRightSize > rightSize) {
                // LR : 先将left左旋,然后在当前节点右旋
                cur.left = leftRotate(cur.left);
                cur = rightRotate(cur);
                // left左旋时影响left和left的left,当前节点右旋影响的是right和cur,
                // 因为left变为了新的cur,所以影响到的是left、right和cur三个节点调整
                cur.left = maintain(cur.left);
                cur.right = maintain(cur.right);
                cur = maintain(cur);
            } else if (rightRightSize > leftSize) {
                // RR : 左旋,一定要先处理RR,在处理RL
                cur = leftRotate(cur);
                cur.left = maintain(cur.left);
                cur = maintain(cur);
            } else if (rightLeftSize > leftSize) {
                // RL : 先将right右旋,然后cur左旋,同样影响left、right和cur三个节点
                cur.right = rightRotate(cur.right);
                cur = leftRotate(cur);
                cur.left = maintain(cur.left);
                cur.right = maintain(cur.right);
                cur = maintain(cur);
            }
            return cur;
        }

        /**
         * 根据指定的key查找节点
         * 如果存在就返回那个节点,不存在就返回null
         */
        private Node<K, V> findNode(K key) {
            if (key == null) {
                return null;
            }
            Node<K, V> cur = this.root;
            while (cur != null) {
                int cmt = key.compareTo(cur.key);
                if (cmt == 0) {
                    // 值相等,直接返回
                    return cur;
                }
                if (cmt < 0) {
                    // key比当前的值小,需要在左侧进行查找
                    cur = cur.left;
                } else {
                    // key比当前值大,需要在右侧进行查找
                    cur = cur.right;
                }
            }
            // 到这里,说明挑出了while循环,此时cur == null,说明没找到
            return null;
        }

        /**
         * 查找最后一个不小于指定key的节点
         * 不小于指定key指的是离key最近的大于等于key的值的节点
         * 方法:
         * 按照BST的方式进行查找,相等的时候直接返回,
         * 往左侧查找的时候,先用一个变量将当前节点记录下来,因为当前节点比要查找的大,
         * 这样挑出循环,说明没找到相等的,变量记录的就是比key大的最近的节点
         */
        private Node<K, V> findLastNoSmall(K key) {
            if (key == null) {
                return null;
            }
            Node<K, V> ans = null;
            Node<K, V> cur = root;
            while (cur != null) {
                int cmt = key.compareTo(cur.key);
                if (cmt == 0) {
                    // 相等,直接返回当前节点
                    return cur;
                }
                if (cmt < 0) {
                    // key 比当前值小,先记录当前值,再从左侧查找
                    ans = cur;
                    cur = cur.left;
                } else {
                    cur = cur.right;
                }
            }
            // 到这里,说明没有找到相等的,返回记录的值
            return ans;
        }

        /**
         * 查找最后一个不大于指定key的节点
         * 不大于指定key的意思是小于等于key
         * 方法:
         * 按照BST的方法查找,
         * 找到相等的,直接返回
         * 没有相等时,在往右侧查找的时候,用一个变量记录当前的值,当前值小于key
         * 这样到最后,返回变量记录的节点,就是小于key最近的节点
         */
        private Node<K, V> findLastNoBig(K key) {
            if (key == null) {
                return null;
            }
            Node<K, V> ans = null;
            Node<K, V> cur = root;
            while (cur != null) {
                int cmt = key.compareTo(cur.key);
                if (cmt == 0) {
                    // 相等,直接返回当前节点
                    return cur;
                }
                if (cmt < 0) {
                    cur = cur.left;
                } else {
                    // key 比当前值大,先记录当前值,再从右侧查找
                    ans = cur;
                    cur = cur.right;
                }
            }
            // 到这里,说明没有找到相等的,返回记录的值
            return ans;
        }

        /**
         * 给指定的子树cur添加记录
         */
        private Node<K, V> add(Node<K, V> cur, K key, V value) {
            if (cur == null) {
                // 要添加的子树为null,说明当前值就是添加到这个位置,返回一个新节点
                return new Node<>(key, value);
            }
            int cmt = key.compareTo(cur.key);
            if (cmt == 0) {
                // 只更新值
                cur.value = value;
                return cur;
            }
            if (cmt < 0) {
                // key比cur的值小,添加到左子树
                cur.left = add(cur.left, key, value);
            } else {
                // key比cur的值大,添加到右子树
                cur.right = add(cur.right, key, value);
            }
            cur.size++;
            // 调整当前节点
            return maintain(cur);
        }

        /**
         * 在指定子树cur上删除节点
         */
        private Node<K, V> delete(Node<K, V> cur, K key) {
            if (cur == null || key == null) {
                return cur;
            }
            cur.size--;
            int cmt = key.compareTo(cur.key);
            if (cmt < 0) {
                // key比当前值小,从左侧子树删除
                cur.left = delete(cur.left, key);
            } else if (cmt > 0) {
                // key比当前值大,从右侧子树删除
                cur.right = delete(cur.right, key);
            } else {
                // cur就是要删除的节点,根据删除逻辑删除即可
                if (cur.left == null && cur.right == null) {
                    // 无子节点,直接删除
                    cur = null;
                } else if (cur.left == null) {
                    // 有右节点,用右节点代替
                    cur = cur.right;
                } else if (cur.right == null) {
                    // 有左节点,用左节点代替
                    cur = cur.left;
                } else {
                    // 两个节点都有,用右子树中的最小值来替代,然后递归删除那个节点
                    // 因为在删除的过程中要调整整个路径上的size,所以需要一个pre指向前一个节点
                    Node<K, V> pre = null;
                    Node<K, V> des = cur.right;
                    des.size--;
                    while (des.left != null) {
                        // 找到右子树的最左侧节点,就是右树中的最小值
                        pre = des;
                        des = des.left;
                        des.size--;
                    }

                    if (pre != null) {
                        // 删掉des的left,因为pre肯定没有left节点了
                        pre.left = des.right;
                        // 这一句要放到这里,因为如果pre为null,代表cur.right没有左节点,直接接cur.left即可
                        des.right = cur.right;
                    }
                    // 用des替换cur
                    des.left = cur.left;
                    des.size = des.left.size + (des.right == null ? 0 : des.right.size) + 1;
                    cur = des;
                }

            }
            return cur;
        }

        /**
         * 从子树cur中获得第index个的元素
         * 根据这个方法,可以很方便的找到一个范围上的元素值
         */
        private Node<K, V> getWithIndex(Node<K, V> cur, int index) {
            if (cur == null || index < 0) {
                return null;
            }
            // 获得左侧子元素的个数
            int leftSize = cur.left != null ? cur.left.size : 0;
            // 如果左侧子元素的个数+1和要的index相等,则根据BST的性质,cur就是需要找的节点
            if (index == leftSize + 1) {
                return cur;
            }
            // 如果index小于或者等于左侧子元素的个数,就直接去左侧找
            if (index <= leftSize) {
                return getWithIndex(cur.left, index);
            }
            // 到这里,说明index大于leftSize + 1的值,需要去右侧找,去右侧的时候,要把左侧的个数减去
            return getWithIndex(cur.right, index - leftSize - 1);
        }

        /**
         * 节点类
         */
        static class Node<K extends Comparable<K>, V> {
            private final K key;
            private V value;
            private Node<K, V> left;
            private Node<K, V> right;
            // 子树的节点数量
            public int size;

            public Node(K key, V value) {
                this.key = key;
                this.value = value;
                this.size = 1;
            }
        }
    }

    public static void main(String[] args) {
        SizeBalancedTreeMap<Integer, Integer> sbt = new SizeBalancedTreeMap<>();
        for (int i = 0; i < 10; i++) {
            sbt.put(i, i);
        }
        System.out.println(sbt.size());
        sbt.remove(5);
        System.out.println(sbt.size());
        for (int i = 0; i < sbt.size(); i++) {
            System.out.println(sbt.getIndexKey(i) + " , " + sbt.getIndexValue(i));
        }
        System.out.println("======");
        functionTest("SizeBalancedTreeMap");
        System.out.println("======");
        performanceTest("SizeBalancedTreeMap");
    }


    public static void functionTest(String prefix) {
        System.out.println(prefix + " 功能测试开始");
        TreeMap<Integer, Integer> treeMap = new TreeMap<>();
        SizeBalancedTreeMap<Integer, Integer> target = new SizeBalancedTreeMap<>();
        int maxK = 500;
        int maxV = 50000;
        int testTime = 1000000;
        boolean success = true;
        for (int i = 0; i < testTime; i++) {
            int addK = (int) (Math.random() * maxK);
            int addV = (int) (Math.random() * maxV);
            treeMap.put(addK, addV);
            target.put(addK, addV);

            int removeK = (int) (Math.random() * maxK);
            treeMap.remove(removeK);
            target.remove(removeK);

            int querryK = (int) (Math.random() * maxK);
            boolean treeMapAns = treeMap.containsKey(querryK);
            boolean ans = target.containsKey(querryK);
            if (treeMapAns != ans) {
                System.out.println("containsKey 错误");
                System.out.printf("key:%d, treeMapAns:%b, ans:%b\n", querryK, treeMapAns, ans);
                success = false;
                break;
            }

            if (treeMap.containsKey(querryK)) {
                int v1 = treeMap.get(querryK);
                int v2 = target.get(querryK);
                if (v1 != v2) {
                    System.out.println("get 错误");
                    System.out.printf("key:%d, treeMapAns:%d, ans:%d\n", querryK, v1, v2);
                    success = false;
                    break;
                }
                Integer f1 = treeMap.floorKey(querryK);
                Integer f2 = target.floorKey(querryK);
                if (!Objects.equals(f1, f2)) {
                    System.out.println("floorKey 错误");
                    System.out.printf("key:%d, treeMapAns:%d, ans:%d\n", querryK, f1, f2);
                    success = false;
                    break;
                }
                f1 = treeMap.ceilingKey(querryK);
                f2 = target.ceilingKey(querryK);
                if (!Objects.equals(f1, f2)) {
                    System.out.println("ceilingKey 错误");
                    System.out.printf("key:%d, treeMapAns:%d, ans:%d\n", querryK, f1, f2);
                    success = false;
                    break;
                }
            }

            Integer f1 = treeMap.firstKey();
            Integer f2 = target.firstKey();
            if (!Objects.equals(f1, f2)) {
                System.out.println("firstKey 错误");
                System.out.printf("key:%d, treeMapAns:%d, ans:%d\n", querryK, f1, f2);
                success = false;
                break;
            }

            f1 = treeMap.lastKey();
            f2 = target.lastKey();
            if (!Objects.equals(f1, f2)) {
                System.out.println("lastKey 错误");
                System.out.printf("key:%d, treeMapAns:%d, ans:%d\n", querryK, f1, f2);
                success = false;
                break;
            }
            int treeMapSize = treeMap.size();
            int ansSize = target.size();
            if (treeMapSize != ansSize) {
                System.out.println("size 错误");
                System.out.printf("key:%d, treeMapAns:%d, ans:%d\n", querryK, treeMapSize, ansSize);
                success = false;
                break;
            }
        }
        if (!success) {
            System.out.println("测试失败");
            return;
        }
        System.out.println(prefix + " 功能测试结束");
    }

    public static void performanceTest(String prefix) {
        System.out.println(prefix + " 性能测试开始");
        TreeMap<Integer, Integer> treeMap = new TreeMap<>();
        SizeBalancedTreeMap<Integer, Integer> target = new SizeBalancedTreeMap<>();
        long start;
        long end;
        int max = 1000000;
        System.out.println("顺序递增加入测试,数据规模 : " + max);
        start = System.currentTimeMillis();
        for (int i = 0; i < max; i++) {
            treeMap.put(i, i);
        }
        end = System.currentTimeMillis();
        System.out.println("treeMap 运行时间 : " + (end - start) + "ms");

        start = System.currentTimeMillis();
        for (int i = 0; i < max; i++) {
            target.put(i, i);
        }
        end = System.currentTimeMillis();
        System.out.println(prefix + " 运行时间 : " + (end - start) + "ms");


        System.out.println("顺序递增删除测试,数据规模 : " + max);
        start = System.currentTimeMillis();
        for (int i = 0; i < max; i++) {
            treeMap.remove(i);
        }
        end = System.currentTimeMillis();
        System.out.println("treeMap 运行时间 : " + (end - start) + "ms");

        start = System.currentTimeMillis();
        for (int i = 0; i < max; i++) {
            target.remove(i);
        }
        end = System.currentTimeMillis();
        System.out.println(prefix + " 运行时间 : " + (end - start) + "ms");

        System.out.println("顺序递减加入测试,数据规模 : " + max);
        start = System.currentTimeMillis();
        for (int i = max; i >= 0; i--) {
            treeMap.put(i, i);
        }
        end = System.currentTimeMillis();
        System.out.println("treeMap 运行时间 : " + (end - start) + "ms");

        start = System.currentTimeMillis();
        for (int i = max; i >= 0; i--) {
            target.put(i, i);
        }
        end = System.currentTimeMillis();
        System.out.println(prefix + " 运行时间 : " + (end - start) + "ms");

        System.out.println("顺序递减删除测试,数据规模 : " + max);
        start = System.currentTimeMillis();
        for (int i = max; i >= 0; i--) {
            treeMap.remove(i);
        }
        end = System.currentTimeMillis();
        System.out.println("treeMap 运行时间 : " + (end - start) + "ms");

        start = System.currentTimeMillis();
        for (int i = max; i >= 0; i--) {
            target.remove(i);
        }
        end = System.currentTimeMillis();
        System.out.println(prefix + " 运行时间 : " + (end - start) + "ms");


        System.out.println("随机加入测试,数据规模 : " + max);
        start = System.currentTimeMillis();
        for (int i = 0; i < max; i++) {
            treeMap.put((int) (Math.random() * i), i);
        }
        end = System.currentTimeMillis();
        System.out.println("treeMap 运行时间 : " + (end - start) + "ms");

        start = System.currentTimeMillis();
        for (int i = max; i >= 0; i--) {
            target.put((int) (Math.random() * i), i);
        }
        end = System.currentTimeMillis();
        System.out.println(prefix + " 运行时间 : " + (end - start) + "ms");

        System.out.println("随机删除测试,数据规模 : " + max);
        start = System.currentTimeMillis();
        for (int i = 0; i < max; i++) {
            treeMap.remove((int) (Math.random() * i));
        }
        end = System.currentTimeMillis();
        System.out.println("treeMap 运行时间 : " + (end - start) + "ms");

        start = System.currentTimeMillis();
        for (int i = max; i >= 0; i--) {
            target.remove((int) (Math.random() * i));
        }
        end = System.currentTimeMillis();
        System.out.println(prefix + " 运行时间 : " + (end - start) + "ms");

        System.out.println(prefix + " 性能测试结束");
    }
}

后记

个人学习总结笔记,不能保证非常详细,轻喷

相关推荐
曾几何时`2 小时前
滑动窗口(十五)2962. 统计最大元素出现至少 K 次的子数组(越长越合法型)
数据结构·算法
AI视觉网奇2 小时前
NVIDIA 生成key
笔记·nvidia
代码游侠2 小时前
应用——SQLite3 C 编程学习
linux·服务器·c语言·数据库·笔记·网络协议·sqlite
究极无敌暴龙战神X2 小时前
机器学习相关
人工智能·算法·机器学习
会思考的猴子2 小时前
UE5 笔记二 GameplayAbilitySystem Attributes & Effects
笔记·ue5
走在路上的菜鸟2 小时前
Android学Dart学习笔记第二十八节 Isolates
android·笔记·学习·flutter
断剑zou天涯2 小时前
【算法笔记】有序表——跳表
笔记·算法
挖矿大亨2 小时前
c++中的函数调用运算符重载
前端·c++·算法
wadesir2 小时前
Rust语言BM算法实现(从零开始掌握Boyer-Moore字符串搜索算法)
算法·rust·.net