D007 并查集基础题集(上七题)

题目一览

维护集合大小、集合数量:

成环判断、获取所有集合的代表:

区间上的连通块合并:

细节+代码

1676 Road Construction - CSES

初始化 size 数组,每次有效合并时更新 size 数组和当前最大值。
点击查看代码

复制代码
import sys

Max = lambda x, y: x if x > y else y
Min = lambda x, y: x if x < y else y

input_type = 1

if input_type:
    inp = lambda: sys.stdin.readline().strip()

    II = lambda: int(inp())
    MII = lambda: map(int, inp().split())
    LII = lambda: list(MII())

else:
    input_data = sys.stdin.read().split()
    it = iter(input_data)
    
    II = lambda: int(next(it))
    SI = lambda: next(it)
    
    if not input_data:
        sys.exit()

class DSU:
    def __init__(self, n):
        self.fa = list(range(n + 1))
        self.size = [1] * (n + 1)
        self.mx = 1
    
    def find(self, x):
        while self.fa[x] != x:
            self.fa[x] = self.fa[self.fa[x]]
            x = self.fa[x]
        return x
    
    def union(self, x, y):
        rx, ry = self.find(x), self.find(y)
        if rx != ry:
            self.fa[ry] = rx
            self.size[rx] += self.size[ry]  # 注意合并方向 rx -> ry
            self.mx = Max(self.mx, self.size[rx])
            return True
        return False
    
    def get_mx(self):
        return self.mx

def main():
    n, m = MII()

    dsu = DSU(n)

    cc = n
    for _ in range(m):
        u, v = MII()
        if dsu.union(u, v):
            cc -= 1
        print(cc, dsu.get_mx())
    
if __name__ == "__main__":
    main()

P2078 朋友 - 洛谷

因为涉及到的元素为负数,所以不能使用数组储存了。传人数组,使用字典初始化 fa 数组。

又小明的公司全是男的,小红的公司全是女的。所以可以配对的数为他(她)们的朋友集合的较小的那个。
点击查看代码

复制代码
import sys

Max = lambda x, y: x if x > y else y
Min = lambda x, y: x if x < y else y

input_type = 1

if input_type:
    inp = lambda: sys.stdin.readline().strip()

    II = lambda: int(inp())
    MII = lambda: map(int, inp().split())
    LII = lambda: list(MII())

else:
    input_data = sys.stdin.read().split()
    it = iter(input_data)
    
    II = lambda: int(next(it))
    SI = lambda: next(it)
    
    if not input_data:
        sys.exit()

class DSU:
    def __init__(self, a):
        self.fa = {x: x for x in a}
        self.size = {x: 1 for x in a}
    
    def find(self, x):
        while self.fa[x] != x:
            self.fa[x] = self.fa[self.fa[x]]
            x = self.fa[x]
        return x
    
    def union(self, x, y):
        rx, ry = self.find(x), self.find(y)
        if rx != ry:
            self.fa[rx] = ry
            self.size[ry] += self.size[rx]
    
    def get_size(self, x):
        return self.size[self.find(x)]

def main():
    n, m, p, q = MII()

    boy = DSU(range(1, n + 1))
    girl = DSU(range(-m, 0))

    for _ in range(p):
        u, v = MII()
        boy.union(u, v)
    
    for _ in range(q):
        u, v = MII()
        girl.union(u, v)
    
    print(Max(1, Min(boy.get_size(1), girl.get_size(-1))))

if __name__ == "__main__":
    main()

ABC420E Reachability Query - AtCoder

当一个集合内有一个黑色的顶点,这个集合内的就所有顶点都可达到这个黑色顶点。所以我们维护一个集合里是否有黑色的顶点即可。
点击查看代码

复制代码
import sys

Max = lambda x, y: x if x > y else y
Min = lambda x, y: x if x < y else y

input_type = 1

if input_type:
    inp = lambda: sys.stdin.readline().strip()

    II = lambda: int(inp())
    MII = lambda: map(int, inp().split())
    LII = lambda: list(MII())

else:
    input_data = sys.stdin.read().split()
    it = iter(input_data)
    
    II = lambda: int(next(it))
    SI = lambda: next(it)
    
    if not input_data:
        sys.exit()

