K 小数问题

最小的 k 个数

面试题 17.14. 最小K个数

设计一个算法,找出数组中 最小的 k 个数 。以 任意顺序 返回这 k 个数均可。

示例:

复制代码
输入: arr = [1,3,5,7,2,4,6,8], k = 4
输出: [1,2,3,4]

提示:

  • 0 <= len(arr) <= 100000
  • 0 <= k <= min(100000, len(arr))

快速选择算法(分治)

利用快速排序恰好有 「左侧元素」 小于等于 「基准值」,「右侧元素」 大于 「基准值」 的特性,我们可以基于快速排序算法划分数组,也成为快速选择。

注意任意顺序是核心,否则不能使用快速选择算法。

cpp 复制代码
// 注意到题目要求「任意顺序返回这 k 个数即可」,因此我们只需要确保前 k 小的数都出现在下标为 [0,k) 的位置即可。
// 而快速排序恰好有 「左侧元素」 小于等于 「基准值」,「右侧元素」 大于 「基准值」 的特性,因此我们可以利用快速排序划分数组

// 在面试中,如果面试官追加一句:"如果是 100 亿个数据,内存放不下怎么办?"这时候就必须用堆了。
// 我们只要维护一个容量为 k 的大根堆:新来的元素如果比堆顶小,就把堆顶踢掉,自己进去。时间复杂度是 O(Nlogk)。

class Solution {
    // void quick_select(vector<int> &arr, int target_index, int l, int r) {
    //     if(l >= r)  return ;
    //     int pivot = arr[(l + r) / 2];
    //     int i = l - 1, j = r + 1;
    //     while(i < j) {
    //         do i ++ ; while(arr[i] < pivot);
    //         do j -- ; while(arr[j] > pivot);
    //         if(i < j) swap(arr[i], arr[j]);
    //     }
    //     if(j > target_index) quick_select(arr, target_index, l, j);
    //     if(j + 1 <= target_index) quick_select(arr, target_index, j + 1, r);
    // }

    void quick_select(vector<int> &arr, int target_index, int l, int r) {
        if(l >= r)  return ;
        int pivot = arr[(l + r + 1) / 2];
        int i = l - 1, j = r + 1;
        while(i < j) {
            do i ++ ; while(arr[i] < pivot);
            do j -- ; while(arr[j] > pivot);
            if(i < j) swap(arr[i], arr[j]);
        }
        if(i - 1 > target_index) quick_select(arr, target_index, l, i - 1);
        if(i <= target_index) quick_select(arr, target_index, i, r);
    }

    // 注意如果我们使用随机数生成 pivot,最好使用 j 作为边界而不是 i
public:
    vector<int> smallestK(vector<int>& arr, int k) {
        if(k == 0)  return {};
        if(k >= arr.size()) return arr;
        quick_select(arr, k - 1, 0, arr.size() - 1);
        return vector<int>(arr.begin(), arr.begin() + k);
    }
};

时间复杂度 :平均 O ( N ) O(N) O(N)。因为每次划分大概能排除掉一半的数据,计算量递减( N + N / 2 + N / 4 + ⋯ ≈ 2 N N + N/2 + N/4 + \dots \approx 2N N+N/2+N/4+⋯≈2N)。最坏情况是数组已经有序,每次只排除一个元素,会退化到 O ( N 2 ) O(N^2) O(N2)(实际工程中通常会引入随机挑选基准来避免这种极端情况)。

空间复杂度 :平均 O ( log ⁡ N ) O(\log N) O(logN)。主要开销来自递归调用的函数栈深度。最坏情况下(时间退化的同时),递归树变成一条直线,空间复杂度退化为 O ( N ) O(N) O(N)。

什么时候用堆排序

在面试中,如果面试官追加一句:"如果是 100 亿个数据,内存放不下怎么办?"这时候就必须用堆了。

我们只要维护一个容量为 k k k 的大根堆:新来的元素如果比堆顶小,就把堆顶踢掉,自己进去。时间复杂度是 O ( N log ⁡ k ) O(N \log k) O(Nlogk)。

第 k 个数

215. 数组中的第K个最大元素

给定整数数组 nums 和整数 k,请返回数组中第 **k** 个最大的元素。

请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。

你必须设计并实现时间复杂度为 O(n) 的算法解决此问题。

示例 1:

复制代码
输入: [3,2,1,5,6,4], k = 2
输出: 5

示例 2:

复制代码
输入: [3,2,3,1,2,4,5,5,6], k = 4
输出: 4

提示:

  • 1 <= k <= nums.length <= 105
  • -104 <= nums[i] <= 104

快速选择算法(分治)

cpp 复制代码
// 快速选择算法 O(n) + O(n)
int findKthLargest(vector<int>& nums, int k) {
    // 由于我们后面需要 % nums.size()
    // 因此这里要确保 nums.size() 不为 0 
    if(!nums.size())    return 0;
    vector<int> small, equal, big;
    int pivot = nums[rand() % nums.size()];
    for(auto &x : nums) {
        if(x == pivot) equal.push_back(x);
        else if(x > pivot) big.push_back(x);
        else small.push_back(x);
    }
    if(k <= big.size())  return findKthLargest(big, k);
    if(k <= big.size() + equal.size())  return pivot;
    return findKthLargest(small, k - (big.size() + equal.size()));
}

