【leetcode练习·二叉树拓展】归并排序详解及应用

本文参考labuladong算法笔记[拓展:归并排序详解及应用 | labuladong 的算法笔记]

"归并排序就是二叉树的后序遍历"------labuladong

就说归并排序吧,如果给你看代码,让你脑补一下归并排序的过程,你脑子里会出现什么场景?

这是一个数组排序算法,所以你脑补一个数组的 GIF,在那一个个交换元素?如果是这样的话,那格局就低了。

但如果你脑海中浮现出的是一棵二叉树,甚至浮现出二叉树后序遍历的场景,那格局就高了,大概率掌握了我经常强调的 框架思维,用这种抽象能力学习算法就省劲多了。

那么,归并排序明明就是一个数组算法,和二叉树有什么关系?接下来我就具体讲讲。

1、算法思路

就这么说吧,所有递归的算法,你甭管它是干什么的,本质上都是在遍历一棵(递归)树,然后在节点(前中后序位置)上执行代码,你要写递归算法,本质上就是要告诉每个节点需要做什么

你看归并排序的代码框架:

python 复制代码
# 定义:排序 nums[lo..hi]
def sort(nums: List[int], lo: int, hi: int) -> None:
    if lo == hi:
        return
    mid = (lo + hi) // 2
    # 利用定义,排序 nums[lo..mid]
    sort(nums, lo, mid)
    # 利用定义,排序 nums[mid+1..hi]
    sort(nums, mid + 1, hi)

    # ***** 后序位置 *****
    # 此时两部分子数组已经被排好序
    # 合并两个有序数组,使 nums[lo..hi] 有序
    merge(nums, lo, mid, hi)

# 将有序数组 nums[lo..mid] 和有序数组 nums[mid+1..hi]
# 合并为有序数组 nums[lo..hi]
def merge(nums: List[int], lo: int, mid: int, hi: int) -> None:
    pass

看这个框架,也就明白那句经典的总结:归并排序就是先把左半边数组排好序,再把右半边数组排好序,然后把两半数组合并。

上述代码和二叉树的后序遍历很像:

python 复制代码
# 二叉树遍历框架 
def traverse(root: TreeNode) -> None: 
    if root is None:
        return
    traverse(root.left)
    traverse(root.right)
    # 后序位置
    print(root.val)

再进一步,你联想一下求二叉树的最大深度的算法代码:

python 复制代码
# 定义:输入根节点,返回这棵二叉树的最大深度
def maxDepth(root: TreeNode) -> int:
	if not root:
		return 0
	# 利用定义,计算左右子树的最大深度
	leftMax = maxDepth(root.left)
	rightMax = maxDepth(root.right)
	# 整棵树的最大深度等于左右子树的最大深度取最大值,
	# 然后再加上根节点自己
	res = max(leftMax, rightMax) + 1

	return res

本质上就是二叉树的后序遍历------先拿到子树,再合并子树

前文 手把手刷二叉树(纲领篇) 说二叉树问题可以分为两类思路,一类是遍历一遍二叉树的思路,另一类是分解问题的思路,根据上述类比,显然归并排序利用的是分解问题的思路(分治算法)。

归并排序的过程可以在逻辑上抽象成一棵二叉树,树上的每个节点的值可以认为是 nums[lo..hi],叶子节点的值就是数组中的单个元素

然后,在每个节点的后序位置(左右子节点已经被排好序)的时候执行 merge 函数,合并两个子节点上的子数组:

这个 merge 操作会在二叉树的每个节点上都执行一遍,执行顺序是二叉树后序遍历的顺序。

后序遍历二叉树大家应该已经烂熟于心了,就是下图这个遍历顺序:

结合上述基本分析,我们把 nums[lo..hi] 理解成二叉树的节点,sort 函数理解成二叉树的遍历函数,整个归并排序的执行过程就是以下 GIF 描述的这样:

这样,归并排序的核心思路就分析完了,接下来只要把思路翻译成代码就行。

2、代码实现

只要拥有了正确的思维方式,理解算法思路是不困难的,但把思路实现成代码,也很考验一个人的编程能力

