寻找两个正序数组的中位数
题目描述
给定两个大小分别为 m 和 n 的正序(从小到大)数组 nums1 和 nums2。请你找出并返回这两个正序数组的中位数。
算法的时间复杂度应该为 O(log(m+n))。
示例 1:
输入:nums1 = [1,3], nums2 = [2]
输出:2.00000
解释:合并数组 = [1,2,3] ,中位数 2
示例 2:
输入:nums1 = [1,2], nums2 = [3,4]
输出:2.50000
解释:合并数组 = [1,2,3,4] ,中位数 (2 + 3) / 2 = 2.5
提示:
- nums1.length == m
- nums2.length == n
- 0 <= m <= 1000
- 0 <= n <= 1000
- 1 <= m + n <= 2000
- -10^6 <= nums1i, nums2i <= 10^6
解题思路总览
| 方法 | 核心思想 | 时间复杂度 | 空间复杂度 | 特点 |
|---|---|---|---|---|
| 二分查找(第 k 小) | 用二分查找找第 (m+n)/2 小和第 (m+n+1)/2 小的数 | O(log(m+n)) | O(1) | 标准解法,满足要求 |
| 双指针合并 | 类似合并两个有序链表,逐步合并 | O(m+n) | O(m+n) | 直观易懂,不满足复杂度要求 |
| 暴力合并排序 | 先合并再排序 | O((m+n)log(m+n)) | O(m+n) | 代码最简洁,但效率最低 |
| 寻找第 k 小(双指针) | 维护两个指针,每次跳过不可能是第 k 小的元素 | O(k) | O(1) | k 小时效率高 |
方法一:二分查找(第 k 小元素)
代码实现
cpp
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m = nums1.size();
int n = nums2.size();
int total = m + n;
// 如果总长度为奇数,直接找第 (total+1)/2 小的数
if (total % 2 == 1) {
return findKth(nums1, 0, m, nums2, 0, n, (total + 1) / 2);
} else {
// 如果总长度为偶数,找第 total/2 小和第 total/2+1 小的数
int k1 = total / 2;
int k2 = total / 2 + 1;
double a = findKth(nums1, 0, m, nums2, 0, n, k1);
double b = findKth(nums1, 0, m, nums2, 0, n, k2);
return (a + b) / 2.0;
}
}
private:
// 在两个有序数组中找第 k 小的数(k 从 1 开始)
double findKth(vector<int>& nums1, int i, int m, vector<int>& nums2, int j, int n, int k) {
// 确保 nums1 是较短的数组,简化逻辑
if (m > n) {
return findKth(nums2, j, n, nums1, i, m, k);
}
// 如果 nums1 已经遍历完,直接在 nums2 中找第 k 小
if (m == 0) {
return nums2[j + k - 1];
}
// 如果 k == 1,找两个数组中最小的那个
if (k == 1) {
return min(nums1[i], nums2[j]);
}
// 比较两个数组中第 k/2 个元素的大小
int mid1 = min(i + k / 2, i + m);
int mid2 = min(j + k / 2, j + n);
if (nums1[mid1 - 1] < nums2[mid2 - 1]) {
// nums1 的前 k/2 个元素都小于 nums2 的第 k/2 个元素
// 排除 nums1 的前 (mid1-i) 个元素
return findKth(nums1, mid1, m - (mid1 - i), nums2, j, n, k - (mid1 - i));
} else {
// nums2 的前 k/2 个元素都小于等于 nums1 的第 k/2 个元素
// 排除 nums2 的前 (mid2-j) 个元素
return findKth(nums1, i, m, nums2, mid2, n - (mid2 - j), k - (mid2 - j));
}
}
};
核心思想
中位数的定义:
- 奇数长度:中位数是第 (m+n+1)/2 小的数
- 偶数长度:中位数是第 (m+n)/2 小和第 (m+n)/2+1 小的数的平均值
因此问题转化为:在两个有序数组中找第 k 小的数。
关键技巧:每次排除 k/2 个不可能是第 k 小的元素,通过比较 nums1mid1 和 nums2mid2 来决定排除哪一段。
算法流程图
以 nums1 = [1,3], nums2 = [2] 为例,total=3(奇数)
找第 (3+1)/2 = 2 小的数:
findKth(nums1=[1,3], i=0, m=2, nums2=[2], j=0, n=1, k=2):
m=2 <= n=1?不成立,继续
m != 0,k != 1
mid1 = min(0+1, 0+2) = 1
mid2 = min(0+1, 0+1) = 1
nums1[0]=1 < nums2[0]=2
排除 nums1 的前 1 个元素
return findKth(nums1=[3], i=1, m=1, nums2=[2], j=0, n=1, k=1)
findKth(nums1=[3], i=1, m=1, nums2=[2], j=0, n=1, k=1):
k == 1
return min(nums1[1]=3, nums2[0]=2) = 2
结果:中位数 = 2
逐行解析
cpp
int total = m + n;
计算两个数组的总长度。
cpp
if (total % 2 == 1) {
return findKth(nums1, 0, m, nums2, 0, n, (total + 1) / 2);
}
奇数长度时,直接找第 (total+1)/2 小的数作为中位数。
cpp
int k1 = total / 2;
int k2 = total / 2 + 1;
double a = findKth(nums1, 0, m, nums2, 0, n, k1);
double b = findKth(nums1, 0, m, nums2, 0, n, k2);
return (a + b) / 2.0;
偶数长度时,找第 total/2 小和第 total/2+1 小的数,取平均值。
cpp
if (m > n) {
return findKth(nums2, j, n, nums1, i, m, k);
}
确保 nums1 是较短的数组,简化后续逻辑。
cpp
if (m == 0) {
return nums2[j + k - 1];
}
如果 nums1 已遍历完,第 k 小的数必在 nums2 中。
cpp
if (k == 1) {
return min(nums1[i], nums2[j]);
}
k == 1 时,找两个数组当前开头的最小值。
cpp
int mid1 = min(i + k / 2, i + m);
int mid2 = min(j + k / 2, j + n);
计算两个数组中的"候选"位置,取 min 是为了防止越界。
cpp
if (nums1[mid1 - 1] < nums2[mid2 - 1]) {
return findKth(nums1, mid1, m - (mid1 - i), nums2, j, n, k - (mid1 - i));
} else {
return findKth(nums1, i, m, nums2, mid2, n - (mid2 - j), k - (mid2 - j));
}
关键排除逻辑:
- 如果 nums1mid1-1 < nums2mid2-1,说明 nums1 的前 k/2 个元素都小于 nums2 的第 k/2 个元素,因此可以排除这些元素
- 否则排除 nums2 的前 k/2 个元素
复杂度分析
| 复杂度 | 分析 |
|---|---|
| 时间 | 每次递归将 k 减少约一半,最多递归 O(log(m+n)) 次 |
| 空间 | O(log(m+n)),递归栈深度 |
方法二:双指针合并
代码实现
cpp
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
vector<int> merged;
int m = nums1.size();
int n = nums2.size();
int i = 0, j = 0;
// 合并两个有序数组
while (i < m && j < n) {
if (nums1[i] <= nums2[j]) {
merged.push_back(nums1[i++]);
} else {
merged.push_back(nums2[j++]);
}
}
// 处理剩余元素
while (i < m) {
merged.push_back(nums1[i++]);
}
while (j < n) {
merged.push_back(nums2[j++]);
}
int total = m + n;
int k = (total - 1) / 2;
if (total % 2 == 1) {
return merged[k];
} else {
return (merged[k] + merged[k + 1]) / 2.0;
}
}
};
核心思想
类似合并两个有序链表,使用双指针逐步合并两个数组,将较小元素加入结果数组,最后根据总长度计算中位数。
算法流程图
以 nums1 = [1,2], nums2 = [3,4] 为例:
合并过程:
第1轮:nums1[0]=1 <= nums2[0]=3,merged=[1],i=1
第2轮:nums1[1]=2 <= nums2[0]=3,merged=[1,2],i=2
nums1 已遍历完,将 nums2 剩余元素加入
merged=[1,2,3,4]
计算中位数:
total=4, k=(4-1)/2=1
total 为偶数
merged[1]=2, merged[2]=3
中位数 = (2+3)/2 = 2.5
复杂度分析
| 复杂度 | 分析 |
|---|---|
| 时间 | O(m+n),需要遍历两个数组各一次 |
| 空间 | O(m+n),需要额外数组存储合并结果 |
方法三:暴力合并排序(代码版本)
代码实现
cpp
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
vector<int> merged = nums1;
merged.insert(merged.end(), nums2.begin(), nums2.end());
ranges::sort(merged);
int n = merged.size();
int k = (n - 1) / 2;
return n % 2 ? merged[k] : (merged[k] + merged[k + 1]) / 2.0;
}
};
复杂度分析
| 复杂度 | 分析 |
|---|---|
| 时间 | O((m+n)log(m+n)),合并 O(m+n),排序 O((m+n)log(m+n)) |
| 空间 | O(m+n),merged 数组 |
为什么不推荐?
题目要求时间复杂度为 O(log(m+n)),但排序需要 O((m+n)log(m+n)),不满足要求。不过代码写法简洁,面试时可以作为"最直观解法"先写出,再优化为二分查找。
方法四:优化双指针(找第 k 小)
代码实现
cpp
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m = nums1.size();
int n = nums2.size();
int total = m + n;
int k1 = (total + 1) / 2;
int k2 = (total + 2) / 2;
double a = findKth(nums1, 0, m, nums2, 0, n, k1);
double b = findKth(nums1, 0, m, nums2, 0, n, k2);
return (a + b) / 2.0;
}
private:
double findKth(vector<int>& nums1, int i, int m, vector<int>& nums2, int j, int n, int k) {
// 确保 nums1 是较短的数组
if (m > n) {
return findKth(nums2, j, n, nums1, i, m, k);
}
if (m == 0) {
return nums2[j + k - 1];
}
if (k == 1) {
return min(nums1[i], nums2[j]);
}
int mid1 = min(i + k / 2, i + m);
int mid2 = min(j + k / 2, j + n);
if (nums1[mid1 - 1] < nums2[mid2 - 1]) {
return findKth(nums1, mid1, m - (mid1 - i), nums2, j, n, k - (mid1 - i));
} else {
return findKth(nums1, i, m, nums2, mid2, n - (mid2 - j), k - (mid2 - j));
}
}
};
核心思想
与方法一类似,但用 (total+1)/2 和 (total+2)/2 两个 k 值来兼容奇偶情况,避免单独处理奇偶。
注意:(total+1)/2 和 (total+2)/2 在奇数时相等,在偶数时相差1。
复杂度分析
| 复杂度 | 分析 |
|---|---|
| 时间 | O(log(m+n)) |
| 空间 | O(log(m+n)),递归栈 |
边界情况分析
情况1:其中一个数组为空
输入: nums1 = [], nums2 = [1]
分析: total = 1(奇数)
findKth 返回 nums2[0] = 1
结果: 1.0
情况2:两个数组都只有一个元素
输入: nums1 = [1], nums2 = [2]
分析: total = 2(偶数)
k1 = (2+1)/2 = 1, k2 = (2+2)/2 = 2
findKth(..., k1=1) = min(1,2) = 1
findKth(..., k2=2) = max(1,2) = 2
中位数 = (1+2)/2 = 1.5
结果: 1.5
情况3:两个数组完全相同
输入: nums1 = [1,2,3], nums2 = [1,2,3]
分析: total = 6(偶数)
中位数 = (2+3)/2 = 2.5
结果: 2.5
情况4:nums1 完全大于 nums2
输入: nums1 = [5,6,7], nums2 = [1,2,3]
分析: total = 6(偶数)
中位数 = (3+5)/2 = 4.0
合并后 [1,2,3,5,6,7],中位数 (3+5)/2 = 4.0
结果: 4.0
情况5:奇数长度,元素个数为 1
输入: nums1 = [1], nums2 = [2,3,4]
分析: total = 4(偶数)
中位数 = (2+3)/2 = 2.5
结果: 2.5
中位数概念详解
中位数定义:将数组排序后位于中间位置的数
奇数长度数组(如 5 个元素):
位置: 1 2 3 4 5
中位数位置 = (5+1)/2 = 3,即第 3 小的数
偶数长度数组(如 6 个元素):
位置: 1 2 3 4 5 6
中位数位置 = (6+1)/2 = 3 和 (6+2)/2 = 4
即第 3 小和第 4 小的数的平均值
两个有序数组合并后的中位数:
合并后长度 = m + n
奇数:第 (m+n+1)/2 小的数
偶数:第 (m+n)/2 小和第 (m+n)/2+1 小的数的平均值
面试追问 FAQ
| 问题 | 回答 |
|---|---|
| 为什么时间复杂度要求是 O(log(m+n))? | 如果用归并排序,时间复杂度是 O((m+n)log(m+n))。二分查找可以将时间复杂度降到 O(log(m+n)) |
| 方法一的递归深度是多少? | 最多 O(log(m+n)) 次递归,每次 k 减少约一半 |
| 如何处理整数溢出? | 使用 long long 类型存储 m+n 的值,或者用 double 类型 |
| 如果 m 和 n 差距很大怎么办? | 方法一已经通过交换确保 nums1 是较短的数组,效率不受影响 |
| 为什么方法一要用 k - (mid1 - i) 而不是 k - k/2? | 因为 nums1 的有效长度可能小于 k/2,所以要减去实际排除的元素个数 |
| 二分查找和双指针哪个更好? | 对于本题,二分查找(方法一/四)时间复杂度更优;双指针(方法二)更直观易懂 |
| 如果 nums1 和 nums2 都很大怎么办? | 方法一/四每次排除约一半元素,效率高;方法三排序效率低 |
相关题目
| 题目 | 难度 | 核心区别 |
|---|---|---|
| 4. 寻找两个正序数组的中位数(本题) | 困难 | 找两个有序数组的中位数 |
| 23. 合并 K 个升序链表 | 困难 | 合并多个有序链表 |
| 21. 合并两个有序链表 | 简单 | 合并两个有序链表 |
| 876. 链表的中间结点 | 简单 | 找链表中点 |
| 154. 寻找旋转排序数组中的最小值 II | 困难 | 旋转数组找最小值 |
| 33. 搜索旋转排序数组 | 中等 | 在旋转数组中搜索 |
总结
| 要点 | 说明 |
|---|---|
| 核心思想 | 将中位数问题转化为"找第 k 小的数",用二分查找优化 |
| 关键技巧 | 每次比较 nums1mid1 和 nums2mid2,排除 k/2 个不可能是第 k 小的元素 |
| 奇偶处理 | 奇数直接返回第 (m+n+1)/2 小的数;偶数返回第 (m+n)/2 小和第 (m+n)/2+1 小的数的平均值 |
| 时间复杂度 | O(log(m+n)) |
| 空间复杂度 | O(log(m+n)),递归栈 |
| 防溢出 | 使用 min(i + k/2, ...) 防止数组越界 |