计数排序

cpp 复制代码
// 计数排序 O(n) + O(n)
int findKthLargest_1(vector<int>& nums, int k) {
    k = nums.size() - k + 1;
    int l = INT_MAX, r = INT_MIN;
    unordered_map<int,int> cnt;
    for(auto &x : nums) {
        ++ cnt[x];
        l = min(l, x);
        r = max(r, x);
    }
    for(int i = l; i <= r; i ++ ) {
        if((k -= cnt[i]) <= 0)  return i;
    }
    // never go there
    return INT_MAX;
}

第 k 个数(拓展)

4. 寻找两个正序数组的中位数

给定两个大小分别为 mn 的正序(从小到大)数组 nums1nums2。请你找出并返回这两个正序数组的 中位数

算法的时间复杂度应该为 O(log (m+n))

示例 1:

复制代码
输入:nums1 = [1,3], nums2 = [2]
输出:2.00000
解释:合并数组 = [1,2,3] ,中位数 2

示例 2:

复制代码
输入:nums1 = [1,2], nums2 = [3,4]
输出:2.50000
解释:合并数组 = [1,2,3,4] ,中位数 (2 + 3) / 2 = 2.5

提示:

  • nums1.length == m
  • nums2.length == n
  • 0 <= m <= 1000
  • 0 <= n <= 1000
  • 1 <= m + n <= 2000
  • -106 <= nums1[i], nums2[i] <= 106

分治 + 递归

这道题的意思其实很简单,就是让我们求两个数组的中位数。我们假设我们已经将两个数组排序成一个长度为 n+m 的数组,那么:

  • n+m 为奇数时:求第 n 2 \frac{n}{2} 2n 个元素
  • n+m 为偶数时,求第 n − 1 2 \frac{n-1}{2} 2n−1 和第 n 2 \frac{n}{2} 2n 个元素取平均

注意都是下取整

那么其实我们就把问题转换为了求第 k k k 小数,只不过相较于朴素的第 k 小数,这里我们要在两个有序数组中找到第 k k k 小的数。

我们可以基于这样的思想:每次排除掉 k 2 \frac{k}{2} 2k 个元素,直到最后只剩一个元素。

cpp 复制代码
// REF:windliang
// 在两个有序数组中求第 k 小数,每次排除掉 k/2 个元素,k-=k/2

class Solution {
    // 在两个有序数组中求第 k 小数
    double recursion(vector<int> &nums1, int l1, int r1, vector<int> &nums2, int l2, int r2, int k) {
        // 放了方便处理数组为空的情况,我们总是令 nums1 的长度大于等于 nums2 的长度
        if(r2 - l2 > r1 - l1)   return recursion(nums2, l2, r2, nums1, l1, r1, k);

        // 递归结束条件
        if(l2 > r2)   return nums1[l1 + k - 1];
        if(k == 1)  return min(nums1[l1], nums2[l2]);
        
        // 每次排除掉 k/2 个元素,注意 nums2 可能不够 k/2 个元素,但 nums1 一定有 k/2 个元素(不做证明)
        int cnt1 = min(k / 2, r1 - l1 + 1);
        int cnt2 = min(k / 2, r2 - l2 + 1);
        
        // 子递归
        if(nums1[l1 + cnt1 - 1] < nums2[l2 + cnt2 - 1])   
            return recursion(nums1, l1 + cnt1, r1, nums2, l2, r2, k - cnt1); // 抛弃掉 nums1[l1, l1 + cnt1 - 1] 共 cnt1 个元素
        return recursion(nums1, l1, r1, nums2, l2 + cnt2, r2, k - cnt2);     // 抛弃掉 nums2[l2, l2 + cnt2 - 1] 共 cnt2 个元素
    }
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        // 我们这里规定第 1 小数对应 arr[0],即数组下标从 0 开始
        
        // 0 1 (2) 3 4       ->    idx=size/2=5/2=2     ->  a[idx]      
        // 0 1 (2) (3) 4 5   ->    idx=size/2=6/2=3     ->  (a[idx-1] + a[idx]) / 2

        int n = nums1.size(), m = nums2.size();
        int idx = (n + m) / 2;
            // 注意由数组下标 idx 转换为 k 需要 +1,即 k=idx+1
        if((n + m) & 1) return recursion(nums1, 0, n - 1, nums2, 0, m - 1, idx + 1);
        return (recursion(nums1, 0, n - 1, nums2, 0, m - 1, idx) + recursion(nums1, 0, n - 1, nums2, 0, m - 1, idx + 1)) / 2.0;
    }
};
  • 时间复杂度: O ( l o g ( m + n ) ) O(log(m+n)) O(log(m+n)) 咱们的核心逻辑是"找第 k 小的数"。每次递归,都会对比并精准淘汰掉 k /2 个绝对不可能的元素。这相当于每次都把查找范围暴力缩小一半。因为 k 的初始值差不多是总长度的一半,这种不断"折半砍"的动作,执行次数也就是 log(m +n) 级别。
  • 空间复杂度: O ( l o g ( m + n ) ) O(log(m+n)) O(log(m+n)) 代码里咱们全程都是拿着索引(l1r1 等)在原数组上比对,没有新建任何数组或哈希表。唯一的内存消耗来自递归调用时产生的系统函数栈。既然递归往下走了 log(m +n ) 层,栈的深度自然就是 O (log(m +n))。