毕竟算法的时间复杂度只是一个理论上的衡量标准,而算法的实际运行效率要考虑的因素更多,比如应该避免内存的频繁分配释放,代码逻辑应尽可能简洁等等。

python 复制代码
class Merge:

    # 用于辅助合并有序数组
    temp = []

    @staticmethod
    def sort(nums):
        # 先给辅助数组开辟内存空间
        Merge.temp = [0] * len(nums)
        # 排序整个数组(原地修改)
        Merge._sort(nums, 0, len(nums) - 1)

    # 定义:将子数组 nums[lo..hi] 进行排序
    @staticmethod
    def _sort(nums, lo, hi):
        if lo == hi:
            # 单个元素不用排序
            return
        # 这样写是为了防止溢出,效果等同于 (hi + lo) / 2
        mid = lo + (hi - lo) // 2
        # 先对左半部分数组 nums[lo..mid] 排序
        Merge._sort(nums, lo, mid)
        # 再对右半部分数组 nums[mid+1..hi] 排序
        Merge._sort(nums, mid + 1, hi)
        # 将两部分有序数组合并成一个有序数组
        Merge.merge(nums, lo, mid, hi)

    # 将 nums[lo..mid] 和 nums[mid+1..hi] 这两个有序数组合并成一个有序数组
    @staticmethod
    def merge(nums, lo, mid, hi):
        # 先把 nums[lo..hi] 复制到辅助数组中
        # 以便合并后的结果能够直接存入 nums
        for i in range(lo, hi+1):
            Merge.temp[i] = nums[i]

        # 数组双指针技巧,合并两个有序数组
        i, j = lo, mid + 1
        for p in range(lo, hi+1):
            if i == mid + 1:
                # 左半边数组已全部被合并
                nums[p] = Merge.temp[j]
                j += 1
            elif j == hi + 1:
                # 右半边数组已全部被合并
                nums[p] = Merge.temp[i]
                i += 1
            elif Merge.temp[i] > Merge.temp[j]:
                nums[p] = Merge.temp[j]
                j += 1
            else:
                nums[p] = Merge.temp[i]
                i += 1

有了之前的铺垫,这里只需要着重讲一下这个 merge 函数。

sort 函数对 nums[lo..mid]nums[mid+1..hi] 递归排序完成之后,我们没有办法原地把它俩合并,所以需要 copy 到 temp 数组里面,然后通过类似于前文 单链表的六大技巧 中合并有序链表的双指针技巧将 nums[lo..hi] 合并成一个有序数组:

注意我们不是在 merge 函数执行的时候 new 辅助数组,而是提前把 temp 辅助数组 new 出来了,这样就避免了在递归中频繁分配和释放内存可能产生的性能问题。

3、复杂度分析

再说一下归并排序的时间复杂度,虽然大伙儿应该都知道是O(NlogN),但不见得所有人都知道这个复杂度怎么算出来的。

前文 动态规划详解 说过递归算法的复杂度计算,就是子问题个数 x 解决一个子问题的复杂度。对于归并排序来说,时间复杂度显然集中在 merge 函数遍历 nums[lo..hi] 的过程,但每次 merge 输入的 lohi 都不同,所以不容易直观地看出时间复杂度。

merge 函数到底执行了多少次?每次执行的时间复杂度是多少?总的时间复杂度是多少?这就要结合之前画的这幅图来看:

执行的次数是二叉树节点的个数,每次执行的复杂度就是每个节点代表的子数组的长度,所以总的时间复杂度就是整棵树中「数组元素」的个数

所以从整体上看,这个二叉树的高度是 logN,其中每一层的元素个数就是原数组的长度 N,所以总的时间复杂度就是 O(NlogN)O(NlogN)。

4、力扣练习

912. 排序数组

给你一个整数数组 nums,请你将该数组升序排列。

你必须在 不使用任何内置函数 的情况下解决问题,时间复杂度为 O(nlog(n)),并且空间复杂度尽可能小。

示例 1:

复制代码
输入:nums = [5,2,3,1]
输出:[1,2,3,5]

示例 2:

复制代码
输入:nums = [5,1,1,2,0,0]
输出:[0,0,1,1,2,5]

提示:

  • 1 <= nums.length <= 5 * 104
  • -5 * 104 <= nums[i] <= 5 * 104

我们可以直接套用归并排序代码模板:

python 复制代码
class Solution:
    def sortArray(self, nums: List[int]) -> List[int]:
        Merge.sort(nums)
        return nums

class Merge:
    # 见上文

简化写法

python 复制代码
class Solution:
    def sortArray(self, nums: List[int]) -> List[int]:
        tmp_arr = [0] * len(nums)

        def sort(nums, lo, hi):
            if lo == hi:
                return

            mid = lo + (hi - lo) // 2
            sort(nums, lo, mid)
            sort(nums, mid + 1, hi)
            merge(nums, lo, mid, hi)

        def merge(nums, lo, mid, hi):
            for i in range(lo, hi + 1):
                tmp_arr[i] = nums[i]

            i, j = lo, mid + 1
            for p in range(lo, hi + 1):
                if i == mid + 1:
                    nums[p] = tmp_arr[j]
                    j += 1
                elif j == hi + 1:
                    nums[p] = tmp_arr[i]
                    i += 1
                elif tmp_arr[i] < tmp_arr[j]:
                    nums[p] = tmp_arr[i]
                    i += 1
                else:
                    nums[p] = tmp_arr[j]
                    j += 1

        sort(nums, 0, len(nums) - 1)
        return nums

315. 计算右侧小于当前元素的个数 | 力扣 | LeetCode | 🔴

给你一个整数数组 nums,按要求返回一个新数组 counts。数组 counts 有该性质: counts[i] 的值是 nums[i] 右侧小于 nums[i] 的元素的数量。

示例 1:

输入:nums = [5,2,6,1]
输出:[2,1,1,0] 
解释:
5 的右侧有 2 个更小的元素 (2 和 1)
2 的右侧仅有 1 个更小的元素 (1)
6 的右侧有 1 个更小的元素 (1)
1 的右侧有 0 个更小的元素

示例 2:

复制代码
输入:nums = [-1]
输出:[0]

示例 3:

复制代码
输入:nums = [-1,-1]
输出:[0,0]

提示:

  • 1 <= nums.length <= 105
  • -104 <= nums[i] <= 104

我用比较数学的语言来描述一下(方便和后续类似题目进行对比),题目让你求出一个 count 数组,使得:

python 复制代码
count[i] = COUNT(j) where j > i and nums[j] < nums[i]

拍脑袋的暴力解法就不说了,嵌套 for 循环,平方级别的复杂度。

这题和归并排序什么关系呢,主要在 merge 函数,我们在使用 merge 函数合并两个有序数组的时候,其实是可以知道一个元素 nums[i] 后边有多少个元素比 nums[i] 小的

具体来说,比如这个场景:

这时候我们应该把 temp[i] 放到 nums[p] 上,因为 temp[i] < temp[j]

但就在这个场景下,我们还可以知道一个信息:5 后面比 5 小的元素个数就是 左闭右开区间 [mid + 1, j) 中的元素个数,即 2 和 4 这两个元素:

换句话说,在对 nums[lo..hi] 合并的过程中,每当执行 nums[p] = temp[i] 时,就可以确定 temp[i] 这个元素后面比它小的元素个数为 j - mid - 1

当然,nums[lo..hi] 本身也只是一个子数组,这个子数组之后还会被执行 merge,其中元素的位置还是会改变。但这是其他递归节点需要考虑的问题,我们只要在 merge 函数中做一些手脚,叠加每次 merge 时记录的结果即可。

发现了这个规律后,我们只要在 merge 中添加两行代码即可解决这个问题,看解法代码:

