本文参考labuladong算法笔记[拓展:归并排序详解及应用 | labuladong 的算法笔记]
"归并排序就是二叉树的后序遍历"------labuladong
就说归并排序吧,如果给你看代码,让你脑补一下归并排序的过程,你脑子里会出现什么场景?
这是一个数组排序算法,所以你脑补一个数组的 GIF,在那一个个交换元素?如果是这样的话,那格局就低了。
但如果你脑海中浮现出的是一棵二叉树,甚至浮现出二叉树后序遍历的场景,那格局就高了,大概率掌握了我经常强调的 框架思维,用这种抽象能力学习算法就省劲多了。
那么,归并排序明明就是一个数组算法,和二叉树有什么关系?接下来我就具体讲讲。
1、算法思路
就这么说吧,所有递归的算法,你甭管它是干什么的,本质上都是在遍历一棵(递归)树,然后在节点(前中后序位置)上执行代码,你要写递归算法,本质上就是要告诉每个节点需要做什么。
你看归并排序的代码框架:
python
# 定义:排序 nums[lo..hi]
def sort(nums: List[int], lo: int, hi: int) -> None:
if lo == hi:
return
mid = (lo + hi) // 2
# 利用定义,排序 nums[lo..mid]
sort(nums, lo, mid)
# 利用定义,排序 nums[mid+1..hi]
sort(nums, mid + 1, hi)
# ***** 后序位置 *****
# 此时两部分子数组已经被排好序
# 合并两个有序数组,使 nums[lo..hi] 有序
merge(nums, lo, mid, hi)
# 将有序数组 nums[lo..mid] 和有序数组 nums[mid+1..hi]
# 合并为有序数组 nums[lo..hi]
def merge(nums: List[int], lo: int, mid: int, hi: int) -> None:
pass
看这个框架,也就明白那句经典的总结:归并排序就是先把左半边数组排好序,再把右半边数组排好序,然后把两半数组合并。
上述代码和二叉树的后序遍历很像:
python
# 二叉树遍历框架
def traverse(root: TreeNode) -> None:
if root is None:
return
traverse(root.left)
traverse(root.right)
# 后序位置
print(root.val)
再进一步,你联想一下求二叉树的最大深度的算法代码:
python
# 定义:输入根节点,返回这棵二叉树的最大深度
def maxDepth(root: TreeNode) -> int:
if not root:
return 0
# 利用定义,计算左右子树的最大深度
leftMax = maxDepth(root.left)
rightMax = maxDepth(root.right)
# 整棵树的最大深度等于左右子树的最大深度取最大值,
# 然后再加上根节点自己
res = max(leftMax, rightMax) + 1
return res
本质上就是二叉树的后序遍历------先拿到子树,再合并子树
前文 手把手刷二叉树(纲领篇) 说二叉树问题可以分为两类思路,一类是遍历一遍二叉树的思路,另一类是分解问题的思路,根据上述类比,显然归并排序利用的是分解问题的思路(分治算法)。
归并排序的过程可以在逻辑上抽象成一棵二叉树,树上的每个节点的值可以认为是 nums[lo..hi]
,叶子节点的值就是数组中的单个元素:
然后,在每个节点的后序位置(左右子节点已经被排好序)的时候执行 merge
函数,合并两个子节点上的子数组:
这个 merge
操作会在二叉树的每个节点上都执行一遍,执行顺序是二叉树后序遍历的顺序。
后序遍历二叉树大家应该已经烂熟于心了,就是下图这个遍历顺序:
结合上述基本分析,我们把 nums[lo..hi]
理解成二叉树的节点,sort
函数理解成二叉树的遍历函数,整个归并排序的执行过程就是以下 GIF 描述的这样:
这样,归并排序的核心思路就分析完了,接下来只要把思路翻译成代码就行。
2、代码实现
只要拥有了正确的思维方式,理解算法思路是不困难的,但把思路实现成代码,也很考验一个人的编程能力。
毕竟算法的时间复杂度只是一个理论上的衡量标准,而算法的实际运行效率要考虑的因素更多,比如应该避免内存的频繁分配释放,代码逻辑应尽可能简洁等等。
python
class Merge:
# 用于辅助合并有序数组
temp = []
@staticmethod
def sort(nums):
# 先给辅助数组开辟内存空间
Merge.temp = [0] * len(nums)
# 排序整个数组(原地修改)
Merge._sort(nums, 0, len(nums) - 1)
# 定义:将子数组 nums[lo..hi] 进行排序
@staticmethod
def _sort(nums, lo, hi):
if lo == hi:
# 单个元素不用排序
return
# 这样写是为了防止溢出,效果等同于 (hi + lo) / 2
mid = lo + (hi - lo) // 2
# 先对左半部分数组 nums[lo..mid] 排序
Merge._sort(nums, lo, mid)
# 再对右半部分数组 nums[mid+1..hi] 排序
Merge._sort(nums, mid + 1, hi)
# 将两部分有序数组合并成一个有序数组
Merge.merge(nums, lo, mid, hi)
# 将 nums[lo..mid] 和 nums[mid+1..hi] 这两个有序数组合并成一个有序数组
@staticmethod
def merge(nums, lo, mid, hi):
# 先把 nums[lo..hi] 复制到辅助数组中
# 以便合并后的结果能够直接存入 nums
for i in range(lo, hi+1):
Merge.temp[i] = nums[i]
# 数组双指针技巧,合并两个有序数组
i, j = lo, mid + 1
for p in range(lo, hi+1):
if i == mid + 1:
# 左半边数组已全部被合并
nums[p] = Merge.temp[j]
j += 1
elif j == hi + 1:
# 右半边数组已全部被合并
nums[p] = Merge.temp[i]
i += 1
elif Merge.temp[i] > Merge.temp[j]:
nums[p] = Merge.temp[j]
j += 1
else:
nums[p] = Merge.temp[i]
i += 1
有了之前的铺垫,这里只需要着重讲一下这个 merge
函数。
sort
函数对 nums[lo..mid]
和 nums[mid+1..hi]
递归排序完成之后,我们没有办法原地把它俩合并,所以需要 copy 到 temp
数组里面,然后通过类似于前文 单链表的六大技巧 中合并有序链表的双指针技巧将 nums[lo..hi]
合并成一个有序数组:
注意我们不是在 merge
函数执行的时候 new 辅助数组,而是提前把 temp 辅助数组 new 出来了,这样就避免了在递归中频繁分配和释放内存可能产生的性能问题。
3、复杂度分析
再说一下归并排序的时间复杂度,虽然大伙儿应该都知道是O(NlogN),但不见得所有人都知道这个复杂度怎么算出来的。
前文 动态规划详解 说过递归算法的复杂度计算,就是子问题个数 x 解决一个子问题的复杂度。对于归并排序来说,时间复杂度显然集中在 merge
函数遍历 nums[lo..hi]
的过程,但每次 merge
输入的 lo
和 hi
都不同,所以不容易直观地看出时间复杂度。
merge
函数到底执行了多少次?每次执行的时间复杂度是多少?总的时间复杂度是多少?这就要结合之前画的这幅图来看:
执行的次数是二叉树节点的个数,每次执行的复杂度就是每个节点代表的子数组的长度,所以总的时间复杂度就是整棵树中「数组元素」的个数。
所以从整体上看,这个二叉树的高度是 logN
,其中每一层的元素个数就是原数组的长度 N
,所以总的时间复杂度就是 O(NlogN)O(NlogN)。
4、力扣练习
912. 排序数组
给你一个整数数组 nums
,请你将该数组升序排列。
你必须在 不使用任何内置函数 的情况下解决问题,时间复杂度为 O(nlog(n))
,并且空间复杂度尽可能小。
示例 1:
输入:nums = [5,2,3,1]
输出:[1,2,3,5]
示例 2:
输入:nums = [5,1,1,2,0,0]
输出:[0,0,1,1,2,5]
提示:
1 <= nums.length <= 5 * 104
-5 * 104 <= nums[i] <= 5 * 104
我们可以直接套用归并排序代码模板:
python
class Solution:
def sortArray(self, nums: List[int]) -> List[int]:
Merge.sort(nums)
return nums
class Merge:
# 见上文
简化写法
python
class Solution:
def sortArray(self, nums: List[int]) -> List[int]:
tmp_arr = [0] * len(nums)
def sort(nums, lo, hi):
if lo == hi:
return
mid = lo + (hi - lo) // 2
sort(nums, lo, mid)
sort(nums, mid + 1, hi)
merge(nums, lo, mid, hi)
def merge(nums, lo, mid, hi):
for i in range(lo, hi + 1):
tmp_arr[i] = nums[i]
i, j = lo, mid + 1
for p in range(lo, hi + 1):
if i == mid + 1:
nums[p] = tmp_arr[j]
j += 1
elif j == hi + 1:
nums[p] = tmp_arr[i]
i += 1
elif tmp_arr[i] < tmp_arr[j]:
nums[p] = tmp_arr[i]
i += 1
else:
nums[p] = tmp_arr[j]
j += 1
sort(nums, 0, len(nums) - 1)
return nums
315. 计算右侧小于当前元素的个数 | 力扣 | LeetCode | 🔴
给你一个整数数组 nums
,按要求返回一个新数组 counts
。数组 counts
有该性质: counts[i]
的值是 nums[i]
右侧小于 nums[i]
的元素的数量。
示例 1:
输入:nums = [5,2,6,1]
输出:[2,1,1,0]
解释:
5 的右侧有 2 个更小的元素 (2 和 1)
2 的右侧仅有 1 个更小的元素 (1)
6 的右侧有 1 个更小的元素 (1)
1 的右侧有 0 个更小的元素
示例 2:
输入:nums = [-1]
输出:[0]
示例 3:
输入:nums = [-1,-1]
输出:[0,0]
提示:
1 <= nums.length <= 105
-104 <= nums[i] <= 104
我用比较数学的语言来描述一下(方便和后续类似题目进行对比),题目让你求出一个 count
数组,使得:
python
count[i] = COUNT(j) where j > i and nums[j] < nums[i]
拍脑袋的暴力解法就不说了,嵌套 for 循环,平方级别的复杂度。
这题和归并排序什么关系呢,主要在 merge
函数,我们在使用 merge
函数合并两个有序数组的时候,其实是可以知道一个元素 nums[i]
后边有多少个元素比 nums[i]
小的。
具体来说,比如这个场景:
这时候我们应该把 temp[i]
放到 nums[p]
上,因为 temp[i] < temp[j]
。
但就在这个场景下,我们还可以知道一个信息:5 后面比 5 小的元素个数就是 左闭右开区间 [mid + 1, j) 中的元素个数,即 2 和 4 这两个元素:
换句话说,在对 nums[lo..hi]
合并的过程中,每当执行 nums[p] = temp[i]
时,就可以确定 temp[i]
这个元素后面比它小的元素个数为 j - mid - 1
。
当然,nums[lo..hi]
本身也只是一个子数组,这个子数组之后还会被执行 merge
,其中元素的位置还是会改变。但这是其他递归节点需要考虑的问题,我们只要在 merge
函数中做一些手脚,叠加每次 merge
时记录的结果即可。
发现了这个规律后,我们只要在 merge
中添加两行代码即可解决这个问题,看解法代码:
python
class Solution:
class Pair:
def __init__(self, val, id):
# 记录数组的元素值
self.val = val
# 记录元素在数组中的原始索引
self.id = id
# 归并排序所用的辅助数组
temp = []
# 记录每个元素后面比自己小的元素个数
count = []
# 主函数
def countSmaller(self, nums):
n = len(nums)
self.count = [0] * n
self.temp = [self.Pair(0, 0)] * n
arr = []
# 记录元素原始的索引位置,以便在 count 数组中更新结果
for i in range(n):
arr.append(self.Pair(nums[i], i))
# 执行归并排序,本题结果被记录在 count 数组中
self.sort(arr, 0, n - 1)
return self.count
# 归并排序
def sort(self, arr, lo, hi):
if lo == hi:
return
mid = lo + (hi - lo) // 2
self.sort(arr, lo, mid)
self.sort(arr, mid + 1, hi)
self.merge(arr, lo, mid, hi)
# 合并两个有序数组
def merge(self, arr, lo, mid, hi):
for i in range(lo, hi + 1):
self.temp[i] = arr[i]
i, j = lo, mid + 1
for p in range(lo, hi + 1):
if i == mid + 1:
arr[p] = self.temp[j]
j += 1
elif j == hi + 1:
arr[p] = self.temp[i]
i += 1
# 此时j来到hi+1处,说明右侧数组所有元素均小于arr[p]
self.count[arr[p].id] += j - mid - 1
elif self.temp[i].val > self.temp[j].val:
arr[p] = self.temp[j]
j += 1
else:
# 此时temp[mid+1, j)的的元素都小于temp[i]
arr[p] = self.temp[i]
i += 1
self.count[arr[p].id] += j - mid - 1
因为在排序过程中,每个元素的索引位置会不断改变,所以我们用一个 Pair 类封装每个元素及其在原始数组 nums 中的索引,以便 count 数组记录每个元素之后小于它的元素个数。
493. 翻转对 | 力扣 | LeetCode | 🔴
给定一个数组 nums
,如果 i < j
且 nums[i] > 2*nums[j]
我们就将 (i, j)
称作一个重要翻转对。
你需要返回给定数组中的重要翻转对的数量。
示例 1:
输入: [1,3,2,3,1]
输出: 2
示例 2:
输入: [2,4,3,5,1]
输出: 3
注意:
- 给定数组的长度不会超过
50000
。 - 输入数组中的所有数字都在32位整数的表示范围内。
我把这道题换个表述方式,你注意和上一道题目对比:
请你先求出一个 count
数组,其中:
python
count[i] = COUNT(j) where j > i and nums[i] > 2*nums[j]
然后请你求出这个 count
数组中所有元素的和。
你看,这样说其实和题目是一个意思,而且和上一道题非常类似,只不过上一题求的是 nums[i] > nums[j]
,这里求的是 nums[i] > 2*nums[j]
罢了。
所以解题的思路当然还是要在 merge
函数中做点手脚,当 nums[lo..mid]
和 nums[mid+1..hi]
两个子数组完成排序后,对于 nums[lo..mid]
中的每个元素 nums[i]
,去 nums[mid+1..hi]
中寻找符合条件的 nums[j]
就行了。
看一下我们对上一题 merge
函数的改造:
python
# 记录「翻转对」的个数
count = 0
# 将 nums[lo..mid] 和 nums[mid+1..hi] 这两个有序数组合并成一个有序数组
def merge(nums, lo, mid, hi):
global count
temp = nums.copy()
# 在合并有序数组之前,加点私货
for i in range(lo, mid + 1):
# 对于左半边的每个 nums[i],都去右半边寻找符合条件的元素
for j in range(mid + 1, hi + 1):
# nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long
if int(nums[i]) > 2 * int(nums[j]):
count += 1
# 数组双指针技巧,合并两个有序数组
i, j = lo, mid + 1
for p in range(lo, hi + 1):
if (i == mid + 1):
nums[p] = temp[j]
j += 1
elif (j == hi + 1):
nums[p] = temp[i]
i += 1
elif temp[i] > temp[j]:
nums[p] = temp[j]
j += 1
else:
nums[p] = temp[i]
i += 1
不过呢,这样修改代码会超时,毕竟额外添加了一个嵌套 for 循环。怎么进行优化呢,注意子数组 nums[lo..mid]
是排好序的,也就是 nums[i] <= nums[i+1]
。
所以,对于 nums[i], lo <= i <= mid
,我们在找到的符合 nums[i] > 2*nums[j]
的 nums[j], mid+1 <= j <= hi
,也必然也符合 nums[i+1] > 2*nums[j]
。
换句话说,我们不用每次都傻乎乎地去遍历整个 nums[mid+1..hi]
,只要维护一个开区间边界 end
,维护 nums[mid+1..end-1]
是符合条件的元素即可。
看最终的解法代码:
python
class Solution:
def __init__(self):
# 记录「翻转对」的个数
self.count = 0
self.temp = []
def reversePairs(self, nums) -> int:
# 执行归并排序
self.temp = [0]*len(nums)
self.sort(nums, 0, len(nums) - 1)
return self.count
def sort(self, nums, lo, hi):
# 归并排序
if lo >= hi:
return
mid = lo + (hi - lo) // 2
self.sort(nums, lo, mid)
self.sort(nums, mid + 1, hi)
self.merge(nums, lo, mid, hi)
def merge(self, nums, lo, mid, hi):
for i in range(lo, hi+1):
self.temp[i] = nums[i]
# 维护左闭右开区间 [mid+1, end) 中的元素乘 2 小于 nums[i]
# 为什么 end 是开区间?因为这样的话可以保证初始区间 [mid+1, mid+1) 是一个空区间
end = mid + 1
for i in range(lo, mid+1):
# nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long
while end <= hi and nums[i] > nums[end] * 2:
end += 1
self.count += end - (mid + 1)
# 数组双指针技巧,合并两个有序数组
i, j = lo, mid + 1
for p in range(lo, hi+1):
if i == mid + 1:
nums[p] = self.temp[j]
j += 1
elif j == hi + 1:
nums[p] = self.temp[i]
i += 1
elif self.temp[i] > self.temp[j]:
nums[p] = self.temp[j]
j += 1
else:
nums[p] = self.temp[i]
i += 1
327. 区间和的个数 | 力扣 | LeetCode | 🔴
给你一个整数数组 nums
以及两个整数 lower
和 upper
。求数组中,值位于范围 [lower, upper]
(包含 lower
和 upper
)之内的 区间和的个数 。
区间和 S(i, j)
表示在 nums
中,位置从 i
到 j
的元素之和,包含 i
和 j
(i
≤ j
)。
示例 1:
输入:nums = [-2,5,-1], lower = -2, upper = 2
输出:3
解释:存在三个区间:[0,0]、[2,2] 和 [0,2] ,对应的区间和分别是:-2 、-1 、2 。
示例 2:
输入:nums = [0], lower = 0, upper = 0
输出:1
提示:
1 <= nums.length <= 105
-231 <= nums[i] <= 231 - 1
-105 <= lower <= upper <= 105
- 题目数据保证答案是一个 32 位 的整数
简单说,题目让你计算元素和落在 [lower, upper]
中的所有子数组的个数。
拍脑袋的暴力解法我就不说了,依然是嵌套 for 循环,这里还是说利用归并排序实现的高效算法。
首先,解决这道题需要快速计算子数组的和,所以你需要阅读前文 前缀和数组技巧,创建一个前缀和数组 preSum
来辅助我们迅速计算区间和。
我继续用比较数学的语言来表述下这道题,题目让你通过 preSum
数组求一个 count
数组,使得:
python
count[i] = COUNT(j) where lower <= preSum[j] - preSum[i] <= upper
然后请你求出这个 count
数组中所有元素的和。
你看,这是不是和题目描述一样?preSum
中的两个元素之差其实就是区间和。
有了之前两道题的铺垫,我直接给出这道题的解法代码吧,思路见注释:
python
class Solution:
def __init__(self):
# 定义实例变量
self.lower = 0
self.upper = 0
self.count = 0
self.temp = []
def countRangeSum(self, nums, lower, upper):
# 赋值实例变量
self.lower = lower
self.upper = upper
# 构建前缀和数组,注意 int 可能溢出,用 long 存储
preSum = [0] * (len(nums) + 1)
for i in range(len(nums)):
preSum[i + 1] = nums[i] + preSum[i]
# 对前缀和数组进行归并排序
self.temp = [0] * len(preSum)
self.sort(preSum, 0, len(preSum) - 1)
return self.count
def sort(self, nums, lo, hi):
if lo == hi:
return
mid = lo + (hi - lo) // 2
self.sort(nums, lo, mid)
self.sort(nums, mid + 1, hi)
self.merge(nums, lo, mid, hi)
def merge(self, nums, lo, mid, hi):
# 填充 temp 数组
for i in range(lo, hi + 1):
self.temp[i] = nums[i]
# 在合并有序数组之前加点私货(这段代码会超时)
# for i in range(lo, mid + 1):
# for j in range(mid + 1, hi + 1):
# 寻找符合条件的 nums[j]
#
# delta = nums[j] - nums[i]
# if delta <= self.upper and delta >= self.lower:
# self.count += 1
# 进行效率优化
# 维护左闭右开区间 [start, end) 中的元素和 nums[i] 的差在 [lower, upper] 中
start = mid + 1
end = mid + 1
for i in range(lo, mid + 1):
# 如果 nums[i] 对应的区间是 [start, end),
# 那么 nums[i+1] 对应的区间一定会整体右移,类似滑动窗口
while start <= hi and nums[start] - nums[i] < self.lower:
start += 1
while end <= hi and nums[end] - nums[i] <= self.upper:
end += 1
self.count += end - start
# 数组双指针技巧,合并两个有序数组
i, j = lo, mid + 1
for p in range(lo, hi + 1):
if i == mid + 1:
nums[p] = self.temp[j]
j += 1
elif j == hi + 1:
nums[p] = self.temp[i]
i += 1
elif self.temp[i] > self.temp[j]:
nums[p] = self.temp[j]
j += 1
else:
nums[p] = self.temp[i]
i += 1
我们依然在 merge
函数合并有序数组之前加了一些逻辑,如果看过前文 滑动窗口核心框架,这个效率优化有点类似维护一个滑动窗口,让窗口中的元素和 nums[i]
的差落在 [lower, upper]
中。
归并排序相关的题目到这里就讲完了,你现在回头体会下我在本文开头说那句话:
所有递归的算法,本质上都是在遍历一棵(递归)树,然后在节点(前中后序位置)上执行代码。你要写递归算法,本质上就是要告诉每个节点需要做什么。
比如本文讲的归并排序算法,递归的 sort
函数就是二叉树的遍历函数,而 merge
函数就是在每个节点上做的事情,有没有品出点味道?