01 TopK 问题
Top-K问题简单来说就是求数据集合中前 K 个最大的元素或者最小的元素,一般情况下数据量都比较大。这个问题在我们日常生活中非常常见,比如说:游戏中活跃度前十的玩家,世界五百强企业等等。
解决这个问题常见的思路就是遍历或者排序,但是当数据量较大时这种方法就并不适用了。这时我们就需要建堆来处理。
02 解决方法
① 用前 K 个数建立一个 K 个数的小堆(求前 K 个最大的数就建小堆,前 K 个最小的数就建大堆)。
② 剩下的 N - K 个数,依次跟堆顶元素比较,如果比堆顶元素大,就进行替换,再向下调整。
③ 最后堆里面的 K 个数就是最大的 K 个数。
这里为什么使用小堆而不使用大堆?
最大的前 K 个数一定比其他数要大,用小堆的话,最大的数进去后一定会沉到最下面,所以不会出现大的数堵在堆顶导致某个数进不去堆的情况,数越大越在下面。对应的,如果使用大堆就会出现一个大的数堵在堆顶,导致剩下比它小的数全部进不去,最后只能选出最大的。
03 代码实现
堆的实现
Heap.h:
cpp
#pragma once
#include <stdio.h>
#include <assert.h>
#include <stdlib.h>
#include <stdbool.h>
typedef int HPDataType;
typedef struct Heap {
HPDataType* array;
int size;
int capacity;
} HP;
/* 堆的初始化 */
void HeapInit(HP* php);
/* 堆的销毁 */
void HeapDestroy(HP* php);
/* 堆的打印 */
void HeapPrint(HP* php);
/* 判断堆是否为空 */
bool HeapIsEmpty(HP* hp);
/* 堆的插入 */
void HeapPush(HP* php, HPDataType x);
/* 检查容量 */
void HeapCheckCapacity(HP* php);
/* 交换函数 */
void Swap(HPDataType* px, HPDataType* py);
/* 大根堆上调 */
void BigAdjustUp(int* arr, int child);
/* 小根堆上调 */
void SmallAdjustUp(int* arr, int child);
/* 堆的删除 */
void HeapPop(HP* php);
/* 大根堆下调 */
void BigAdjustDown(int* arr, int n, int parent);
/* 小根堆下调 */
void SmallAdjustDown(int* arr, int n, int parent);
/* 返回堆顶数据*/
HPDataType HeapTop(HP* php);
/* 统计堆的个数 */
int HeapSize(HP* php);
Heap.c:
cpp
#include "Heap.h"
/* 堆的初始化 */
void HeapInit(HP* php) {
assert(php);
php->array = NULL;
php->size = php->capacity = 0;
}
/* 堆的销毁 */
void HeapDestroy(HP* php) {
assert(php);
free(php->array);
php->size = php->capacity = 0;
}
/* 堆的打印 */
void HeapPrint(HP* php) {
for (int i = 0; i < php->size; ++i) {
printf("%d ", php->array[i]);
}
printf("\n");
}
/* 判断堆是否为空 */
bool HeapIsEmpty(HP* php) {
assert(php);
return php->size == 0;
}
/* 检查容量 */
void HeapCheckCapacity(HP* php) {
if (php->size == php->capacity) {
int new_capacity = php->capacity == 0 ? 4 : php->capacity * 2;
HPDataType* tmp_array = (HPDataType*)realloc(php->array, sizeof(HPDataType) * new_capacity);
if (tmp_array == NULL) {
printf("realloc failed");
exit(-1);
}
php->array = tmp_array;
php->capacity = new_capacity;
}
}
void Swap(HPDataType* px, HPDataType* py) {
HPDataType tmp = *px;
*px = *py;
*py = tmp;
}
/* 大根堆上调 */
void BigAdjustUp(int* arr, int child) {
assert(arr);
// 根据公式算出父亲的下标
int father = (child - 1) / 2;
// 最坏情况:调到根,child == father当 child 为根节点时结束(根节点永远是0)
while (child > 0) {
if (arr[child] > arr[father]) {
//HPDataType tmp = arr[child];
//arr[child] = arr[father];
//arr[father] = tmp;
Swap(&arr[child], &arr[father]);
// 往上走
child = father;
father = (child - 1) / 2;
}
else {
break;
}
}
}
/* 小根堆上调 */
void SmallAdjustUp(int* arr, int child) {
assert(arr);
// 根据公式算出父亲的下标
int father = (child - 1) / 2;
// 最坏情况:调到根,child == father当 child 为根节点时结束(根节点永远是0)
while (child > 0) {
if (arr[child] < arr[father]) {
Swap(&arr[child], &arr[father]);
// 往上走
child = father;
father = (child - 1) / 2;
}
else {
break;
}
}
}
/* 堆的插入 */
void HeapPush(HP* php, HPDataType x) {
assert(php);
// 检查是否需要扩容
HeapCheckCapacity(php);
// 插入数据
php->array[php->size] = x;
php->size++;
// 向上调整
SmallAdjustUp(php->array, php->size - 1);
}
/* 大根堆下调 */
void AdjustDown(int* arr, int n, int parent) {
// 默认为左孩子
int child = parent * 2 + 1;
while (child < n) {
if (child + 1 > n && arr[child + 1] > arr[child]) {
child = child + 1;
}
if (arr[child] > arr[parent]) {
Swap(&arr[child], &arr[parent]);
parent = child;
child = parent * 2 + 1;
}
else {
break;
}
}
}
/* 小根堆下调 */
void SmallAdjustDown(int* arr, int n, int parent) {
// 默认为左孩子
int child = parent * 2 + 1;
while (child < n) {
if (child + 1 < n && arr[child + 1] < arr[child]) {
child = child + 1;
}
if (arr[child] < arr[parent]) {
Swap(&arr[child], &arr[parent]);
parent = child;
child = parent * 2 + 1;
}
else {
break;
}
}
}
/* 堆的删除 */
void HeapPop(HP* php) {
assert(php);
assert(!HeapIsEmpty(php));
Swap(&php->array[0], &php->array[php->size - 1]);
php->size--;
SmallAdjustDown(php->array, php->size, 0);
}
/* 返回堆顶数据*/
HPDataType HeapTop(HP* php) {
assert(php);
assert(!HeapIsEmpty(php));
return php->array[0];
}
/* 统计堆的个数 */
int HeapSize(HP* php) {
assert(php);
return php->size;
}
TopK 实现
cpp
#include "Heap.h"
/* 在N个数中找出最大的前K个 */
void PrintTopK(int* arr, int N, int K) {
// 初始化堆
HP hp;
HeapInit(&hp);
// 创建一个 K 个数的小堆
for (int i = 0; i < K; ++i) {
HeapPush(&hp, arr[i]);
}
// 剩下的 N - K 个数依次和堆顶比较
for (int i = K; i < N; ++i) {
if (arr[i] > HeapTop(&hp)) {
HeapPop(&hp);
HeapPush(&hp, arr[i]);
}
}
HeapPrint(&hp);
HeapDestroy(&hp);
}
void TestTopK() {
int N = 1000000;
int* arr = (int*)malloc(sizeof(int) * N);
srand(time(0));
for (size_t i = 0; i < N; ++i) {
arr[i] = rand() % 1000000;
}
arr[5] = 1000000 + 1;
arr[1231] = 1000000 + 2;
arr[5355] = 1000000 + 3;
arr[51] = 1000000 + 4;
arr[15] = 1000000 + 5;
arr[2335] = 1000000 + 6;
arr[9999] = 1000000 + 7;
arr[76] = 1000000 + 8;
arr[423] = 1000000 + 9;
arr[3144] = 1000000 + 10;
PrintTopK(arr, N, 10);
}
int main() {
TestTopK();
return 0;
}
运行结果如下:
