快速排序(QuickSort)是一种高效的分治排序算法,其核心思想是选取一个基准值(pivot),将数组分成两部分:一部分所有元素小于等于基准值,另一部分所有元素大于基准值。然后对这两部分递归应用相同的操作,直到子数组规模为 0 或 1。
快速排序的步骤分三步:
- 选择基准值:通常选数组的第一个、最后一个、中间或随机元素作为 pivot。例如,选择最后一个元素。
- 分区(Partition):将数组重新排列,使得左边部分的所有元素 <= pivot,右边部分的所有元素 > pivot。最终 pivot 位于其正确排序位置。实现方式是用两个指针 i 和 j ,i 指向小于 pivot 的部分的边界,j 遍历数组,交换元素以满足条件。
- 递归排序:对 pivot 左侧和右侧的子数组递归调用快速排序。
Java 代码实现:
ini
public void quickSort(int[] arr, int low, int high) {
if (low < high) {
int pi = partition(arr, low, high);
quickSort(arr, low, pi - 1);
quickSort(arr, pi + 1, high);
}
}
private int partition(int[] arr, int low, int high) {
int pivot = arr[high];
int i = low - 1;
for (int j = low; j < high; j++) {
if (arr[j] <= pivot) {
i++;
int temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
}
}
int temp = arr[i + 1];
arr[i + 1] = arr[high];
arr[high] = temp;
return i + 1;
}
递归的终止条件:如果 low >= high,说明子数组为空或只有一个元素,无需排序,直接返回。
快速排序的时间复杂度取决于分区平衡性和递归深度,主要分以下几种情况:
-
平均情况:O(nlogn)
假设每次 partition 将数组大致分成两等份,pivot 接近中位数,每次 partition 遍历子数组,比较和交换操作需要 O(n) 时间(n 是子数组大小)。假设数组每次均分为两半,递归树的深度为 O(logn),因为每次子数组规模减半。每一层的总工作量是O(n)(所有子数组的元素总数约为 n )。总时间复杂度 = 每层工作量 O(n) × 层数 O(logn) = O(nlogn)。这种情况出现在随机数据或 pivot 选择较好时。
-
最坏情况:O(n²)
假设每次 partition 分区极不平衡,比如数组已排序(升序或降序),且 pivot 选最后一个元素,或者数组元素全部相同。每次 partition 将数组分为两部分:一部分为空或只有一个元素,另一部分包含剩余 n-1 个元素。递归树的深度为O(n),因为每次只减少一个元素。每次 partition 仍需 O(n) 时间(遍历整个子数组)。总时间复杂度 = 每层工作量 O(n) + O(n-1) + ... + O(1) = O(n²)。这种情况出现在已排序数组、逆序数组或全相同元素且pivot 固定选最后一个元素时。
-
最好情况:O(nlogn)
每次 partition 完美地将数组分成两等份(pivot 恰好是中位数)。与平均情况类似,递归深度为 O(log n),每层工作量为 O(n),总时间为 O(n log n)。这种情况极少发生,除非 pivot 选择机制(如三数取中)确保接近中位数。
快速排序的空间复杂度主要由递归调用栈和辅助变量决定。在递归调用栈中,平均情况下,当分区平衡时,递归树的深度为 O(logn),每次递归调用在栈上存储常量级的变量(low、high、pi 等),因此栈空间为O(log n)。最坏情况下,当分区极不平衡时(例如已排序数组),递归深度为 O(n)。每次递归调用仍存储常数级变量,栈空间为 O(n)。partition 方法中使用局部变量:pivot、i、j、temp,均为 O(1)。排序是原地进行的(通过交换元素,不需要额外数组),因此没有额外的 O(n) 空间。所以快速排序的总空间复杂度由递归栈主导,平均情况是O(log n),最坏情况是O(n)。
快速排序的四种优化算法:
- 随机化 Pivot,通过随机选择 pivot,避免最坏情况(如已排序数组或全相同元素),使分区更平衡,平均时间复杂度更稳定为 O(n log n)。
ini
import java.util.Random;
static void quickSort(int[] arr, int low, int high) {
if (low < high) {
int pi = partition(arr, low, high);
quickSort(arr, low, pi - 1);
quickSort(arr, pi + 1, high);
}
}
static int partition(int[] arr, int low, int high) {
// 随机化 pivot
Random rand = new Random();
int randomIndex = low + rand.nextInt(high - low + 1);
// 交换随机索引与 high
int temp = arr[randomIndex];
arr[randomIndex] = arr[high];
arr[high] = temp;
int pivot = arr[high];
int i = low - 1;
for (int j = low; j < high; j++) {
if (arr[j] <= pivot) {
i++;
temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
}
}
temp = arr[i + 1];
arr[i + 1] = arr[high];
arr[high] = temp;
return i + 1;
}
- 三数取中,通过选择首、尾、中间元素的中位数作为 pivot,尽量使分区平衡,减少最坏情况发生。
ini
static void quickSort(int[] arr, int low, int high) {
if (low < high) {
int pi = partition(arr, low, high);
quickSort(arr, low, pi - 1);
quickSort(arr, pi + 1, high);
}
}
static int partition(int[] arr, int low, int high) {
// 三数取中
int mid = low + (high - low) / 2;
// 比较 low, mid, high 三个位置的值,找出中位数
int pivotIndex;
if ((arr[low] <= arr[mid] && arr[mid] <= arr[high]) || (arr[high] <= arr[mid] && arr[mid] <= arr[low])) {
pivotIndex = mid;
} else if ((arr[mid] <= arr[low] && arr[low] <= arr[high]) || (arr[high] <= arr[low] && arr[low] <= arr[mid])) {
pivotIndex = low;
} else {
pivotIndex = high;
}
// 将中位数交换到 high
int temp = arr[pivotIndex];
arr[pivotIndex] = arr[high];
arr[high] = temp;
int pivot = arr[high];
int i = low - 1;
for (int j = low; j < high; j++) {
if (arr[j] <= pivot) {
i++;
temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
}
}
temp = arr[i + 1];
arr[i + 1] = arr[high];
arr[high] = temp;
return i + 1;
}
- 小数组优化,对小规模子数组(例如长度 ≤ 10),使用插入排序替代快速排序,减少递归开销和函数调用,提高性能。
ini
static void quickSort(int[] arr, int low, int high) {
// 小数组优化,阈值设为 10
if (high - low + 1 <= 10) {
insertionSort(arr, low, high);
return;
}
if (low < high) {
int pi = partition(arr, low, high);
quickSort(arr, low, pi - 1);
quickSort(arr, pi + 1, high);
}
}
static int partition(int[] arr, int low, int high) {
int pivot = arr[high];
int i = low - 1;
for (int j = low; j < high; j++) {
if (arr[j] <= pivot) {
i++;
int temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
}
}
int temp = arr[i + 1];
arr[i + 1] = arr[high];
arr[high] = temp;
return i + 1;
}
// 插入排序实现
static void insertionSort(int[] arr, int low, int high) {
for (int i = low + 1; i <= high; i++) {
int key = arr[i];
int j = i - 1;
while (j >= low && arr[j] > key) {
arr[j + 1] = arr[j];
j--;
}
arr[j + 1] = key;
}
}
- 重复元素优化,优化大量重复元素的情况,避免因 arr[j] <= pivot 导致的不必要交换,减少比较和移动次数。将数组分为三部分,小于 pivot 的元素、等于 pivot 的元素和大于 pivot 的元素。
ini
static void quickSort(int[] arr, int low, int high) {
if (low < high) {
int[] result = partitionThreeWay(arr, low, high);
int lt = result[0], gt = result[1];
quickSort(arr, low, lt - 1);
quickSort(arr, gt + 1, high);
}
}
static int[] partitionThreeWay(int[] arr, int low, int high) {
int pivot = arr[high];
int lt = low; // 小于 pivot 区域右边界
int gt = high; // 大于 pivot 区域左边界
int i = low; // 当前处理元素
while (i <= gt) {
if (arr[i] < pivot) {
int temp = arr[lt];
arr[lt] = arr[i];
arr[i] = temp;
lt++;
i++;
} else if (arr[i] > pivot) {
int temp = arr[gt];
arr[gt] = arr[i];
arr[i] = temp;
gt--;
} else {
i++; // 等于 pivot,直接跳过
}
}
return new int[]{lt, gt}; // 返回小于和大于区域的边界
}
综合以上优化算法,可以得到一个更加鲁棒的快速排序实现:
ini
import java.util.Random;
static void quickSort(int[] arr, int low, int high) {
// 小数组优化
if (high - low + 1 <= 10) {
insertionSort(arr, low, high);
return;
}
if (low < high) {
int[] result = partitionThreeWay(arr, low, high);
int lt = result[0], gt = result[1];
quickSort(arr, low, lt - 1);
quickSort(arr, gt + 1, high);
}
}
static int[] partitionThreeWay(int[] arr, int low, int high) {
// 三数取中
int mid = low + (high - low) / 2;
int pivotIndex;
if ((arr[low] <= arr[mid] && arr[mid] <= arr[high]) || (arr[high] <= arr[mid] && arr[mid] <= arr[low])) {
pivotIndex = mid;
} else if ((arr[mid] <= arr[low] && arr[low] <= arr[high]) || (arr[high] <= arr[low] && arr[low] <= arr[mid])) {
pivotIndex = low;
} else {
pivotIndex = high;
}
// 随机化 pivot(可选,进一步增强鲁棒性)
Random rand = new Random();
if (rand.nextBoolean()) {
pivotIndex = low + rand.nextInt(high - low + 1);
}
// 交换到 high
int temp = arr[pivotIndex];
arr[pivotIndex] = arr[high];
arr[high] = temp;
int pivot = arr[high];
int lt = low, gt = high, i = low;
while (i <= gt) {
if (arr[i] < pivot) {
temp = arr[lt];
arr[lt] = arr[i];
arr[i] = temp;
lt++;
i++;
} else if (arr[i] > pivot) {
temp = arr[gt];
arr[gt] = arr[i];
arr[i] = temp;
gt--;
} else {
i++;
}
}
return new int[]{lt, gt};
}
static void insertionSort(int[] arr, int low, int high) {
for (int i = low + 1; i <= high; i++) {
int key = arr[i];
int j = i - 1;
while (j >= low && arr[j] > key) {
arr[j + 1] = arr[j];
j--;
}
arr[j + 1] = key;
}
}