并查集的使用是如下的场景
1.一开始每个元素都拥有自己的集合,在自己的集合里只有这个元素自己。
2.find(i):查找i所在集合的代表元素,代表元素来代表i所在的集合。
3.bool isSameSet(a, b):判断a和b在不在一个集合里。
4.void union(a, b):a所在集合所有元素 与 b所在集合所有元素合并成一个集合。
5.各种操作单次调用的均摊时间复杂度为O(1)。
并查集的两个优化
1.扁平化(一定要做)
2.小挂大(可以不做,原论文中是秩的概念,可以理解为粗略高度或者大小)
并查集的小扩展
可以定制信息:并查集目前有多少个集合,以及给每个集合打上标签信息。
并查集时间复杂度的理解:作为如此简单、小巧的结构,感性理解单次调用的均摊时间复杂度为O(1)即可,其实为α(n),阿克曼函数。当n=10^80次方即可探明宇宙原子量,α(n)的返回值也不超过6,那就可以认为是O(1)。并查集的发明者Bernard A. Galler和Michael J. Fischer,从1964年证明到1989年才证明完毕,建议记住即可,理解证明难度很大!
下面我们讲解并查集的实现。
并查集认为,同一集合的元素具有相同的代表元素,即在同一集合里面的元素的代表元素是相同的,而在最开始每一个元素以自己为一个集合,即最开始每一个元素的代表元素都是自己。find方法查找集合的代表元素就很好理解;判断a和b在不在一个集合里,只需判断a和b的代表元素是否相同;a所在集合的所有元素与b所在集合的所有元素合并成一个集合,只需将a或者b所在集合的所有元素的代表元素全部改为b或a所在集合的代表元素。所以我们通过一个father数组保存每一个元素的指向元素,即用下标代表此元素,下标的值代表指向的元素,则father数组在初始化的时候下标和下标的值是相等的。而并查集的两个优化,第一,扁平化,是指当我们查一个元素的代表元素时,有可能它是经过很多次指向才指到了代表元素,所以当得到代表元素时,将这个指向路径上的所有元素的指向元素全部改为代表元素;第二,小挂大,是指合并集合的时候将容量小的集合指向容量大的集合。
下面通过几个题加深理解。
题目一
测试链接:https://www.nowcoder.com/practice/e7ed657974934a30b2010046536a5372
分析:这道题就是一个标准的并查集模板。代码如下。
cpp
#include <iostream>
#define MAXN 1000002
using namespace std;
int father[MAXN];
int Size[MAXN];
int N;
int M;
int find(int i) {
int behalf = i;
int temp;
while (father[behalf] != behalf) {
behalf = father[behalf];
}
while (father[i] != behalf) {
temp = i;
i = father[i];
father[temp] = behalf;
}
return behalf;
}
bool isSameSet(int a, int b) {
return find(a) == find(b);
}
void Union(int a, int b) {
int behalf_a = find(a);
int behalf_b = find(b);
if (behalf_a != behalf_b) {
if (Size[behalf_a] < Size[behalf_b]) {
father[behalf_a] = behalf_b;
Size[behalf_b] += Size[behalf_a];
} else {
father[behalf_b] = behalf_a;
Size[behalf_a] += Size[behalf_b];
}
}
}
int main(void) {
int opt, x, y;
scanf("%d%d", &N, &M);
for (int i = 1; i <= N; ++i) {
father[i] = i;
Size[i] = 1;
}
for (int i = 0; i < M; ++i) {
scanf("%d%d%d", &opt, &x, &y);
switch (opt) {
case 1:
if (isSameSet(x, y)) {
printf("Yes\n");
} else {
printf("No\n");
}
break;
case 2:
Union(x, y);
break;
}
}
}
其中,扁平化采用while循环的方式,后面会展示递归的方式;Union方法采用小挂大的方式。
题目二
测试链接:https://leetcode.cn/problems/couples-holding-hands/
分析:这道题我们可以根据情侣的ID判断出这是第几对情侣,同时需要明白一个现象,一个集合中,如果有n对情侣没有并肩坐在一起,那么需要n-1次交换,才能使这n对情侣并肩坐在一起。那么,遍历数组每两个情侣ID并合并集合,最终可以得到几个集合,每个集合中有若干对情侣。这时候统计每个集合需要的交换次数相加即可。这里有一个小优化,最开始每一对情侣都是一个集合,最后合并完之后,得到一个集合数,因为每一个集合需要交换的次数是集合中的对数减1,那总对数是知道的,也就是最开始集合的个数,那么,只需要将最开始集合的个数减去现在集合的个数就是需要交换的次数。代码如下。
cpp
class Solution {
public:
int father[31];
int Size[31];
int number;
int find(int i){
int behalf = i;
int temp;
while (father[behalf] != behalf)
{
behalf = father[behalf];
}
while (father[i] != behalf)
{
temp = i;
i = father[i];
father[temp] = behalf;
}
return behalf;
}
bool isSameSet(int a, int b){
return find(a) == find(b);
}
void Union(int a, int b){
int behalf_a = find(a);
int behalf_b = find(b);
if(behalf_a != behalf_b){
if(Size[behalf_a] < Size[behalf_b]){
father[behalf_a] = behalf_b;
Size[behalf_b] += Size[behalf_a];
}else{
father[behalf_b] = behalf_a;
Size[behalf_a] += Size[behalf_b];
}
--number;
}
}
int minSwapsCouples(vector<int>& row) {
int n = row.size() / 2;
number = n;
for(int i = 0;i < number;++i){
father[i] = i;
Size[i] = 1;
}
for(int i = 0;i < row.size();i += 2){
Union(row[i]/2, row[i+1]/2);
}
return n - number;
}
};
其中,n为对数;Union方法每合并一次集合数减1。
题目三
测试链接:https://leetcode.cn/problems/similar-string-groups/
分析:这道题的思路还是比较明确的,就是如果这两个相似则合并,所以重点就在如何判断相似。我们可以首先得到两个字符串的长度,如果长度不相等,则这两个字符串不相似;如果长度相等,则开始遍历字符串,如果不同的个数为0或2代表相似,其余个数代表不相似。代码如下。
cpp
class Solution {
public:
int father[301];
int Size[301];
int number;
int find(int i){
int behalf = i;
int temp;
while (father[behalf] != behalf)
{
behalf = father[behalf];
}
while (father[i] != behalf)
{
temp = i;
i = father[i];
father[temp] = behalf;
}
return behalf;
}
void Union(int a, int b){
int behalf_a = find(a);
int behalf_b = find(b);
if(behalf_a != behalf_b){
if(Size[behalf_a] < Size[behalf_b]){
father[behalf_a] = behalf_b;
Size[behalf_b] += Size[behalf_a];
}else{
father[behalf_b] = behalf_a;
Size[behalf_a] += Size[behalf_b];
}
--number;
}
}
bool isSame(string s1, string s2){
int length_s1 = s1.size();
int length_s2 = s2.size();
int diff = 0;
if(length_s1 != length_s2){
return false;
}
for(int i = 0;i < length_s1 && diff < 3;++i){
if(s1[i] != s2[i]){
++diff;
}
}
return diff == 0 || diff == 2;
}
int numSimilarGroups(vector<string>& strs) {
int length = strs.size();
number = length;
for(int i = 0;i < number;++i){
father[i] = i;
Size[i] = 1;
}
for(int i = 0;i < length;++i){
for(int j = i+1;j < length;++j){
if(isSame(strs[i], strs[j])){
Union(i, j);
}
}
}
return number;
}
};
其中,最开始每个字符串为一个集合;相似则合并且集合数减1,最后返回还剩的集合数就是相似字符串组数。
题目四
测试链接:https://leetcode.cn/problems/number-of-islands/
分析:这道题将每个位置作为一个集合,从上到下,从左到右,遍历二维网格,如果位置值为0,则集合数减1;如果位置为1,则判定上边和左边是否为1,如果为1,合并。合并方法中每次合并集合数减1。遍历完数组剩余的集合数,即为岛屿个数。代码如下。
cpp
class Solution {
public:
int father[90001];
int Size[90001];
int number;
int find(int i){
int behalf = i;
int temp;
while (father[behalf] != behalf)
{
behalf = father[behalf];
}
while (father[i] != behalf)
{
temp = i;
i = father[i];
father[temp] = behalf;
}
return behalf;
}
void Union(int a, int b){
int behalf_a = find(a);
int behalf_b = find(b);
if(behalf_a != behalf_b){
if(Size[behalf_a] < Size[behalf_b]){
father[behalf_a] = behalf_b;
Size[behalf_b] += Size[behalf_a];
}else{
father[behalf_b] = behalf_a;
Size[behalf_a] += Size[behalf_b];
}
--number;
}
}
int numIslands(vector<vector<char>>& grid) {
int row = grid.size();
int column = grid[0].size();
number = row * column;
for(int i = 0;i < number;++i){
father[i] = i;
Size[i] = 1;
}
for(int i = 0;i < row;++i){
for(int j = 0;j < column;++j){
if(grid[i][j] == '0'){
--number;
}else{
if(j > 0 && grid[i][j-1] == '1'){
Union(i * column + j, i * column + (j-1));
}
if(i > 0 && grid[i-1][j] == '1'){
Union(i * column + j, (i-1) * column + j);
}
}
}
}
return number;
}
};
其中,因为网格为二维数组,father和Size数组一维数组,所以这里将二维下标转化为一维下标再操作。
题目五
测试链接:https://leetcode.cn/problems/most-stones-removed-with-same-row-or-column/
分析:这道题思路还是比较清晰的,将同行或者同列的石头看作为一个集合,遍历完石头数组过后,剩余的集合数即代表移除完石头剩余的石头数。将石头总数减去剩余石头数就有是可以移除的石子的最大数量。这里合并石子并不需要通过双重for循环,这样会超时,我们可以利用两个map,一个行map,一个列map,用来查询遍历到的石头的行和列是否已经存在。如果遍历到了一个石子之后,如果行map中已经有这个行的石子存在,则直接合并;如果没有,则将这个石子插入行map。对列map操作相同。这样只需一重for循环遍历数组即可。代码如下。
cpp
class Solution {
public:
map<int, int> rowFirst;
map<int, int> columnFirst;
int father[1000];
int Size[1000];
int number;
int find(int i){
int behalf = i;
int temp;
while (father[behalf] != behalf)
{
behalf = father[behalf];
}
while (father[i] != behalf)
{
temp = i;
i = father[i];
father[temp] = behalf;
}
return behalf;
}
void Union(int a, int b){
int behalf_a = find(a);
int behalf_b = find(b);
if(behalf_a != behalf_b){
if(Size[behalf_a] < Size[behalf_b]){
father[behalf_a] = behalf_b;
Size[behalf_b] += Size[behalf_a];
}else{
father[behalf_b] = behalf_a;
Size[behalf_a] += Size[behalf_b];
}
--number;
}
}
int removeStones(vector<vector<int>>& stones) {
int length = stones.size();
int row, column;
number = length;
for(int i = 0;i < length;++i){
father[i] = i;
Size[i] = 1;
}
for(int i = 0;i < length;++i){
row = stones[i][0];
column = stones[i][1];
if(rowFirst.count(row) == 0){
rowFirst.insert(make_pair(row, i));
}else{
Union(i, (rowFirst.find(row))->second);
}
if(columnFirst.count(column) == 0){
columnFirst.insert(make_pair(column, i));
}else{
Union(i, (columnFirst.find(column))->second);
}
}
return length - number;
}
};
其中,map中第一个数代表行或列,第二个数代表石头的下标。
题目六
测试链接:https://leetcode.cn/problems/find-all-people-with-secret/
分析:这道题我们通过知晓秘密和不知晓秘密两个状态,可以想到尝试并查集。首先对给的会议数组排序,以时间从小到大排序。通过滑动窗口得到相同时间会议数组的下标范围,对这些会议的参与专家合并。对一个时间的所有会议合并完成后,必然会有两个集合,一个是知晓秘密的集合,一个是不知晓秘密的集合,这时需要将不知晓秘密集合中的专家重新初始化,也就是不知晓秘密集合中的专家,每一个专家自己为一个集合。然后在下一个时刻继续重复操作。最后,遍历专家,知晓秘密的专家插入答案数组。代码如下。
cpp
class Solution {
public:
class MyCompare
{
public:
bool operator()(vector<int> v1, vector<int> v2){
return v1[2] < v2[2];
}
};
int father[100002];
bool attr[100002];
int find(int i){
if(i != father[i]){
father[i] = find(father[i]);
}
return father[i];
}
void Union(int a, int b){
int behalf_a = find(a);
int behalf_b = find(b);
if(behalf_a != behalf_b){
father[behalf_a] = behalf_b;
attr[behalf_b] |= attr[behalf_a];
}
}
vector<int> findAllPeople(int n, vector<vector<int>>& meetings, int firstPerson) {
int length = meetings.size();
sort(meetings.begin(), meetings.end(), MyCompare());
for(int i = 0;i < n;++i){
father[i] = i;
attr[i] = false;
}
attr[0] = true;
father[firstPerson] = 0;
for(int left = 0, right;left < length;){
right = left;
while (right + 1 < length && meetings[right+1][2] == meetings[left][2])
{
++right;
}
for(int i = left;i <= right;++i){
Union(meetings[i][0], meetings[i][1]);
}
for(int i = left, a, b;i <= right;++i){
a = meetings[i][0];
b = meetings[i][1];
if(!attr[find(a)]){
father[a] = a;
}
if(!attr[find(b)]){
father[b] = b;
}
}
left = right + 1;
}
vector<int> ans;
for(int i = 0;i < n;++i){
if(attr[find(i)]){
ans.push_back(i);
}
}
return ans;
}
};
其中,attr数组代表是否知晓秘密,true为知晓,false为不知晓;每次同一时间的下标范围通过滑动窗口找出;find方法中的扁平化处理采用递归方式。
题目七
测试链接:https://leetcode.cn/problems/number-of-good-paths/
分析:首先,根据题目有多少个节点,最开始就有多少个好路径。然后我们可以考虑将每条边按这条边连接的两个节点中的最大值从小到大的排序。从小到大是为了不遗漏好路径,也就是从好路径两端从1开始。对于排好序的路径,我们依次遍历路径连接的两个节点所在的集合的最大值。如果相同代表存在好路径,而增加好路径的个数,是两个集合最大值的个数相乘,最后将连接的两个节点合并;如果连接的两个节点所在的集合的最大值不同,则直接将这个边连接的两个节点合并。遍历完边数即可得到好路径的个数。代码如下。
cpp
class Solution {
public:
int father[30001];
int NumOfMaxValueOfset[30001];
int find(int i){
if(i != father[i]){
father[i] = find(father[i]);
}
return father[i];
}
int Union(int a, int b, vector<int>& vals){
int behalf_a = find(a);
int behalf_b = find(b);
int path = 0;
if(vals[behalf_a] > vals[behalf_b]){
father[behalf_b] = behalf_a;
}else if(vals[behalf_a] < vals[behalf_b]){
father[behalf_a] = behalf_b;
}else{
path = NumOfMaxValueOfset[behalf_a] * NumOfMaxValueOfset[behalf_b];
father[behalf_b] = behalf_a;
NumOfMaxValueOfset[behalf_a] += NumOfMaxValueOfset[behalf_b];
}
return path;
}
int numberOfGoodPaths(vector<int>& vals, vector<vector<int>>& edges) {
int num = vals.size();
int length = edges.size();
for(int i = 0;i < num;++i){
father[i] = i;
NumOfMaxValueOfset[i] = 1;
}
// auto cmp = [&vals](const vector<int>& v1, const vector<int>& v2){
// int max1 = max(vals[v1[0]], vals[v1[1]]);
// int max2 = max(vals[v2[0]], vals[v2[1]]);
// return max1 < max2;
// };
// sort(edges.begin(), edges.end(), cmp);
sort(edges.begin(), edges.end(), [&vals](const vector<int>& v1, const vector<int>& v2)->bool{
return (vals[v1[0]] > vals[v1[1]] ? vals[v1[0]] : vals[v1[1]]) <
(vals[v2[0]] > vals[v2[1]] ? vals[v2[0]] : vals[v2[1]]);
});
for(int i = 0;i < length;++i){
num += Union(edges[i][0], edges[i][1], vals);
}
return num;
}
};
其中,合并时没有采用小挂大,而是哪个点所在集合的最大值大就被挂,这样可以确保合并后集合的最大值不会出错;同时,对时间复杂度进行分析可以看出,代码中最耗时的部分是排序,这里的比较函数提供了两种写法,一个没注释掉的使用lambda表达式且里面使用三目运算符,一个是被注释掉的auto写法且里面采用库函数max,两个都可以过,速度差不多;最后需要强调的是,一定在传参数的时候,传容器类型一定一定一定要按引用传递,不然会超时。
题目八
测试链接:https://leetcode.cn/problems/minimize-malware-spread-ii/
分析:题目中包含感染和未感染两个状态,可以尝试使用并查集。一个比较容易想到的思路是对于initial数组中的每一个节点依次删除,然后统计M最小的值时索引最小的节点。
cpp
class Solution {
public:
int father[301];
bool infect[301];
int find(int i){
if(father[i] != i){
father[i] = find(father[i]);
}
return father[i];
}
void Union(int a, int b){
int behalf_a = find(a);
int behalf_b = find(b);
if(behalf_a != behalf_b){
father[behalf_a] = behalf_b;
infect[behalf_b] |= infect[behalf_a];
}
}
int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
int n = graph.size();
int len_initial = initial.size();
int ans = initial[0];
int M = -((1 << 31) + 1);
int temp_M;
int delete_node;
for(int i = 0;i < len_initial;++i){
temp_M = 0;
delete_node = initial[i];
for(int j = 0;j < n;++j){
if(j == delete_node){
continue;
}
father[j] = j;
infect[j] = false;
}
for(int j = 0;j < len_initial;++j){
if(j == i){
continue;
}
infect[initial[j]] = true;
}
for(int j = 0;j < n;++j){
if(j == delete_node){
continue;
}
for(int k = j+1;k < n;++k){
if(k == delete_node){
continue;
}
if(graph[j][k]){
Union(j, k);
}
}
}
for(int j = 0;j < n;++j){
if(j == delete_node){
continue;
}
if(infect[find(j)]){
++temp_M;
}
}
if(temp_M < M){
M = temp_M;
ans = initial[i];
}else if(temp_M == M){
ans = ans < initial[i] ? ans : initial[i];
}
}
return ans;
}
};
其中,infec数组代表是否被感染,true为感染,false为未感染;M初始化为int最小值;遍历感染节点数组循环中的流程是:初始化father和infec数组->把感染节点的infect数组更新->合并节点->统计有多少节点感染->更新M和ans。
另外一种更快的思路就是统计每一个感染节点连接了多少普通节点,删去连接最多普通节点的感染节点即可。所以,对普通节点有一个数组统计每一个普通节点连接的是哪一个感染节点,当如果一个普通节点连接的感染节点数超过一个,则这个普通节点一定会被感染。代码如下。
cpp
#pragma GCC optimize(3, "Ofast", "inline")
auto init_ = [] {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
return 0;
}();
class Solution {
public:
int cnt[301];
bool virus[301];
int father[301];
int infect[301];
int Size[301];
int find(int i){
if(father[i] != i){
father[i] = find(father[i]);
}
return father[i];
}
void Union(int a, int b){
int behalf_a = find(a);
int behalf_b = find(b);
if(behalf_a != behalf_b){
father[behalf_a] = behalf_b;
Size[behalf_b] += Size[behalf_a];
}
}
int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
int n = graph.size();
int len_initial = initial.size();
int ans = initial[0];
int num = 0;
for(int i = 0;i < n;++i){
virus[i] = false;
infect[i] = -1;
Size[i] = 1;
father[i] = i;
cnt[i] = 0;
}
for(int i = 0;i < len_initial;++i){
virus[initial[i]] = true;
}
for(int i = 0;i < n;++i){
for(int j = i+1;j < n;++j){
if(graph[i][j] == 1 && !virus[i] && !virus[j]){
Union(i, j);
}
}
}
for(int i = 0;i < len_initial;++i){
for(int j = 0;j < n;++j){
if(initial[i] != j && !virus[j] && graph[initial[i]][j] == 1){
int behalf = find(j);
if(infect[behalf] == -1){
infect[behalf] = initial[i];
}else if(infect[behalf] != -2 && infect[behalf] != initial[i]){
infect[behalf] = -2;
}
}
}
}
for(int i = 0;i < n;++i){
if(i == find(i) && infect[i] >= 0){
cnt[infect[i]] += Size[i];
}
}
for(int i = 0;i < len_initial;++i){
if(cnt[initial[i]] > num){
num = cnt[initial[i]];
ans = initial[i];
}else if(cnt[initial[i]] == num){
ans = ans < initial[i] ? ans : initial[i];
}
}
return ans;
}
};
其中,开头的代码是开启了O3优化和取消了一些输入输出同步,可以一定程度上加快代码速度,但是在比赛时不推荐使用;cnt数组存储每一个感染节点连接的普通节点个数;virus数组代表节点是否感染;infect数组代表普通节点连接的感染节点的下标,最开始为-1,第一次更新之后大于等于0,如果有第二次更新代表连接两个感染节点,更新为-2;主要流程是:初始化初virus数组的其他数组->初始化virus数组->合并相邻的普通结点->对每个普通结点集合更新连接的感染结点的下标->统计每个感染结点连接的集合的结点数->遍历cnt数组更新ans。