class DSU:
    def __init__(self, n):
        self.fa = list(range(n + 1))
        self.col = [0] * (n + 1)
        self.size = [0] * (n + 1)
    
    def find(self, x):
        while self.fa[x] != x:
            self.fa[x] = self.fa[self.fa[x]]
            x = self.fa[x]
        return x
    
    def union(self, x, y):
        rx, ry = self.find(x), self.find(y)
        if rx != ry:
            self.fa[rx] = ry
            self.size[ry] += self.size[rx]  
    
    def charge(self, x):
        self.col[x] ^= 1
        rx = self.find(x)
        self.size[rx] += 1 if self.col[x] == 1 else -1
    
    def ok(self, x):
        rx = self.find(x)  
        return self.size[rx] > 0

def main():
    n, q = MII()

    dsu = DSU(n)

    outs = []
    for _ in range(q):
        o = LII()
        op = o[0]

        if op == 1:
            u, v = o[1:]
            dsu.union(u, v)
        elif op == 2:
            x = o[1]
            dsu.charge(x)
        else:
            x = o[1]
            outs.append("Yes" if dsu.ok(x) else "No")

    print('\n'.join(outs))

if __name__ == "__main__":
    main()

1666 Building Roads - CSES

使 \(K\) 连通块连通至少需要 \(K-1\) 条边,只需要将每个集合的根连起来即可。

遍历每个节点,当 fa[x] == x 时,这个 x 就是这个集合根。
点击查看代码

复制代码
import sys

Max = lambda x, y: x if x > y else y
Min = lambda x, y: x if x < y else y

input_type = 1

if input_type:
    inp = lambda: sys.stdin.readline().strip()

    II = lambda: int(inp())
    MII = lambda: map(int, inp().split())
    LII = lambda: list(MII())

else:
    input_data = sys.stdin.read().split()
    it = iter(input_data)
    
    II = lambda: int(next(it))
    SI = lambda: next(it)
    
    if not input_data:
        sys.exit()

class DSU:
    def __init__(self, n):
        self.fa = list(range(n + 1))

    def find(self, x):
        while self.fa[x] != x:
            self.fa[x] = self.fa[self.fa[x]]
            x = self.fa[x]
        return x
    
    def union(self, x, y):
        rx, ry = self.find(x), self.find(y)
        if rx != ry:
            self.fa[rx] = ry

    def get_root(self):
        return [x for x in range(1, len(self.fa)) if self.fa[x] == x]
    

def main():
    n, m = MII()

    dsu = DSU(n)

    for _ in range(m):
        u, v = MII()
        dsu.union(u, v)
    
    root = dsu.get_root()

    print(len(root) - 1)

    for i in range(len(root) - 1):
        print(root[i], root[i + 1])

if __name__ == "__main__":
    main()

Roads not only in Berland - CodeForces

一个有 \(n\) 个顶点,\(n-1\) 条边的无向图,每次操作都删去一条边再加上一条边,使其这个无向图连通。

\(n\) 个顶点 \(n-1\) 条边的连通图是一颗树。原图一定存在若干个环。我们利用并查集判断两个点是否已经在一个集合中,如果已经在一个集合中了,那么这条将要加入的边就是形成环的重边,需要删掉这条。

然后需要加的边就跟上一题一样了,将所有根连起来。
点击查看代码

复制代码
import sys

Max = lambda x, y: x if x > y else y
Min = lambda x, y: x if x < y else y

input_type = 1

if input_type:
    inp = lambda: sys.stdin.readline().strip()

    II = lambda: int(inp())
    MII = lambda: map(int, inp().split())
    LII = lambda: list(MII())

else:
    input_data = sys.stdin.read().split()
    it = iter(input_data)
    
    II = lambda: int(next(it))
    SI = lambda: next(it)
    
    if not input_data:
        sys.exit()

class DSU:
    def __init__(self, n):
        self.n = n
        self.fa = list(range(n + 1))

    def find(self, x):
        while self.fa[x] != x:
            self.fa[x] = self.fa[self.fa[x]]
            x = self.fa[x]
        return x
    
    def union(self, x, y):
        rx, ry = self.find(x), self.find(y)
        if rx != ry:
            self.fa[rx] = ry
            return True
        return False
    
    def get_root(self):
        return [x for x in range(1, self.n + 1) if self.fa[x] == x]
        

