
「寻找两个正序数组的中位数」是一道来自LeetCode Hot100题单 的Hard难度练习题。同时它也是2011年408笔试的手撕算法真题。
本文是笔者解答该题的整个思维过程的完整记录,以及从中提炼出的更具有一般性的解题思路。希望能给其他同学提供一些参考。
暴力解法
这道题目一个很直接的办法便是对两个数组进行线性扫描+归并。这在大一的《C语言》或者《数据结构》课程中我们早已经接触过了。很显然,它的时间复杂度和空间复杂度均为 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( m + n ) O(m+n) </math>O(m+n),距离题目所要求的 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( log ( m + n ) ) O(\log(m+n)) </math>O(log(m+n))还有不小的距离。
C++
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m = nums1.size();
int n = nums2.size();
int len = m + n;
std::vector<int> merged;
merged.reserve(len);
int i = 0;
int j = 0;
while (i < m && j < n) {
if (nums1[i] < nums2[j]) {
merged.push_back(nums1[i++]);
} else {
merged.push_back(nums2[j++]);
}
}
while (i < m) {
merged.push_back(nums1[i++]);
}
while (j < n) {
merged.push_back(nums2[j++]);
}
return len % 2 == 1 ? merged[len / 2] : (merged[len / 2 - 1] + merged[len / 2]) / 2.0;
}
};
思考1:重新审视问题,反思我们的目的到底是啥?
我们的目的是要求出中位数。在刚才我写的代码中,为了找到中位数,把两个数组完完整整地合并成了一个 merged 数组。
于是我的第一个反思是:
为了找到中位数,我们真的需要耗费 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( m + n ) O(m+n) </math>O(m+n)这么多的时间和空间,去知道 merged 数组里所有的元素吗?还是说,我们其实只需要关心特定位置上的那一、两个数就可以了?
针对这个疑问,我又重新整理了思路,给出如下结论:
- 如果
nums1[m-1]≤nums2[0]
,即max(nums1)≤min(nums2)
,这意味着两个数组在逻辑上已经完成了归并,此时我们直接通过简单的下标访问即可得出答案 - 如果
nums2[n-1]≤nums1[0]
,与之类似,我们也不需要人工进行归并排序,即可直接通过下标访问得出答案 - 针对
nums1
和nums2
中的数值从小到大交错排列的一般情况,我认为也不是总是需要进行完整的数组归并。从我的代码中也可以看到,实际上我们只关心merged[len / 2 - 1]
和merged[len / 2]
这两个值 。我认为这意味着我们只需要完成merged
数组前半部分的计算,即可直接得出答案,后面的归并操作事实上是多余的。
不过到这里仍然很遗憾。这仍然是一个线性复杂度的操作 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( ( m + n ) / 2 ) O((m+n)/2) </math>O((m+n)/2)。虽然看上去又有了些收获,但实际上我并没有给出更优的办法。
思考2:怎么做才有可能实现对数级别的复杂度呢?
从刚才的思考中,我意识到了一件很重要的事情------我们只关心merged
数组中间的那一两个数 。而为了找到它们,我提出进行一半的归并。只不过遗憾的是,复杂度依然是 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( ( m + n ) / 2 ) O((m+n)/2) </math>O((m+n)/2),也就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( m + n ) O(m+n) </math>O(m+n),这意味着这个解题方向大概率到这里是可以放弃了。
在解答算法题的过程中,这事实上是一个强烈的信号:我需要一种方法,能够一次性地排除掉一批我们不关心的元素 ,而不是一个一个地排除。通过在迭代或递归的过程中,每一次都能做到一次性大幅度地缩小问题,我才有可能从 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( m + n ) O(m+n) </math>O(m+n)跨越到 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( log ( m + n ) ) O(\log (m+n)) </math>O(log(m+n)),就像二分查找一样。
好,我决定继续顺着这个思路往下,最终实现消除"遍历"的目的。
思考3:咋实现一次性地将问题的规模进行缩小呢?
到此为止我便进入了解答本题的最困难的阶段。起初,我只盯着题目要求的"中位数"不放,结果未能找到任何突破口。
一般这种情况下,就需要考虑对问题进行泛化了。
我想到:我们不去找中位数,我们先试着在两个有序数组中找到第k
小的元素 。如果能在解决这个问题,那么找中位数就简单了,因为要求中位数,本质上就是去找第(m+n)/2
个或者第(m+n)/2 + 1
个元素(取决于总长度的奇偶)。
与中学时代长期做题训练中强调的"把问题不断简化"不同,有时候把一个具体的问题(找中位数)抽象成一个更通用、更灵活的问题(找第k小数),反而能让我们的核心逻辑更清晰。
接着我又注意到,在算法运行过程中,这个k
值其实是有可能被我们不断地进行缩小的。比如说,我们能不能每次都让它缩小一半呢?即原问题有没有希望能从求第k
小个数,被简化为求第k/2
个小的数,然后一直迭代(递归)下去,直到k=1
?
如果可以办到,我们应该就有希望写出一个耗时为 <math xmlns="http://www.w3.org/1998/Math/MathML"> T ( k ) = T ( k / 2 ) + O ( 1 ) T(k)=T(k/2) + O(1) </math>T(k)=T(k/2)+O(1)的算法。学过《算法分析与设计》课程的我们应该很容易就能用主方法计算出,这个算法的时间复杂度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( log k ) O(\log k) </math>O(logk)。
同时刚才我已经想到了,原题本质上就是当k=(m+n)/2
或者k=(m+n)/2+1
时的特殊情况。将k
代入表达式,可以得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( log ( m + n ) ) O(\log (m+n)) </math>O(log(m+n)),完美符合原题的要求!!!
思考4:分析缩小问题规模方案的可行性
找到了疑似的突破口,接下来我还要进一步分析这个方案的可行性。
我首先意识到,要解决这个问题,数组nums1
和nums2
的规模(或者说"区间范围")在逻辑上应该是需要随着k
一起在算法迭代过程中被不断缩小的。不然我们可以想象一下:当k=1
时,假如这两个数组还是原来的规模,我们又该如何确定返回它们中的哪个元素作为答案呢?这很难,对吧。
这就与我前面想到的一次性地排除掉一批我们不关心的元素 呼应起来了。那么nums1
和nums2
的规模一次性可以被缩小多少呢?
很显然,我们首先就应该去考虑能不能一次性地从nums1
或者nums2
当中排除掉k/2
个元素。即下一个子问题中输入数组规模,和所求目标的规模均缩小k/2
------这看上去凭感觉应该是最和谐统一、便于实现的。
为了验证这个最直接的想法,我又引入了两个辅助变量pivot1=nums1[k/2-1]
和pivot2=nums2[k/2-1]
。结合之前的想法,现在的问题在于:我们能否在 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 1 ) O(1) </math>O(1)时间内,将nums1[0, pivot1]
或nums2[0, pivot2]
范围内的元素(恰好有k/2
个),从原问题的输入当中排除出去呢?
<math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 1 ) O(1) </math>O(1)时间一般暗示我们只能进行常数次的比较。这里我尝试直接比较pivot1
和pivot2
,看看根据比较结果能否推导出一些能指导算法进行大规模排除的结论。
例如,对于pivot1 ≤ pivot2
,即... ≤ nums1[k/2 - 1] ≤ nums2[k/2 - 1] ≤ ...
,我这里不妨考虑一种极端情况:从nums1[0]
到nums1[k/2 - 2
](共k/2-1
个数),以及从nums2[0]
到nums2[k/2-2]
(共k/2-2
个数),这共计k - 2
个数都会落在这个不等式当中nums1[k/2 - 1]
的左侧。同时注意到在这种情况下nums2[k/2 - 1]
即为我们要求的第k小的数。
换句话说,通过这个极端例子,我意识到在pivot1 ≤ pivot2
的前提下,从nums1[0]
到nums1[k/2 - 1]
这k/2
个数,绝无成为我们要求解的第k
小的数的机会。在下一轮的算法迭代过程中,我们可以直接将它们排除出去。而对于pivot1 > pivot2
,结论是对称的。并且这整个过程的时间开销恰好只需我们梦寐以求的 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 1 ) O(1) </math>O(1)。
到这里,我已经基本验证了前面提出的算法方案的可行性。解决问题的最大障碍几乎已经被消除了!
思考5:递归的终止条件
我们来整理一下目前的逻辑:
- 为了找到第
k
小的数,我们比较nums1[k/2 - 1]
和nums2[k/2 - 1]
。 - 如果
nums1[k/2 - 1] ≤ nums2[k/2 - 1]
,我们就知道nums1
的前k/2
个数可以被排除。问题转化为:在nums1
剩下的部分和整个nums2
中寻找第k - k/2
小的数。 - 同理,如果
nums2[k/2 - 1] < nums1[k/2 - 1]
,我们就可以排除nums2
的前k/2
个数。问题转化为:在整个nums1
和nums2
剩下的部分中寻找第k - k/2
小的数。
这个递归的思路已经非常清晰了。在真正动手写代码之前,我们还需要考虑最后一步,也是编程中至关重要的一步。
这个递归过程会不断地排除元素、缩小 k,那它什么时候会停下来呢?
换句话说,我需要考虑清楚在不断排除元素和缩小 k 的过程中,我们会遇到哪些最简单、最极端,以至于我们不需要再递归,可以直接返回答案的情况?
首先我们来约定几个符号:nums1'
(经过若干轮递归后被不断缩减的nums1
),nums2'
(经过若干轮递归后被不断缩减的nums2
),k'
(经过若干轮递归后被不断缩减的k)
以下是我刚开始思考得出的结论:
- 情况0 :
- 我们的递归不变式中还有一个重要的条件------为了能够推动递归继续向下进行,我们必须保证
k'/2≥1
,也就是k'2≥2
。 - 这意味着,当
k'=1
时,我们需要结束递归了。 - 同时,在这种情况下,只有
nums1'[0]
(如果nums1'
不为空的话)或者nums2'[0]
(如果nums2'
不为空的话)有机会成为合并数组中第1小的元素 。我们只需要取它们俩的最小值返回,即为最终答案。
- 我们的递归不变式中还有一个重要的条件------为了能够推动递归继续向下进行,我们必须保证
- 情况1-1 :
nums1'
和nums2'
均不为空,同时nums1'
的大小已经缩减到k'/2
以下,而nums2'
的大小仍然≥k'/2
。- 这时我们直接对比
nums1'.back()
和nums2'[2k-1]
。 - 如果
nums1'.back() < nums2'[2k-1]
,那么直接将nums1'
中的所有元素全部排除掉。然后转入情况3-1。 - 如果
nums1'.back() ≥ nums2'[2k-1]
,那么直接将nums2'
中前k'/2
个元素全部排除掉。然后再次调用递归函数。
- 这时我们直接对比
- 情况1-2 :
nums1'
和nums2'
均不为空,同时nums2'
的大小已经缩减到k'/2
以下,而nums1'
的大小仍然≥k'/2
。- 这种情况与情况1-1是对称的。
- 情况2 :
nums1'
和nums2'
均不为空,同时nums1'
和nums2'
的大小均已经缩减到k'/2
以下。- 这种情况下显然是无法继续进行递归了。此时我们需要将算法退化为最初我实现的线性扫描+合并吗?
- 不需要,事实上这种情况是不存在的,压根不需要考虑 。因为我们要计算的是第
k'
小的数,这意味着至少应该有nums1'.size()+nums2'.size()≥k'
。 - 而这种情况下
nums1'.size() + nums2'.size() < k'/2 + k'/2
,即nums1'.size() + nums2'.size() < k'
,这与我们的递归不变式是矛盾的。
- 情况3-1 :
nums1'
已经变为空,而nums2'
不为空。- 这时候我们直接返回
nums2[k' - 1]
,这就是最终答案。
- 这时候我们直接返回
- 情况3-2 :
nums1'
不为空,而nums2'
已经变为空。- 这与情况3-1是对称的。
思考6(可选):进一步化简递归出口
事实上,完成前面的分析后,我就可以写代码了。但在coding前,我又注意到前述的情况1-1和情况1-2实际上还是有进一步被化简的空间的。
比如说,对于情况1-1,我们可以直接排除掉从nums2'[0]
到nums2'[k/2 - 1]
这k/2
个元素,然后进入下一轮递归。
这是啥原理呢?
其实很好解释。我们这里仍然考虑一种极端情况:
假设nums1'[0] ≤ nums1'[1] ≤ ... ≤ nums1'.back() ≤ nums2'[0] ≤ nums2'[1] ≤ ... ≤ nums2'[k'/2-1]
,
即merged
数组中nums1'
中的元素全部排在nums2'
的前面。那么在这种情况下,nums2'[k'/2-1]
充其量也不过是第(k'/2-1)+k'/2=k'-1
小的数,绝无成为第k'
小的数的可能性。换句话来说,我们可以理直气壮地将nums2'[0]
到nums2'[k/2 - 1]
这k/2
个元素排除出去,而无需担心会错过正确答案!
在具体代码实现中,对于这种情况下我们可以简洁地通过直接将pivot1
置为无穷大来做到这一点。
情况1-2的分析是对称的。
代码实现
C++
#define INF (std::numeric_limits<int>::max())
class Solution {
public:
int GetKthElement(const std::vector<int>& nums1, int begin1, const std::vector<int>& nums2, int begin2, int k) {
assert(k >= 1);
int len_nums1 = nums1.size() - begin1;
int len_nums2 = nums2.size() - begin2;
if (k == 1) {
int first_in_nums1 = len_nums1 == 0 ? INF : nums1[begin1];
int first_in_nums2 = len_nums2 == 0 ? INF : nums2[begin2];
assert(first_in_nums1 != INF || first_in_nums2 != INF);
return std::min(first_in_nums1, first_in_nums2);
}
if (len_nums1 == 0) {
assert(len_nums2 > 0);
return nums2[begin2 + k - 1];
}
if (len_nums2 == 0) {
assert(len_nums1 > 0);
return nums1[begin1 + k - 1];
}
int pivot1 = len_nums1 < k / 2 ? INF : nums1[begin1 + k / 2 - 1];
int pivot2 = len_nums2 < k / 2 ? INF : nums2[begin2 + k / 2 - 1];
if (pivot1 <= pivot2) {
// 淘汰从nums1[0]到nums1[k/2-1]的全部元素
return GetKthElement(nums1, begin1 + k / 2, nums2, begin2, k - k / 2);
} else {
// 淘汰从nums2[0]到nums2[k/2-1]的全部元素
return GetKthElement(nums1, begin1, nums2, begin2 + k / 2, k - k / 2);
}
}
double findMedianSortedArrays(std::vector<int>& nums1, std::vector<int>& nums2) {
int total_len = nums1.size() + nums2.size();
if (total_len % 2 == 1) {
return GetKthElement(nums1, 0, nums2, 0, total_len / 2 + 1);
} else {
int a = GetKthElement(nums1, 0, nums2, 0, total_len / 2);
int b = GetKthElement(nums1, 0, nums2, 0, total_len / 2 + 1);
return (a + b) / 2.0;
}
}
};
提炼一般的方法论:如何攻克一道困难的手撕题?
第一步:寻找"笨办法",建立直觉 (Find the Brute-Force Solution)
目标: 确保你完全理解了问题,并有一个保底的、能 work 的方案。
做法:
- 忘掉所有限制: 先忽略题目中关于时间/空间复杂度的要求(比如这道题的 O(log(m+n)))。
- 用最直观、最符合人类思维的方式解决它: 就像我最开始做的那样,"要找中位数,我就先把两个数组合并,然后直接取中间的数"。这就是最直观的暴力解法。
- 分析"笨办法": 写出(或在脑中想出)这个笨办法的代码,并分析它的时间/空间复杂度。这为你后续的优化提供了一个明确的 Baseline(基准线) 。
在面试中的作用:
- 向面试官证明你读懂了题目,并且具备基本的编程实现能力。
- 破冰,避免冷场。即使暂时想不到最优解,你也可以说:"首先,一个最直接的想法是... 它的复杂度是... 但这不满足要求,我们可以尝试优化。"
第二步:识别瓶颈,找到浪费 (Identify the Bottleneck)
目标: 找出你的"笨办法"为什么慢/为什么耗空间。
做法:
- 审视核心操作: 在你的 O(m+n) 解法中,最耗时的操作是什么?是 while 循环里一次又一次的比较和 push_back。
- 寻找"冗余计算": 例如在我解答本题最开始的地方,我反思的第一个问题就是:"我们真的需要知道 merged 数组里所有 的元素吗?" 我的结论是"不,我们只关心中间那一两个"。这就找到了问题的关键------为了得到那1、2个结果,我们计算了(m+n)/2个中间值,这里存在巨大的计算浪费。
在面试中的作用:
- 展现你的分析能力。能清晰地指出当前解法的瓶颈,是找到优化方向的前提。
- 引导对话,向面试官展示你的思考路径:"这个解法的瓶颈在于...,因为它做了很多不必要的计算。我的优化思路就是如何避免这些冗余计算。"
第三步:联想经典模型,寻找"跨越式"解法 (Match with Classic Models)
目标: 将问题与你所学的经典算法模型进行匹配,实现从"线性"到"对数/常数"的飞跃。
做法:
-
思考"排除法": 当你发现瓶颈是"遍历/线性扫描"时,脑海里要立刻亮起一盏灯:有没有办法不一个一个看,而是一片一片地扔掉?
-
匹配模型:
- "一次扔一半" -> 二分查找 (Binary Search)
- "问题能分解成结构相同的子问题" -> 递归/分治 (Recursion / Divide & Conquer)
- "需要存储中间计算结果以避免重复计算" -> 动态规划 (Dynamic Programming)
- "从局部最优推出全局最优" -> 贪心算法 (Greedy Algorithm)
比如,在这道题中,我发现可以"排除一半",立刻就联想到了二分查找和分治。这就是最关键的一步。于是,我引入了 pivot 的概念,尝试将问题规模减半,自然而然地走上了 O(log n) 的道路。
在面试中的作用:
- 这是整个解题过程中最高光的时刻,体现了你的算法知识储备和应用能力。
- 即使不能立刻设计出完美方案,你也可以说:"这个'一次排除一半'的特性,让我联想到了二分查找的思想。也许我们可以尝试定义两个指针,通过比较它们来缩小搜索范围..."
第四步:处理边界,完善代码 (Handle Edge Cases)
目标: 将优化的思路转化为健壮、完整的代码。
做法:
-
思考递归的终止条件: 例如我通过分析明确当 k=1 或"一个数组被排空"时,递归应该停止。这是分治算法的"底"。
-
思考极端输入:
- 一个数组为空怎么办?
- 两个数组都为空怎么办?
- 数组长度差别巨大怎么办?
- 我们要找的 k/2 比某个数组的长度还大怎么办?
-
打磨细节: 比如 total_len 的奇偶判断,下标是 k 还是 k-1,整数除法 /2.0 的问题等。
在面试中的作用:
- 展现你严谨的思维和代码功底。能考虑到各种边界情况的候选人,在工程实践中会更可靠。
- 即使最终代码没写完,清晰地讲出你对这些边界情况的考虑,也能获得很高的评价。
总结
当你再遇到一个没见过的题目或者难题时,就启动这个四步流程:
- 先求有,再求好: 搞定暴力解法。
- 找浪费,定方向: 分析暴力解法的瓶颈。
- 套模型,求飞跃: 联想二分、分治等经典模型。
- 补细节,保周全: 处理好所有边界和终止条件。