归并排序
目录
算法思想
通过将当前乱序的数组分成两个部分,分别进行「递归调用 」,利用两个指针将数据元素以此比较,选择相对较小的元素放进「辅助数组 」中,再将辅助数组的数据放回「原数组」
命名由来
归并=递归+合并
算法描述
问题描述
leetcode第148题
给你链表的头结点 head,请将其按 升序 排列并返回排序后的链表。
sortList函数
先看sortList,此函数的目的是对链表进行归并排序。
cpp
ListNode* sortList(ListNode* head) {
if (head == nullptr) // 1
return nullptr;
else if (head->next == nullptr) // 2
return head;
ListNode *slow = head, *fast = head; // 3
ListNode *pre = nullptr;
while (fast != nullptr)
{
pre = slow;
slow = slow->next;
fast = fast->next;
if (fast)
fast = fast->next;
}
ListNode *tmp = pre->next;
pre->next = nullptr; //4
return mergeSort(head, tmp); //5
}
(1) 当链表没有元素的时候不需要排序,直接返回null;
(2) 当链表只有一个元素的时候也不需要排序,返回本身即可;
(3) 我们用快慢指针来找到链表的中间节点,并将链表分为两部分 ,分别是左半部分和右半部分;
(4) 此时我们就完成了对一个链表的切割,左边是以head为头结点的链表,右边则是以tmp指针为头结点的链表 ;
(5) 调用 mergeSort 函数进行合并排序。
mergeSort函数
cpp
ListNode* mergeSort(ListNode* a, ListNode* b)
{
a = sortList(a);
b = sortList(b); // 1
ListNode* head = new ListNode(0);
ListNode* tmp = head; // 2
head->next = nullptr;
while (a || b) // 3
{
if (a == nullptr)
{
tmp->next = b;
break;
}
else if (b == nullptr)
{
tmp->next = a;
break;
}
else if (a->val < b->val)
{
tmp->next = a;
a = a->next;
}
else if (a->val >= b->val)
{
tmp->next = b;
b = b->next;
}
tmp = tmp->next;
tmp->next = nullptr;
}
return head->next; // 4
}
(1) a 和 b 分别表示左边部分和右边部分,将 a 和 b 分别传入 sortList 函数中进行排序(递归调用);
(2) 创建一个新的头节点 head,以及一个临时节点 tmp 用于构建合并后的链表;
(3) 通过比较 a 和 b 的值,逐个选择较小的节点接入到新链表中,直至其中一个链表为空。
(4) 最后,返回合并后链表的头节点(即 head->next),并注意释放之前创建的虚拟头节点。
源代码
cpp
/**
* Definition for singly-linked list.
* struct ListNode {
* int val;
* ListNode *next;
* ListNode() : val(0), next(nullptr) {}
* ListNode(int x) : val(x), next(nullptr) {}
* ListNode(int x, ListNode *next) : val(x), next(next) {}
* };
*/
class Solution {
ListNode* mergeSort(ListNode* a, ListNode* b)
{
a = sortList(a);
b = sortList(b);
ListNode* head = new ListNode(0);
ListNode* tmp = head;
head->next = nullptr;
while (a || b)
{
if (a == nullptr)
{
tmp->next = b;
break;
}
else if (b == nullptr)
{
tmp->next = a;
break;
}
else if (a->val < b->val)
{
tmp->next = a;
a = a->next;
}
else if (a->val >= b->val)
{
tmp->next = b;
b = b->next;
}
tmp = tmp->next;
tmp->next = nullptr;
}
return head->next;
}
public:
ListNode* sortList(ListNode* head) {
if (head == nullptr) //
return nullptr;
else if (head->next == nullptr) //
return head;
ListNode *slow = head, *fast = head, *pre = nullptr;
while (fast != nullptr)
{
pre = slow;
slow = slow->next;
fast = fast->next;
if (fast)
fast = fast->next;
}
ListNode *tmp = pre->next;
pre->next = nullptr;
return mergeSort(head, tmp);
}
};