python 复制代码
class Solution:
    class Pair:
        def __init__(self, val, id):
            # 记录数组的元素值
            self.val = val
            # 记录元素在数组中的原始索引
            self.id = id

    # 归并排序所用的辅助数组
    temp = []
    # 记录每个元素后面比自己小的元素个数
    count = []

    # 主函数
    def countSmaller(self, nums):
        n = len(nums)
        self.count = [0] * n
        self.temp = [self.Pair(0, 0)] * n
        arr = []
        # 记录元素原始的索引位置,以便在 count 数组中更新结果
        for i in range(n):
            arr.append(self.Pair(nums[i], i))

        # 执行归并排序,本题结果被记录在 count 数组中
        self.sort(arr, 0, n - 1)

        return self.count

    # 归并排序
    def sort(self, arr, lo, hi):
        if lo == hi:
            return
        mid = lo + (hi - lo) // 2
        self.sort(arr, lo, mid)
        self.sort(arr, mid + 1, hi)
        self.merge(arr, lo, mid, hi)

    # 合并两个有序数组
    def merge(self, arr, lo, mid, hi):
        for i in range(lo, hi + 1):
            self.temp[i] = arr[i]

        i, j = lo, mid + 1
        for p in range(lo, hi + 1):
            if i == mid + 1:
                arr[p] = self.temp[j]
                j += 1
            elif j == hi + 1:
                arr[p] = self.temp[i]
                i += 1
                # 此时j来到hi+1处,说明右侧数组所有元素均小于arr[p]
                self.count[arr[p].id] += j - mid - 1
            elif self.temp[i].val > self.temp[j].val:
                arr[p] = self.temp[j]
                j += 1
            else:
                # 此时temp[mid+1, j)的的元素都小于temp[i]
                arr[p] = self.temp[i]
                i += 1
                self.count[arr[p].id] += j - mid - 1

因为在排序过程中,每个元素的索引位置会不断改变,所以我们用一个 Pair 类封装每个元素及其在原始数组 nums 中的索引,以便 count 数组记录每个元素之后小于它的元素个数。

493. 翻转对 | 力扣 | LeetCode | 🔴

给定一个数组 nums ,如果 i < jnums[i] > 2*nums[j] 我们就将 (i, j) 称作一个重要翻转对

你需要返回给定数组中的重要翻转对的数量。

示例 1:

复制代码
输入: [1,3,2,3,1]
输出: 2

示例 2:

复制代码
输入: [2,4,3,5,1]
输出: 3

注意:

  1. 给定数组的长度不会超过50000
  2. 输入数组中的所有数字都在32位整数的表示范围内。

我把这道题换个表述方式,你注意和上一道题目对比:

请你先求出一个 count 数组,其中:

python 复制代码
count[i] = COUNT(j) where j > i and nums[i] > 2*nums[j]

然后请你求出这个 count 数组中所有元素的和。

你看,这样说其实和题目是一个意思,而且和上一道题非常类似,只不过上一题求的是 nums[i] > nums[j],这里求的是 nums[i] > 2*nums[j] 罢了。

所以解题的思路当然还是要在 merge 函数中做点手脚,当 nums[lo..mid]nums[mid+1..hi] 两个子数组完成排序后,对于 nums[lo..mid] 中的每个元素 nums[i],去 nums[mid+1..hi] 中寻找符合条件的 nums[j] 就行了。

看一下我们对上一题 merge 函数的改造:

python 复制代码
# 记录「翻转对」的个数
count = 0

# 将 nums[lo..mid] 和 nums[mid+1..hi] 这两个有序数组合并成一个有序数组
def merge(nums, lo, mid, hi):
    global count
    temp = nums.copy()

    # 在合并有序数组之前,加点私货
    for i in range(lo, mid + 1):
        # 对于左半边的每个 nums[i],都去右半边寻找符合条件的元素
        for j in range(mid + 1, hi + 1):
            # nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long
            if int(nums[i]) > 2 * int(nums[j]):
                count += 1

    # 数组双指针技巧,合并两个有序数组
    i, j = lo, mid + 1
    for p in range(lo, hi + 1):
        if (i == mid + 1):
            nums[p] = temp[j]
            j += 1
        elif (j == hi + 1):
            nums[p] = temp[i]
            i += 1
        elif temp[i] > temp[j]:
            nums[p] = temp[j]
            j += 1
        else:
            nums[p] = temp[i]
            i += 1