为什么每次排除掉 1 2 \frac{1}{2} 21 的元素

简单来说,选 k / 2 k/2 k/2 是为了在**"保证绝对安全"的前提下,做到"最极致的效率"**。核心就这 3 点:

  • 追求二分的速度: 如果我们每次只比较并排除 1 个元素(比如用双指针挨个比),那找第 k k k 个数就要循环 k k k 次,时间复杂度就退化成了龟速的 O ( k ) O(k) O(k)。为了满足题目对数级别的要求,我们必须像二分查找一样,每次把目标范围直接"砍掉一半"。
  • 数学上的绝对安全: 为什么排掉较小的那 k / 2 k/2 k/2 个数绝对不会"误杀"目标?假设我们比较了两个数组的第 k / 2 k/2 k/2 个数(设为 A A A 和 B B B)。如果 A < B A < B A<B,那么 A A A 这个数,撑死了也只能大于 nums1 里的 k / 2 − 1 k/2 - 1 k/2−1 个数,以及 nums2 里的 k / 2 − 1 k/2 - 1 k/2−1 个数。加起来总共才 k − 2 k-2 k−2 个数。这意味着, A A A 充其量也就是全局第 k − 1 k-1 k−1 小的数。所以,A A A 以及它前面的所有数,连成为第 k k k 小数的资格都没有,直接整锅端掉是绝对安全的。
  • 为什么不排除更多?: 比如你想一次性排除 2 k / 3 2k/3 2k/3 个元素来加速,那就无法保证上面的数学推导了,你极有可能会把真正的答案给误删掉。 k / 2 k/2 k/2 就是理论上能一口气安全切掉的最大极限。

每次切掉一半的 k k k,这才是真正的"分治"艺术。

迭代写法

把递归改成迭代(也就是用 while 循环),核心思想完全没变,只是把函数自己调用自己,变成了不断更新指针。

这样一来,连递归那点系统栈的内存都省了,空间复杂度直接干到绝对的 O ( 1 ) O(1) O(1)。

cpp 复制代码
class Solution {
    // 迭代版的求第 K 小数
    double getKth(vector<int>& nums1, vector<int>& nums2, int k) {
        int m = nums1.size(), n = nums2.size();
        // l1 和 l2 相当于两个指针,记录两个数组当前还没被淘汰的起始位置
        int l1 = 0, l2 = 0;

        while (true) {
            // 边界情况 1:nums1 里的数全被淘汰光了,直接去 nums2 里找剩下的第 k 个
            if (l1 == m) return nums2[l2 + k - 1];
            // 边界情况 2:nums2 里的数全被淘汰光了
            if (l2 == n) return nums1[l1 + k - 1];
            // 边界情况 3:k 减到了 1,说明下一个就是要找的数,直接挑俩指针当前指着的最小值
            if (k == 1) return min(nums1[l1], nums2[l2]);

            // 各自往前看 k/2 个元素,注意加上当前的偏移量 l1/l2,并且千万别越界
            int idx1 = min(l1 + k / 2 - 1, m - 1);
            int idx2 = min(l2 + k / 2 - 1, n - 1);

            // 谁的"界碑"小,就大胆淘汰谁及它前面的那截元素
            if (nums1[idx1] <= nums2[idx2]) {
                // 精准扣减 k 值:减去这次实际淘汰的元素个数
                k -= (idx1 - l1 + 1);
                // 把 nums1 的指针推到被淘汰元素的下一个位置
                l1 = idx1 + 1;
            } else {
                k -= (idx2 - l2 + 1);
                l2 = idx2 + 1;
            }
        }
    }

public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        int n = nums1.size(), m = nums2.size();
        int totalLength = n + m;
        
        // 奇偶情况拆分,复用 getKth 逻辑
        if (totalLength % 2 == 1) {
            return getKth(nums1, nums2, totalLength / 2 + 1);
        } else {
            return (getKth(nums1, nums2, totalLength / 2) + 
                    getKth(nums1, nums2, totalLength / 2 + 1)) / 2.0;
        }
    }
};

迭代版的核心变化就这三点:

  • 循环代替递归: 用一个 while (true) 死循环罩住,满足边界条件就直接 return 结果。
  • 指针推进: 之前是靠传递各种 l1, r1 缩小范围,现在只需要维护 l1l2 两个起始指针,不断向右推,相当于逻辑上把左边的元素给"截断"了。
  • 内存极简: 除了几个整型变量,什么额外空间都没开辟。