最小的 k 个数
面试题 17.14. 最小K个数
设计一个算法,找出数组中 最小的 k 个数 。以 任意顺序 返回这 k 个数均可。
示例:
输入: arr = [1,3,5,7,2,4,6,8], k = 4
输出: [1,2,3,4]
提示:
0 <= len(arr) <= 1000000 <= k <= min(100000, len(arr))
快速选择算法(分治)
利用快速排序恰好有 「左侧元素」 小于等于 「基准值」,「右侧元素」 大于 「基准值」 的特性,我们可以基于快速排序算法划分数组,也成为快速选择。
注意任意顺序是核心,否则不能使用快速选择算法。
cpp
// 注意到题目要求「任意顺序返回这 k 个数即可」,因此我们只需要确保前 k 小的数都出现在下标为 [0,k) 的位置即可。
// 而快速排序恰好有 「左侧元素」 小于等于 「基准值」,「右侧元素」 大于 「基准值」 的特性,因此我们可以利用快速排序划分数组
// 在面试中,如果面试官追加一句:"如果是 100 亿个数据,内存放不下怎么办?"这时候就必须用堆了。
// 我们只要维护一个容量为 k 的大根堆:新来的元素如果比堆顶小,就把堆顶踢掉,自己进去。时间复杂度是 O(Nlogk)。
class Solution {
// void quick_select(vector<int> &arr, int target_index, int l, int r) {
// if(l >= r) return ;
// int pivot = arr[(l + r) / 2];
// int i = l - 1, j = r + 1;
// while(i < j) {
// do i ++ ; while(arr[i] < pivot);
// do j -- ; while(arr[j] > pivot);
// if(i < j) swap(arr[i], arr[j]);
// }
// if(j > target_index) quick_select(arr, target_index, l, j);
// if(j + 1 <= target_index) quick_select(arr, target_index, j + 1, r);
// }
void quick_select(vector<int> &arr, int target_index, int l, int r) {
if(l >= r) return ;
int pivot = arr[(l + r + 1) / 2];
int i = l - 1, j = r + 1;
while(i < j) {
do i ++ ; while(arr[i] < pivot);
do j -- ; while(arr[j] > pivot);
if(i < j) swap(arr[i], arr[j]);
}
if(i - 1 > target_index) quick_select(arr, target_index, l, i - 1);
if(i <= target_index) quick_select(arr, target_index, i, r);
}
// 注意如果我们使用随机数生成 pivot,最好使用 j 作为边界而不是 i
public:
vector<int> smallestK(vector<int>& arr, int k) {
if(k == 0) return {};
if(k >= arr.size()) return arr;
quick_select(arr, k - 1, 0, arr.size() - 1);
return vector<int>(arr.begin(), arr.begin() + k);
}
};
时间复杂度 :平均 O ( N ) O(N) O(N)。因为每次划分大概能排除掉一半的数据,计算量递减( N + N / 2 + N / 4 + ⋯ ≈ 2 N N + N/2 + N/4 + \dots \approx 2N N+N/2+N/4+⋯≈2N)。最坏情况是数组已经有序,每次只排除一个元素,会退化到 O ( N 2 ) O(N^2) O(N2)(实际工程中通常会引入随机挑选基准来避免这种极端情况)。
空间复杂度 :平均 O ( log N ) O(\log N) O(logN)。主要开销来自递归调用的函数栈深度。最坏情况下(时间退化的同时),递归树变成一条直线,空间复杂度退化为 O ( N ) O(N) O(N)。
什么时候用堆排序
在面试中,如果面试官追加一句:"如果是 100 亿个数据,内存放不下怎么办?"这时候就必须用堆了。
我们只要维护一个容量为 k k k 的大根堆:新来的元素如果比堆顶小,就把堆顶踢掉,自己进去。时间复杂度是 O ( N log k ) O(N \log k) O(Nlogk)。
第 k 个数
215. 数组中的第K个最大元素
给定整数数组 nums 和整数 k,请返回数组中第 **k** 个最大的元素。
请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。
你必须设计并实现时间复杂度为 O(n) 的算法解决此问题。
示例 1:
输入: [3,2,1,5,6,4], k = 2
输出: 5
示例 2:
输入: [3,2,3,1,2,4,5,5,6], k = 4
输出: 4
提示:
1 <= k <= nums.length <= 105-104 <= nums[i] <= 104
快速选择算法(分治)
cpp
// 快速选择算法 O(n) + O(n)
int findKthLargest(vector<int>& nums, int k) {
// 由于我们后面需要 % nums.size()
// 因此这里要确保 nums.size() 不为 0
if(!nums.size()) return 0;
vector<int> small, equal, big;
int pivot = nums[rand() % nums.size()];
for(auto &x : nums) {
if(x == pivot) equal.push_back(x);
else if(x > pivot) big.push_back(x);
else small.push_back(x);
}
if(k <= big.size()) return findKthLargest(big, k);
if(k <= big.size() + equal.size()) return pivot;
return findKthLargest(small, k - (big.size() + equal.size()));
}
计数排序
cpp
// 计数排序 O(n) + O(n)
int findKthLargest_1(vector<int>& nums, int k) {
k = nums.size() - k + 1;
int l = INT_MAX, r = INT_MIN;
unordered_map<int,int> cnt;
for(auto &x : nums) {
++ cnt[x];
l = min(l, x);
r = max(r, x);
}
for(int i = l; i <= r; i ++ ) {
if((k -= cnt[i]) <= 0) return i;
}
// never go there
return INT_MAX;
}
第 k 个数(拓展)
4. 寻找两个正序数组的中位数
给定两个大小分别为 m 和 n 的正序(从小到大)数组 nums1 和 nums2。请你找出并返回这两个正序数组的 中位数 。
算法的时间复杂度应该为 O(log (m+n)) 。
示例 1:
输入:nums1 = [1,3], nums2 = [2]
输出:2.00000
解释:合并数组 = [1,2,3] ,中位数 2
示例 2:
输入:nums1 = [1,2], nums2 = [3,4]
输出:2.50000
解释:合并数组 = [1,2,3,4] ,中位数 (2 + 3) / 2 = 2.5
提示:
nums1.length == mnums2.length == n0 <= m <= 10000 <= n <= 10001 <= m + n <= 2000-106 <= nums1[i], nums2[i] <= 106
分治 + 递归
这道题的意思其实很简单,就是让我们求两个数组的中位数。我们假设我们已经将两个数组排序成一个长度为 n+m 的数组,那么:
- 当
n+m为奇数时:求第 n 2 \frac{n}{2} 2n 个元素 - 当
n+m为偶数时,求第 n − 1 2 \frac{n-1}{2} 2n−1 和第 n 2 \frac{n}{2} 2n 个元素取平均
注意都是下取整
那么其实我们就把问题转换为了求第 k k k 小数,只不过相较于朴素的第 k 小数,这里我们要在两个有序数组中找到第 k k k 小的数。
我们可以基于这样的思想:每次排除掉 k 2 \frac{k}{2} 2k 个元素,直到最后只剩一个元素。
cpp
// REF:windliang
// 在两个有序数组中求第 k 小数,每次排除掉 k/2 个元素,k-=k/2
class Solution {
// 在两个有序数组中求第 k 小数
double recursion(vector<int> &nums1, int l1, int r1, vector<int> &nums2, int l2, int r2, int k) {
// 放了方便处理数组为空的情况,我们总是令 nums1 的长度大于等于 nums2 的长度
if(r2 - l2 > r1 - l1) return recursion(nums2, l2, r2, nums1, l1, r1, k);
// 递归结束条件
if(l2 > r2) return nums1[l1 + k - 1];
if(k == 1) return min(nums1[l1], nums2[l2]);
// 每次排除掉 k/2 个元素,注意 nums2 可能不够 k/2 个元素,但 nums1 一定有 k/2 个元素(不做证明)
int cnt1 = min(k / 2, r1 - l1 + 1);
int cnt2 = min(k / 2, r2 - l2 + 1);
// 子递归
if(nums1[l1 + cnt1 - 1] < nums2[l2 + cnt2 - 1])
return recursion(nums1, l1 + cnt1, r1, nums2, l2, r2, k - cnt1); // 抛弃掉 nums1[l1, l1 + cnt1 - 1] 共 cnt1 个元素
return recursion(nums1, l1, r1, nums2, l2 + cnt2, r2, k - cnt2); // 抛弃掉 nums2[l2, l2 + cnt2 - 1] 共 cnt2 个元素
}
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
// 我们这里规定第 1 小数对应 arr[0],即数组下标从 0 开始
// 0 1 (2) 3 4 -> idx=size/2=5/2=2 -> a[idx]
// 0 1 (2) (3) 4 5 -> idx=size/2=6/2=3 -> (a[idx-1] + a[idx]) / 2
int n = nums1.size(), m = nums2.size();
int idx = (n + m) / 2;
// 注意由数组下标 idx 转换为 k 需要 +1,即 k=idx+1
if((n + m) & 1) return recursion(nums1, 0, n - 1, nums2, 0, m - 1, idx + 1);
return (recursion(nums1, 0, n - 1, nums2, 0, m - 1, idx) + recursion(nums1, 0, n - 1, nums2, 0, m - 1, idx + 1)) / 2.0;
}
};
- 时间复杂度: O ( l o g ( m + n ) ) O(log(m+n)) O(log(m+n)) 咱们的核心逻辑是"找第 k 小的数"。每次递归,都会对比并精准淘汰掉 k /2 个绝对不可能的元素。这相当于每次都把查找范围暴力缩小一半。因为 k 的初始值差不多是总长度的一半,这种不断"折半砍"的动作,执行次数也就是 log(m +n) 级别。
- 空间复杂度: O ( l o g ( m + n ) ) O(log(m+n)) O(log(m+n)) 代码里咱们全程都是拿着索引(
l1、r1等)在原数组上比对,没有新建任何数组或哈希表。唯一的内存消耗来自递归调用时产生的系统函数栈。既然递归往下走了 log(m +n ) 层,栈的深度自然就是 O (log(m +n))。
为什么每次排除掉 1 2 \frac{1}{2} 21 的元素
简单来说,选 k / 2 k/2 k/2 是为了在**"保证绝对安全"的前提下,做到"最极致的效率"**。核心就这 3 点:
- 追求二分的速度: 如果我们每次只比较并排除 1 个元素(比如用双指针挨个比),那找第 k k k 个数就要循环 k k k 次,时间复杂度就退化成了龟速的 O ( k ) O(k) O(k)。为了满足题目对数级别的要求,我们必须像二分查找一样,每次把目标范围直接"砍掉一半"。
- 数学上的绝对安全: 为什么排掉较小的那 k / 2 k/2 k/2 个数绝对不会"误杀"目标?假设我们比较了两个数组的第 k / 2 k/2 k/2 个数(设为 A A A 和 B B B)。如果 A < B A < B A<B,那么 A A A 这个数,撑死了也只能大于
nums1里的 k / 2 − 1 k/2 - 1 k/2−1 个数,以及nums2里的 k / 2 − 1 k/2 - 1 k/2−1 个数。加起来总共才 k − 2 k-2 k−2 个数。这意味着, A A A 充其量也就是全局第 k − 1 k-1 k−1 小的数。所以,A A A 以及它前面的所有数,连成为第 k k k 小数的资格都没有,直接整锅端掉是绝对安全的。 - 为什么不排除更多?: 比如你想一次性排除 2 k / 3 2k/3 2k/3 个元素来加速,那就无法保证上面的数学推导了,你极有可能会把真正的答案给误删掉。 k / 2 k/2 k/2 就是理论上能一口气安全切掉的最大极限。
每次切掉一半的 k k k,这才是真正的"分治"艺术。
迭代写法
把递归改成迭代(也就是用 while 循环),核心思想完全没变,只是把函数自己调用自己,变成了不断更新指针。
这样一来,连递归那点系统栈的内存都省了,空间复杂度直接干到绝对的 O ( 1 ) O(1) O(1)。
cpp
class Solution {
// 迭代版的求第 K 小数
double getKth(vector<int>& nums1, vector<int>& nums2, int k) {
int m = nums1.size(), n = nums2.size();
// l1 和 l2 相当于两个指针,记录两个数组当前还没被淘汰的起始位置
int l1 = 0, l2 = 0;
while (true) {
// 边界情况 1:nums1 里的数全被淘汰光了,直接去 nums2 里找剩下的第 k 个
if (l1 == m) return nums2[l2 + k - 1];
// 边界情况 2:nums2 里的数全被淘汰光了
if (l2 == n) return nums1[l1 + k - 1];
// 边界情况 3:k 减到了 1,说明下一个就是要找的数,直接挑俩指针当前指着的最小值
if (k == 1) return min(nums1[l1], nums2[l2]);
// 各自往前看 k/2 个元素,注意加上当前的偏移量 l1/l2,并且千万别越界
int idx1 = min(l1 + k / 2 - 1, m - 1);
int idx2 = min(l2 + k / 2 - 1, n - 1);
// 谁的"界碑"小,就大胆淘汰谁及它前面的那截元素
if (nums1[idx1] <= nums2[idx2]) {
// 精准扣减 k 值:减去这次实际淘汰的元素个数
k -= (idx1 - l1 + 1);
// 把 nums1 的指针推到被淘汰元素的下一个位置
l1 = idx1 + 1;
} else {
k -= (idx2 - l2 + 1);
l2 = idx2 + 1;
}
}
}
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int n = nums1.size(), m = nums2.size();
int totalLength = n + m;
// 奇偶情况拆分,复用 getKth 逻辑
if (totalLength % 2 == 1) {
return getKth(nums1, nums2, totalLength / 2 + 1);
} else {
return (getKth(nums1, nums2, totalLength / 2) +
getKth(nums1, nums2, totalLength / 2 + 1)) / 2.0;
}
}
};
迭代版的核心变化就这三点:
- 循环代替递归: 用一个
while (true)死循环罩住,满足边界条件就直接return结果。 - 指针推进: 之前是靠传递各种
l1, r1缩小范围,现在只需要维护l1和l2两个起始指针,不断向右推,相当于逻辑上把左边的元素给"截断"了。 - 内存极简: 除了几个整型变量,什么额外空间都没开辟。