不过呢,这样修改代码会超时,毕竟额外添加了一个嵌套 for 循环。怎么进行优化呢,注意子数组 nums[lo..mid] 是排好序的,也就是 nums[i] <= nums[i+1]

所以,对于 nums[i], lo <= i <= mid,我们在找到的符合 nums[i] > 2*nums[j]nums[j], mid+1 <= j <= hi,也必然也符合 nums[i+1] > 2*nums[j]

换句话说,我们不用每次都傻乎乎地去遍历整个 nums[mid+1..hi],只要维护一个开区间边界 end,维护 nums[mid+1..end-1] 是符合条件的元素即可

看最终的解法代码:

python 复制代码
class Solution:
    def __init__(self):
        # 记录「翻转对」的个数
        self.count = 0
        self.temp = []

    def reversePairs(self, nums) -> int:
        # 执行归并排序
        self.temp = [0]*len(nums)
        self.sort(nums, 0, len(nums) - 1)
        return self.count

    def sort(self, nums, lo, hi):
        # 归并排序
        if lo >= hi:
            return

        mid = lo + (hi - lo) // 2
        self.sort(nums, lo, mid)
        self.sort(nums, mid + 1, hi)
        self.merge(nums, lo, mid, hi)

    def merge(self, nums, lo, mid, hi):
        for i in range(lo, hi+1):
            self.temp[i] = nums[i]
        
        # 维护左闭右开区间 [mid+1, end) 中的元素乘 2 小于 nums[i]
        # 为什么 end 是开区间?因为这样的话可以保证初始区间 [mid+1, mid+1) 是一个空区间
        end = mid + 1
        for i in range(lo, mid+1):
            # nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long
            while end <= hi and nums[i] > nums[end] * 2:
                end += 1

            self.count += end - (mid + 1)

        # 数组双指针技巧,合并两个有序数组
        i, j = lo, mid + 1
        for p in range(lo, hi+1):
            if i == mid + 1:
                nums[p] = self.temp[j]
                j += 1
            elif j == hi + 1:
                nums[p] = self.temp[i]
                i += 1
            elif self.temp[i] > self.temp[j]:
                nums[p] = self.temp[j]
                j += 1
            else:
                nums[p] = self.temp[i]
                i += 1

327. 区间和的个数 | 力扣 | LeetCode | 🔴

给你一个整数数组 nums 以及两个整数 lowerupper 。求数组中,值位于范围 [lower, upper] (包含 lowerupper)之内的 区间和的个数

区间和 S(i, j) 表示在 nums 中,位置从 ij 的元素之和,包含 ij (ij)。

示例 1:

复制代码
输入:nums = [-2,5,-1], lower = -2, upper = 2
输出:3
解释:存在三个区间:[0,0]、[2,2] 和 [0,2] ,对应的区间和分别是:-2 、-1 、2 。

示例 2:

复制代码
输入:nums = [0], lower = 0, upper = 0
输出:1

提示:

  • 1 <= nums.length <= 105
  • -231 <= nums[i] <= 231 - 1
  • -105 <= lower <= upper <= 105
  • 题目数据保证答案是一个 32 位 的整数

简单说,题目让你计算元素和落在 [lower, upper] 中的所有子数组的个数。

拍脑袋的暴力解法我就不说了,依然是嵌套 for 循环,这里还是说利用归并排序实现的高效算法。

首先,解决这道题需要快速计算子数组的和,所以你需要阅读前文 前缀和数组技巧,创建一个前缀和数组 preSum 来辅助我们迅速计算区间和。

我继续用比较数学的语言来表述下这道题,题目让你通过 preSum 数组求一个 count 数组,使得:

python 复制代码
count[i] = COUNT(j) where lower <= preSum[j] - preSum[i] <= upper

然后请你求出这个 count 数组中所有元素的和。

你看,这是不是和题目描述一样?preSum 中的两个元素之差其实就是区间和。

有了之前两道题的铺垫,我直接给出这道题的解法代码吧,思路见注释:

python 复制代码
class Solution:
    def __init__(self):
        # 定义实例变量
        self.lower = 0
        self.upper = 0
        self.count = 0
        self.temp = []

    def countRangeSum(self, nums, lower, upper):
        # 赋值实例变量
        self.lower = lower
        self.upper = upper
        
        # 构建前缀和数组,注意 int 可能溢出,用 long 存储
        preSum = [0] * (len(nums) + 1)
        for i in range(len(nums)):
            preSum[i + 1] = nums[i] + preSum[i]
            
        # 对前缀和数组进行归并排序
        self.temp = [0] * len(preSum)
        self.sort(preSum, 0, len(preSum) - 1)
        return self.count
        
    def sort(self, nums, lo, hi):
        if lo == hi:
            return
        mid = lo + (hi - lo) // 2
        self.sort(nums, lo, mid)
        self.sort(nums, mid + 1, hi)
        self.merge(nums, lo, mid, hi)

    def merge(self, nums, lo, mid, hi):
        # 填充 temp 数组
        for i in range(lo, hi + 1):
            self.temp[i] = nums[i]
        
        # 在合并有序数组之前加点私货(这段代码会超时)
        # for i in range(lo, mid + 1):
        #     for j in range(mid + 1, hi + 1):
        # 寻找符合条件的 nums[j]
        #
        #         delta = nums[j] - nums[i]
        #         if delta <= self.upper and delta >= self.lower:
        #             self.count += 1

        # 进行效率优化
        # 维护左闭右开区间 [start, end) 中的元素和 nums[i] 的差在 [lower, upper] 中
        start = mid + 1
        end = mid + 1
        for i in range(lo, mid + 1):
            # 如果 nums[i] 对应的区间是 [start, end),
            # 那么 nums[i+1] 对应的区间一定会整体右移,类似滑动窗口
            while start <= hi and nums[start] - nums[i] < self.lower:
                start += 1
            while end <= hi and nums[end] - nums[i] <= self.upper:
                end += 1
            self.count += end - start

        # 数组双指针技巧,合并两个有序数组
        i, j = lo, mid + 1
        for p in range(lo, hi + 1):
            if i == mid + 1:
                nums[p] = self.temp[j]
                j += 1
            elif j == hi + 1:
                nums[p] = self.temp[i]
                i += 1
            elif self.temp[i] > self.temp[j]:
                nums[p] = self.temp[j]
                j += 1
            else:
                nums[p] = self.temp[i]
                i += 1

我们依然在 merge 函数合并有序数组之前加了一些逻辑,如果看过前文 滑动窗口核心框架,这个效率优化有点类似维护一个滑动窗口,让窗口中的元素和 nums[i] 的差落在 [lower, upper] 中。

归并排序相关的题目到这里就讲完了,你现在回头体会下我在本文开头说那句话:

所有递归的算法,本质上都是在遍历一棵(递归)树,然后在节点(前中后序位置)上执行代码。你要写递归算法,本质上就是要告诉每个节点需要做什么

比如本文讲的归并排序算法,递归的 sort 函数就是二叉树的遍历函数,而 merge 函数就是在每个节点上做的事情,有没有品出点味道?

相关推荐
某个默默无闻奋斗的人39 分钟前
三傻排序的比较(选择,冒泡,插入)
java·数据结构·算法
黄雪超2 小时前
算法基础——一致性
大数据·算法·一致性
敲上瘾2 小时前
BFS(广度优先搜索)——搜索算法
数据结构·c++·算法·搜索引擎·宽度优先·图搜索算法
查理零世3 小时前
【算法】回溯算法专题② ——组合型回溯 + 剪枝 python
python·算法·剪枝
酒酿泡芙12173 小时前
前端力扣刷题 | 6:hot100之 矩阵
前端·leetcode·矩阵
JackieZhang.3 小时前
求水仙花数,提取算好,打表法。或者暴力解出来。
数据结构·算法
LUCIAZZZ4 小时前
Hot100之贪心算法
数据结构·算法·leetcode·贪心算法
h^hh4 小时前
堆的模拟实现(详解)c++
数据结构·c++·算法
tt5555555555554 小时前
每日一题——小根堆实现堆排序算法
c语言·数据结构·算法·面试·排序算法·八股文