def main():
    n = II()

    dsu = DSU(n)

    de = []
    for _ in range(n - 1):
        u, v = MII()
        if not dsu.union(u, v):
            de.append((u, v))  #! 成环的边一定要删

    k = len(de)

    print(k)
    
    ae = []
    fa = dsu.get_root()
    for i in range(len(fa) - 1):
        ae.append((fa[i], fa[i + 1]))

    for i in range(k):
        print(*de[i], *ae[i])

if __name__ == "__main__":
    main()

Cycle Graph? - AtCoder

判断单环的条件是:

  1. 只有一个连通分量,所有节点都应该在这个连通分量里。
  2. 度数都为 \(2\) 。

第一个条件使用并查集判断,枚举 [1, n]1 进行判断 ,如果有节点没有在同一个连通分量里或度数不为 \(0\) 就输出 No
点击查看代码

复制代码
import sys

Max = lambda x, y: x if x > y else y
Min = lambda x, y: x if x < y else y

input_type = 1

if input_type:
    inp = lambda: sys.stdin.readline().strip()

    II = lambda: int(inp())
    MII = lambda: map(int, inp().split())
    LII = lambda: list(MII())

else:
    input_data = sys.stdin.read().split()
    it = iter(input_data)
    
    II = lambda: int(next(it))
    SI = lambda: next(it)
    
    if not input_data:
        sys.exit()

class DSU:
    def __init__(self, n):
        self.fa = list(range(n + 1))

    def find(self, x):
        while self.fa[x] != x:
            self.fa[x] = self.fa[self.fa[x]]
            x = self.fa[x]
        return x

    def union(self, x, y):
        rx, ry = self.find(x), self.find(y)
        if rx != ry:
            self.fa[rx] = ry

    def is_same(self, x, y):
        return self.find(x) == self.find(y)

def main():
    n, m = MII()

    deg = [0] * (n + 1)

    dsu = DSU(n)

    for _ in range(m):
        u, v = MII()
        deg[u] += 1
        deg[v] += 1
        dsu.union(u, v)

    for i in range(1, n + 1):
        if not dsu.is_same(1, i) or deg[i] != 2:
            print("No")
            return

    print("Yes")

if __name__ == "__main__":
    main()

小苯的蓄水池(hard)- 牛客

使用并查集将 [l, r] 的水池合并,并维护水池的大小和个数。不过在合并过程中不能直接使用 for 一个一个合并,数据很大会超时。
点击查看代码

复制代码
import sys

Max = lambda x, y: x if x > y else y
Min = lambda x, y: x if x < y else y

input_type = 0

if input_type:
    inp = lambda: sys.stdin.readline().strip()

    II = lambda: int(inp())
    MII = lambda: map(int, inp().split())
    LII = lambda: list(MII())

else:
    input_data = sys.stdin.read().split()
    it = iter(input_data)
    
    II = lambda: int(next(it))
    SI = lambda: next(it)
    
    if not input_data:
        sys.exit()

class DSU:
    def __init__(self, n, a):
        self.fa = list(range(n + 1))
        self.size = [1] * (n + 1)
        self.water = [0] * (n + 1)
        for i in range(n):
            self.water[i + 1] = a[i]
    
    def find(self, x):
        while self.fa[x] != x:
            self.fa[x] = self.fa[self.fa[x]]
            x = self.fa[x]
        return x
    
    def union(self, x, y):
        rx, ry = self.find(x), self.find(y)
        if rx != ry:
            self.fa[rx] = ry
            self.size[ry] += self.size[rx]
            self.water[ry] += self.water[rx]
    
    def cal(self, x):
        rx = self.find(x)
        res = self.water[rx] / self.size[rx]
        return f"{res:.10f}"

def main():
    n  = II()
    q = II()
    a = [II() for _ in range(n)]

    dsu = DSU(n, a)

    outs = []
    for _ in range(q):
        op = II()
        if op == 1:
            l, r = II(), II()
            cur = dsu.find(l)  
            while cur < r:
                dsu.union(cur, cur + 1)  # 要有方向的合并
                cur = dsu.find(cur)

        elif op == 2:
            x = II()
            ret = dsu.cal(x)
            outs.append(str(ret))

    print('\n'.join(outs))

if __name__ == "__main__":
    main()
相关推荐
程序员何未来1 年前
AcWing算法基础课-790数的三次方根-Java题解
java·数据结构·算法·算法竞赛·算法题解·acwing算法基础课·计算机算法