11.5Merge k Sorted Lists 优先队列
优先队列描述
优先队列可以在O(1)的时间内获得最大值,并且可以在O(log n)的时间内取出最大值或插入任意值。
优先队列常常用堆来实现,堆是一个完全二叉树,每个节点的值总是大于等于子节点的值,实现堆的时候,我们通常用一个数组而不是用指针建立一个树。因为堆是完全二叉树,所以用数组表示时,位置i的节点的父节点位置一定为i/2,而他的两个子节点的位置有一定分别为2i和2i+1。
题目描述
给定k个增序的链表,试着将他们合并成一条增序链表
输入输出样例
Input :
1-\>4-\>5,1-\>3-\>4,2-\>6
Output:1->1->2->3->4->4->5->6
题解
本题可以有很多种解法,比如类似于归并排序进行两两合并。我们这里展示一个速度比较快的方法,即把所有的链表存储在一个优先队列中,每次提取所有链表头部节点值最小的那个节点,直到所有链表都被提取完为止。注意因为 Comp 函数默认是对最大堆进行比较并维持递增关系,如果我们想要获取最小的节点值,则我们需要实现一个最小堆,因此比较函数应该维持递减关系,所以 operator () 中返回时用大于号而不是等增关系时的小于号进行比较。
cpp
#include <iostream>
#include <vector>
#include <queue> // 优先队列需要这个头文件
using namespace std;
// 1. 必须自己定义链表节点结构(C++ 标准库没有 ListNode)
struct ListNode {
int val;
ListNode *next;
ListNode() : val(0), next(nullptr) {}
ListNode(int x) : val(x), next(nullptr) {}
ListNode(int x, ListNode *next) : val(x), next(next) {}
};
// 2. 小根堆比较器
struct Comp {
bool operator() (ListNode* l1, ListNode* l2) {
return l1->val > l2->val; // 小根堆:堆顶是最小值
}
};
// 3. 合并 K 个升序链表
ListNode* mergeKLists(vector<ListNode*>& lists) {
if (lists.empty()) return nullptr;
// 优先队列(小根堆)
priority_queue<ListNode*, vector<ListNode*>, Comp> q;
// 把所有链表的头节点放入堆
for (ListNode* list : lists) {
if (list) {
q.push(list);
}
}
// 哑节点,方便构造结果链表
ListNode* dummy = new ListNode(0);
ListNode* cur = dummy;
// 每次取最小节点
while (!q.empty()) {
ListNode* minNode = q.top();
q.pop();
cur->next = minNode;
cur = cur->next;
// 把下一个节点放回堆
if (minNode->next) {
q.push(minNode->next);
}
}
return dummy->next;
}
// ------------------- 辅助函数 -------------------
// 打印链表
void printList(ListNode* head) {
while (head) {
cout << head->val << " ";
head = head->next;
}
cout << endl;
}
// 根据数组创建链表
ListNode* createList(const vector<int>& nums) {
if (nums.empty()) return nullptr;
ListNode* dummy = new ListNode(0);
ListNode* cur = dummy;
for (int num : nums) {
cur->next = new ListNode(num);
cur = cur->next;
}
return dummy->next;
}
// ------------------- 主函数测试 -------------------
int main() {
// 测试用例:3 个升序链表
vector<int> l1 = {1, 4, 5};
vector<int> l2 = {1, 3, 4};
vector<int> l3 = {2, 6};
// 构建链表
ListNode* list1 = createList(l1);
ListNode* list2 = createList(l2);
ListNode* list3 = createList(l3);
// 放入数组
vector<ListNode*> lists = {list1, list2, list3};
// 合并
ListNode* mergedList = mergeKLists(lists);
// 输出结果
cout << "合并后的有序链表:";
printList(mergedList);
return 0;
}