"求数组中第 k 大的元素"是一个非常经典的问题。
最直接的做法是先排序,再取排序后的第 k 个元素。但排序的时间复杂度是 O(n log n)。能不能做到 O(n)?
问题定义
给定一个长度为 n 的数组 nums,求其中第 k 大的元素。
例如:
python
nums = [3, 2, 1, 5, 6, 4]
k = 2
第 2 大的元素是 5。
明确指标,我们这里说的是"第 k 大元素",不是"第 k 个不同的元素"。重复元素也要参与排名。
最常用的方法:Quickselect
这是工程和面试里最常见的做法。思想来自快速排序。
快速排序会选一个基准值 pivot,然后把数组划分成两部分,再递归处理左右两边。而 Quickselect 不一样,它只关心第 k 大在哪一边,因此 每次只递归进入一侧。
假设一次划分后,基准值左边有一部分更大的数,右边有一部分更小的数:
- 如果第
k大就在左边,那就只去左边找 - 如果刚好等于基准值,那直接返回
- 否则去右边找
每一轮的 partition 是 O(n),但不像快排那样两边都递归,而是只处理一边,所以平均复杂度是:
text
O(n) + O(n/2) + O(n/4) + ... = O(n)
为了更好处理有大量重复元素的情况,实现上推荐使用三路划分:
- 大于
pivot - 等于
pivot - 小于
pivot
python
import random
def kth_largest(nums, k):
def select(l, r, k_idx):
pivot = nums[random.randint(l, r)]
i, j, t = l, l, r
# 三路划分:
# [l, i-1] > pivot
# [i, j-1] == pivot
# [j, t] 未处理
# [t+1, r] < pivot
while j <= t:
if nums[j] > pivot:
nums[i], nums[j] = nums[j], nums[i]
i += 1
j += 1
elif nums[j] < pivot:
nums[j], nums[t] = nums[t], nums[j]
t -= 1
else:
j += 1
if k_idx < i:
return select(l, i - 1, k_idx)
elif k_idx <= t:
return pivot
else:
return select(t + 1, r, k_idx)
return select(0, len(nums) - 1, k - 1)
- 平均时间复杂度:
O(n) - 最坏时间复杂度:
O(n^2)(可能和快速排序一样在最倒霉的情况下退化)
理论上最强的方法:BFPRT
BFPRT 也叫 Median of Medians,中文常叫"中位数的中位数算法"。最坏时间复杂度也能保证是 O(n)
Quickselect 的问题在于:如果基准值选得很差,就可能退化。BFPRT 的思路是:不要随便选 pivot,而是精心挑一个"足够好"的 pivot。
流程:
- 把数组每 5 个元素分成一组
- 对每组排序,取出每组中位数
- 递归地求这些中位数的中位数
- 用这个"中位数的中位数"作为 pivot
这个 pivot 可以保证不会太差,于是每一轮都能排除掉相当一部分元素,最终把最坏复杂度压到 O(n)。
为什么分 5 个一组?可以证明:每轮都能稳定丢掉一个固定比例的元素 。这样递推式就会收敛到 O(n)。当然,分 3 个、7 个也可以讨论,但 5 是一个惯例选择。
最坏时间复杂度和平均时间复杂度都是O(n)
但是实际上,这算是在 quickselect 基础上增加程序避免落入最坏情况。实现复杂,常数较大,实际工程中随机数据通常跑不过随机化 Quickselect
值域小时的线性解法:计数 / 桶统计
如果数组中的数值范围不大,还有一种简单直接的线性做法:计数排序思想 。开一个计数数组 cnt表示某值出现了多少次,然后从大到小扫描。当累计个数达到 k 时,对应值就是第 k 大
python
def kth_largest_counting(nums, k):
mx = max(nums)
mn = min(nums)
offset = -mn
cnt = [0] * (mx - mn + 1)
for x in nums:
cnt[x + offset] += 1
s = 0
for i in range(len(cnt) - 1, -1, -1):
s += cnt[i]
if s >= k:
return i - offset
时间复杂度:O(n + U)。空间复杂度:O(U)。其中 U 是值域大小。只有在 U 不大时,这种方法才真正有优势。如果值域非常大,比如元素分布在 [-10^9, 10^9],那这个方法就不现实了。