等待组(waitgroup)

在之前的文章里曾经简单的说过等待组的使用。这一章节会深入的来介绍。

在之前聊 Golang GMP一.GMP调度器 时,我和大家讨论了 goroutine 的调度方式分为主动让渡和被动调度. 其中触发被动调度的常见方式包括通道 channel 和单机锁 sync.Mutex. 在此之上,今天再补充另一种可能触发 goroutine 被动调度的工具------并发等待组 sync.WaitGroup

在并发编程中,如何使得异步运行的 goroutine 之间建立一种默契的协作关系,这是一个非常关键的话题. channel 是达成这个目标的一种实现方式,不同 goroutine 之间可以通过并发通道 channel 完成信息的传递,从而促成协作关系. (如果想展开这个话题,可以阅读我之前发表的文章------2.通道(chan)

当 goroutine 之间需要建立明确的层级关系. 倘若父 goroutine 希望持有子 goroutine 的生杀大权,并且保证父 goroutine 消亡时能连带回收其创建的所有子 goroutine ,此时可以使用到 Golang 上下文工具 context,完成父 goroutine 对 子 goroutine 的生命周期控制(如果想展开这个话题,可以阅读我之前发表的文章------4.上下文(context)

在具备了以上知识之后,今天我们百尺竿头更进一步,探讨一种新的 goroutine 协作机制------等待聚合模式.

在这种模式下,父 goroutine 在创建一系列子 goroutine 后,可以选择在一个合适的时机对所有子 goroutine 的执行结果进行等待聚合,直到所有子 goroutine 都执行完成之后,父 goroutine 才会继续往前推进. 要达成这种协作模式最合适的工具,就是我们今天要聊的主角------Golang 中的并发等待组工具 sync.WaitGroup.

一.等待组

1.1 介绍

首先先介绍等待组是什么?

Go语言中的等待组(sync.WaitGroup)是一种用于协调多个 goroutine 同步执行的机制。

它通过计数器来跟踪正在执行的 goroutine 数量,允许主 goroutine 等待所有子 goroutine 完成后再继续执行。

为什么要使用等待组?

  1. 防止主 goroutine 提前退出
    默认情况下,主 goroutine 不会等待子 goroutine 完成。若不使用同步机制,子 goroutine 可能未执行完毕,程序就退出了。等待组通过 Wait() 阻塞主流程,确保所有子任务完成。
  2. 简化多 goroutine 同步
    相比通过 channel 逐个通知完成,等待组更简洁。尤其当 goroutine 数量动态变化时,无需手动管理多个信号。
  3. 避免竞态条件
    内部通过原子操作或锁保证计数器线程安全,开发者无需自行处理并发问题。

与channel对比

实际上使用channel也可以实现上述操作,但是为什么不直接使用channel呢?

复制代码
done := make(chan bool)
for i := 0; i < 3; i++ {
    go func() {
        // 任务代码
        done <- true
    }()
}
for i := 0; i < 3; i++ { <-done }

你会发现这样必须手动控制,比较繁琐

  • 主 goroutine 需要在一开始就明确启动的子 goroutine 数量,从而建立好对应容量的 channel,以及设定执行 for 循环接收信号量的次数. 这样的设定不够灵活,也就是子协程无法添加和减少。

1.2 操作

sync.WaitGroup有以下三个方法:

|----------------------------------|------------|
| 方法名 | 功能 |
| (wg * WaitGroup) Add(delta int) | 计数器+delta |
| (wg *WaitGroup) Done() | 计数器-1 |
| (wg *WaitGroup) Wait() | 阻塞直到计数器变为0 |

sync.WaitGroup内部维护着一个计数器,计数器的值可以增加和减少。

例如当我们启动了N 个并发任务时,就将计数器值增加N。每个任务完成时通过调用Done()方法将计数器减1。

通过调用Wait()来等待并发任务执行完,当计数器值为0时,表示所有并发任务已经完成。

复制代码
package main

import (
	"fmt"
	"strconv"
	"sync"
)

var wg sync.WaitGroup //只定义不需要赋值

func main() { //主线程
	for i := 1; i <= 5; i++ {
		wg.Add(1) //协程开始时加1
		go func(n int) {
            defer wg.Done() //协程执行完-1
			fmt.Println("结果为:" + strconv.Itoa(n))
		}(i) // 要开启协程在前面加个go
	}

	wg.Wait() //这里就是相当于一个阻塞,什么时候上面的计数器为0的时候,就结束
}

使用注意事项

  1. **Add()**调用时机
    必须在启动 goroutine 调用 Add(),避免竞态条件(如主流程先执行 Wait(),而计数器尚未增加)。
  2. 传递 WaitGroup****的指针
    方法接收者应为指针(如 func worker(wg *sync.WaitGroup)),否则会导致副本操作,计数器无法正确归零。
  3. 使用 defer****调用 Done()
    确保即使 goroutine 发生 panic,计数器仍能正确递减,避免死锁。
  4. 避免重用
    一个 WaitGroup 完成一次同步后,若需复用,应创建新实例。重置或重用可能导致不可预知的行为。

二.具体实现

2.1 版本一

这里的具体场景是指waitGroup+channel完成的数据聚合。会举小徐先生的例子,抛开过多的杂话。

来看一段简单的代码,实现的是一个异步向数组添加数据的操作,使用等待组确认数据都加入,最后关闭打印这个数组.

复制代码
package main

import (
	"fmt"
	"sync"
	"time"
)

func main() {
	tasksNum := 10
	dataCh := make(chan interface{})
	resp := make([]interface{}, 0, tasksNum)
	// 启动读 goroutine
	go func() {
		for data := range dataCh {
			resp = append(resp, data)
		}
	}()

	// 保证获取到所有数据后,通过 channel 传递到读协程手中
	var wg sync.WaitGroup
	for i := 0; i < tasksNum; i++ {
		wg.Add(1)
		go func(ch chan<- interface{}) {
			defer wg.Done()
			ch <- time.Now().UnixNano()
		}(dataCh)
	}
	// 确保所有取数据的协程都完成了工作,才关闭 ch
	wg.Wait()
	close(dataCh)

	fmt.Println("resp: %+v", resp)
}

下面考验一下大家对并发的敏感度,看完上述代码之后,大家有没有找到其中存在的并发的问题呢?

这里就不卖关子了,就直接来说一下吧:

这里的主要问题就是主 goroutine 在通过 WaitGroup.Wait 方法确保子 goroutine 都完成任务后,会关闭 dataCh ,并直接获取 resp slice 进行打印. 此时 dataCh 虽然关闭了,但是由于异步的不确定性,读 goroutine 可能还没来得及将所有数据都聚合到 resp slice 当中,因此主 goroutine 拿到的 resp slice 可能存在数据缺失.

大家可以运行一下看看结果

2.2 版本二

之前存在的问题是,主 goroutine 可能在读 goroutine 完成数据聚合前,就已经取用了 resp slice. 那么我们就额外启用一个用于标识读 goroutine 是否执行结束的 channel:stopCh 即可. 具体步骤包括:

  • 主 goroutine 关闭 dataCh 之后,不是立即取用 resp slice,而是会先尝试从 stopCh 中读取信号,读取成功后,才继续往下
  • 读 goroutine 在退出前,往 stopCh 中塞入信号量,让主 goroutine 能够感知到读 goroutine 处理完成这一事件

这样处理之后,逻辑是严谨的,主 goroutine 能够保证取得的 resp slice 所拥有的完整数据。

也就是通过一个额外的channel来判断是否完成了这个任务,使用这个chan进行一个阻塞。

复制代码
package main

import (
	"fmt"
	"sync"
	"time"
)

func main() {
	tasksNum := 10

	dataCh := make(chan interface{})
	resp := make([]interface{}, 0, tasksNum)
	stopCh := make(chan struct{}, 1)
	// 启动读 goroutine
	go func() {
		for data := range dataCh {
			resp = append(resp, data)
		}
		stopCh <- struct{}{}
	}()

	// 保证获取到所有数据后,通过 channel 传递到读协程手中
	var wg sync.WaitGroup
	for i := 0; i < tasksNum; i++ {
		wg.Add(1)
		go func(ch chan<- interface{}) {
			defer wg.Done()
			ch <- time.Now().UnixNano()
		}(dataCh)
	}
	// 确保所有取数据的协程都完成了工作,才关闭 ch
	wg.Wait()
	close(dataCh)

	// 确保读协程处理完成
	<-stopCh

	fmt.Println("resp: %+v", resp)
}

2.3 版本三

版本 2.0 需要额外引入一个 stopCh,用于主 goroutine 和读 goroutine 之间的通信交互,看起来总觉得不够优雅. 下面我们就较真一下,针对于如何省去这个小小的 channel,进行版本 3.0 的方案探讨.

  • 同样创建一个无缓冲的 dataCh,用于聚合数据的传递

  • 异步启动一个总览写流程的写 goroutine,在这个写 goroutine 中,基于 WaitGroup 使用模式,让写 goroutine 中进一步启动的子 goroutine 在完成工作后,将数据发送到 dataCh 当中

  • 写 goroutine 基于 WaitGroup.Wait 操作,在确保所有子 goroutine 完成工作后,关闭 dataCh

  • 接下来,让主 goroutine 同时扮演读 goroutine 的角色,通过 for range 的方式持续遍历接收 dataCh 当中的数据,将其填充到 resp slice

  • 当写 goroutine 关闭 dataCh 后,主 goroutine 才能结束遍历流程,从而确保能够取得完整的 resp 数据

    package main

    import (
    "fmt"
    "sync"
    "time"
    )

    func main() {
    tasksNum := 10

    复制代码
      dataCh := make(chan interface{})
      // 启动写 goroutine,推进并发获取数据进程,将获取到的数据聚合到 channel 中
      go func() {
      	// 保证获取到所有数据后,通过 channel 传递到读协程中
      	var wg sync.WaitGroup
      	for i := 0; i < tasksNum; i++ {
      		wg.Add(1)
      		go func(ch chan<- interface{}) {
      			defer wg.Done()
      			ch <- time.Now().UnixNano()
      		}(dataCh)
      	}
      	// 确保所有取数据的协程都完成了工作,才关闭 ch
      	wg.Wait()
      	close(dataCh)
      }()
    
      resp := make([]interface{}, 0, tasksNum)
      // 主协程作为读协程,持续读取数据,直到所有写协程完成任务,chan 被关闭后才会往下
      for data := range dataCh {
      	resp = append(resp, data)
      }
      fmt.Println("resp: %+v", resp)

    }

这里对比一下版本一和版本三的区别吧:

在版本一里面是以写协程作为主协程,就会导致数据写完之后主协程直接关闭了chan,导致部分数据还没被读入slice就结束了,所有会出现数据丢失的问题

在版本三中,是以读协程为主协程,启动写协程并发聚合数据,之后才会关闭管道,在还没关闭之前,在主协程中管道是阻塞的,写入一个便读取一个不会关闭,直到写协程关闭才会关闭。

三.源码走读

复制代码
type WaitGroup struct {
    // 防止值拷贝标记
	noCopy noCopy

    // 64 个 bit 组成的状态值,
    //高 32 位标识了当前需要等待多少个 goroutine 执行了 WaitGroup.Add,还没执行 WaitGroup.Done
    // 低 32 位表示了当前多少 goroutine 执行了 WaitGroup.Wait 操作陷入阻塞中了
	state atomic.Uint64 
    
    // // 用于将 goroutine 阻塞和唤醒的信号量
	sema  uint32
}
  • noCopy:这是防拷贝标识,标记了 WaitGroup 不应该用于值传递
  • state1:这是 WaitGroup 的核心字段,是一个无符号的64位整数,高32位是 WaitGroup 中并发计数器的数值,即当前 WaitGroup.Add 与 WaitGroup.Done 之间的差值;低 32 位标识了,当前有多少 goroutine 因 WaitGroup.Wait 操作而处于阻塞态,陷入阻塞态的原因是因为计数器的值没有清零,即 state1 字段高 32 位是一个正值
  • state2:用于阻塞和唤醒 goroutine 的信号量

2.1 Add函数

这看源码之前呢,里面会涉及race包,如果想要具体了解的话,可以先看0.race包这篇文章。

会简单说一下用到的函数:

复制代码
// .....  

// 标识是否启用竞态检测(竞态检测(Race Detection)是一种用于发现程序中 数据竞争(Data Race) 的技术)
const Enabled = true

func Acquire(addr unsafe.Pointer) {
	runtime.RaceAcquire(addr)
}

// race.ReleaseMerge(unsafe.Pointer(wg)) 的作用是 
// 向竞态检测器(Race Detector)标记同步事件的合并,
// 确保它能正确理解 WaitGroup 的计数器递减(Done())与等待(Wait())之间的同步关系
// 也就是合并 Done() 的同步事件,确保 Wait() 正确同步。
func ReleaseMerge(addr unsafe.Pointer) {
	runtime.RaceReleaseMerge(addr)
}

// 临时关闭竞态检测
func Disable() {
	runtime.RaceDisable()
}
// 打开竞态检测
func Enable() {
	runtime.RaceEnable()
}

func Read(addr unsafe.Pointer) {
	runtime.RaceRead(addr)
}

func Write(addr unsafe.Pointer) {
	runtime.RaceWrite(addr)
}

func (wg *WaitGroup) Add(delta int) {
    // 
    if race.Enabled {
        if delta < 0 {
            // Synchronize decrements with Wait.
            race.ReleaseMerge(unsafe.Pointer(wg))
        }
        race.Disable()
        defer race.Enable()
    }
    // state1 高 32 位加 1,标识执行任务数量加 1 
    state := wg.state.Add(uint64(delta) << 32)

    // 取的是 state 高 32 位的值,代表有多少个 goroutine 在执行任务
    v := int32(state >> 32)

    // w 取的是 state 低 32 位的值,代表有多少个 goroutine 执行了 WaitGroup.Wait 在阻塞等待
    w := uint32(state)
    
    if race.Enabled && delta > 0 && v == int32(delta) {
        race.Read(unsafe.Pointer(&wg.sema))
    }
    
    // 不能出现负值的执行任务计数器
    if v < 0 {
        panic("sync: negative WaitGroup counter")
    }
    // 倘若存在 goroutine 在阻塞等待 WaitGroup.Wait,
    // 但是在执行 WaitGroup.Add 前,执行任务计数器的值为 0
    if w != 0 && delta > 0 && v == int32(delta) {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    // 倘若当前没有 goroutine 在 Wait,或者任务执行计数器仍大于 0,则直接返回
    if v > 0 || w == 0 {
        return
    }
    // 在执行过 WaitGroup.Wait 操作的情况下,
    // WaitGroup.Add 操作不应该并发执行,否则可能导致 panic
    if wg.state.Load() != state {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    // 将 state1 计数器置为 0,然后依次唤醒执行过 Wait 的 waiters
    wg.state.Store(0)
    for ; w != 0; w-- {
        runtime_Semrelease(&wg.sema, false, 0)
    }
}

2.2 Done

复制代码
// Done decrements the [WaitGroup] counter by one.
func (wg *WaitGroup) Done() {
	wg.Add(-1)
}

这个就很简单就是-1呗。

2.3 Wait

复制代码
func (wg *WaitGroup) Wait() {
    // 关闭竞态检测
	if race.Enabled {
		race.Disable()
	}
	for {
		state := wg.state.Load()
        // 取的是 state 高 32 位的值,代表有多少个 goroutine 在执行任务
		v := int32(state >> 32)
        // w 取的是 state 低 32 位的值,代表有多少个 goroutine 执行了 WaitGroup.Wait 在阻塞等待
		w := uint32(state)
		if v == 0 {
			// 倘若当前需要等待完成任务的计数器值为 0,则无需 wait 直接返回
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
		// wait 阻塞等待 waitGroup 的计数器加一,然后陷入阻塞
		if wg.state.CompareAndSwap(state, state+1) {
			if race.Enabled && w == 0 {
                
				race.Write(unsafe.Pointer(&wg.sema))
			}
			runtime_Semacquire(&wg.sema)
            // 从阻塞中回复,倘若前一轮 wait 操作还没结束,waitGroup 又被使用了,则会 panic
			if wg.state.Load() != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
	}
}
相关推荐
weixin_377634845 分钟前
【python异步多线程】异步多线程爬虫代码示例
开发语言·爬虫·python
struggle202520 分钟前
PennyLane 是一个用于量子计算、量子机器学习和量子化学的跨平台 Python 库。由研究人员构建,用于研究
python·量子计算
扑克中的黑桃A20 分钟前
Python-素数
python
kymjs张涛21 分钟前
前沿技术周刊 2025-06-09
android·前端·ios
扑克中的黑桃A21 分钟前
Python学习的自我理解和想法(4)
python
前端康师傅25 分钟前
JavaScript 变量详解
前端·javascript
Sun_light26 分钟前
队列:先进先出的线性数据结构及其应用
前端·javascript·算法
Data_Adventure29 分钟前
如何在本地测试自己开发的 npm 包
前端·vue.js·svg
扑克中的黑桃A30 分钟前
Python-打印杨辉三角(进阶版)
python
萌萌哒草头将军40 分钟前
⚓️ Oxlint 1.0 版本发布,比 ESLint 快50 到 100 倍!🚀🚀🚀
前端·javascript·vue.js