
思路:
1.第k小/大问题的通用转化方法:
(1)第k小等价于:求最小的x,满足<=x的数至少有k个(注意是至少不是恰好)。
(2)第k大等价于:求最大的x,满足>=x的数至少有k个(注意是至少不是恰好)。
2.在本题中是找第k小的元素,因此就是找最小的target,满足 <= target的数至少有k个。因此target越大,越能找到k个数;target越小,越不能找到k个数。据此,可以二分猜答案。
3.本题可以转化为给定整数target,统计有序矩阵中 <= target的元素个数cnt,判断是否满足cnt >= k。
4.如何做到高效统计cnt?通过双指针实现,如下图所示。

5.细节:本代码中采用闭区间二分,使用开区间或者半开半闭区间也是可以的。
(1)闭区间左端点的初始值:matrix[0][0]。由于不会存在比最小值还小的数,也就是有0个数 <= matrix[0][0],由于0 < k,所以无法满足要求。
(2)闭区间右端点的初始值:matrix[n - 1][n - 1],有n^2个数 <= matrix[n - 1][n - 1],由于n^2 >= k,所以一定满足要求。
(3)二分查找: 寻找第k小的元素。如果矩阵中 <= mid的元素个数 >= k,那么第k小的元素一定 <= mid。此时记录mid作为一个候选答案,我们还要继续向左寻找,看能否找到更小的满足条件的值。最终返回的就是最小的满足条件的mid。
6.疑问:为什么二分结束后,答案ans一定在矩阵中?
答:
虽在在过程中ans记录的可能不是矩阵中的元素,但由于我们不断缩小范围,最终会下降到矩阵中实际存在的元素,这是因为:任何不在矩阵中的元素x,如果它满足小于等于它的数cnt >= k,那么比它小的下一个矩阵元素也一定满足cnt >= k。因此二分最终会收敛到矩阵中实际存在的数字。所以ans最终就是矩阵中第k小的元素。
7.复杂度分析:
(1)时间复杂度:O(nlog(U))。其中n是matrix的行数和列数,U = matrix[n - 1][n - 1] - matrix[0][0]。二分O(log(U))次,每次需要跑一个O(n)的双指针。
(2)空间复杂度:O(1)。
附代码:
java
class Solution {
public int kthSmallest(int[][] matrix, int k) {
int n = matrix.length;
int left = matrix[0][0];
int right = matrix[n - 1][n - 1];
int ans = -1;
while (left <= right) {
int mid = left + (right - left) / 2;
// 说明矩阵中 <= mid的元素个数 >= k
// 说明第k小的元素一定 <= mid
if (check(matrix,mid,k)) {
// 此时记录的mid是一个候选答案
ans = mid;
// 缩小范围继续往左边找,看能否找到更小的满足条件的值
right = mid - 1;
} else {
// 不满足条件就向右找
left = mid + 1;
}
}
// 最终返回最小的满足条件的值
return ans;
}
private boolean check(int[][] matrix,int target,int k) {
int n = matrix.length;
int cnt = 0; // matrix中的 <= target的元素个数
int i = 0;
int j = n - 1; // 从右上角开始
while (i < n && j >= 0 && cnt < k) {
if (matrix[i][j] > target) {
j--; // 排除第j列
} else {
cnt += j + 1; // 说明整行元素都 <= target,cnt加上这行元素的个数
i++; // 第i行加完也排除
}
}
return cnt >= k; // 判断matrix中的 <= target的元素个数是否 >= k
}
}
ACM模式:
java
import java.util.Scanner;
class Solution {
public int kthSmallest(int[][] matrix, int k) {
int n = matrix.length;
int left = matrix[0][0];
int right = matrix[n - 1][n - 1];
int ans = -1;
while (left <= right) {
int mid = left + (right - left) / 2;
if (check(matrix,mid,k)) {
ans = mid;
right = mid - 1;
} else {
left = mid + 1;
}
}
return ans;
}
private boolean check(int[][] matrix,int target,int k) {
int n = matrix.length;
int cnt = 0;
int i = 0;
int j = n - 1;
while (i < n && j >= 0 && cnt < k) {
if (matrix[i][j] > target) {
j--;
} else {
cnt += j + 1;
i++;
}
}
return cnt >= k;
}
}
public class Main {
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
// 读取矩阵大小
int n = scanner.nextInt();
// 读取矩阵
int[][] matrix = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
matrix[i][j] = scanner.nextInt();
}
}
int k = scanner.nextInt();
// 计算第k小的元素
Solution solution = new Solution();
int result = solution.kthSmallest(matrix, k);
System.out.println(result);
scanner.close();
}
}