🐉大家好,我是gopher_looklook,现任某独角兽企业Go语言工程师,喜欢钻研Go源码,发掘各项技术在大型Go微服务项目中的最佳实践,期待与各位小伙伴多多交流,共同进步!
概念
sync.WaitGroup
是Go语言中用于协调多个goroutine同步的核心工具,用于等待一组goroutine完成它们的任务。
简单示例
go
package main
import (
"fmt"
"sync"
"time"
)
func worker(id int, wg *sync.WaitGroup) {
defer wg.Done()
fmt.Printf("Worker %d starting\n", id)
time.Sleep(1 * time.Second)
fmt.Printf("Worker %d done\n", id)
}
func main() {
var wg sync.WaitGroup
numWorkers := 3
wg.Add(numWorkers)
for i := 0; i < numWorkers; i++ {
go worker(i, &wg)
}
wg.Wait() // 最常见的用法,此时只有一个等待者
fmt.Println("All workers have finished")
}
- 输出
- 分析
可以看到,当调用wg.Add(numWorkers)
时,表示我们要执行numWorkers组子goroutine,执行wg.Wait()
则会阻塞当前goroutine。只有当全部子goroutine全部执行完成,最后一个子goroutine执行完wg.Done()
后,当前被wg.Wait()
阻塞的goroutine才能继续往下执行。
背景知识-Go源码中信号量的实现
Go语言提供了两个函数用于实现信号量的控制runtime_Semrelease
和runtime_Semacquire
。
go
func runtime_Semrelease(s *uint32, handoff bool, skipframes int)
func runtime_Semacquire(s *uint32)
- runtime_Semrelease 函数的主要作用是释放信号量。
- 当一个 goroutine 使用完资源后,会调用 runtime_Semrelease 将信号量的值加 1,表示释放了一个资源。
- runtime_Semacquire 函数的主要作用是尝试获取信号量。
- 如果当前信号量的值大于 0,runtime_Semacquire 会将信号量的值减 1,表示成功获取了一个资源,然后立即返回,调用该函数的 goroutine 可以继续执行后续操作。
- 如果当前信号量的值为 0,说明没有可用资源,调用 runtime_Semacquire 的 goroutine 会被阻塞。该 goroutine 会被放入等待队列中,直到有其他 goroutine 调用 runtime_Semrelease 释放信号量,才有可能被唤醒继续执行。
sync.WaitGroup
正是使用了获取和释放信号量的操作,实现了等待一组子goroutine完成任务,并通知到正在等待中的goroutine的功能。
源码解读
-
go源码版本:1.23.0
-
源码
go
package sync
import (
"sync/atomic"
"unsafe"
)
type WaitGroup struct {
noCopy noCopy
state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
sema uint32
}
func (wg *WaitGroup) Add(delta int) {
state := wg.state.Add(uint64(delta) << 32)
v := int32(state >> 32)
w := uint32(state)
if v < 0 {
panic("sync: negative WaitGroup counter")
}
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
if v > 0 || w == 0 {
return
}
// This goroutine has set counter to 0 when waiters > 0.
// Now there can't be concurrent mutations of state:
// - Adds must not happen concurrently with Wait,
// - Wait does not increment waiters if it sees counter == 0.
// Still do a cheap sanity check to detect WaitGroup misuse.
if wg.state.Load() != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// Reset waiters count to 0.
wg.state.Store(0)
for ; w != 0; w-- {
runtime_Semrelease(&wg.sema, false, 0)
}
}
// Done decrements the [WaitGroup] counter by one.
func (wg *WaitGroup) Done() {
wg.Add(-1)
}
// Wait blocks until the [WaitGroup] counter is zero.
func (wg *WaitGroup) Wait() {
for {
state := wg.state.Load()
v := int32(state >> 32)
w := uint32(state)
// Increment waiters count.
if wg.state.CompareAndSwap(state, state+1) {
runtime_Semacquire(&wg.sema)
if wg.state.Load() != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
return
}
}
}
这里我删减了部分与竞态检测相关的代码。可以看到,sync.WaitGroup
的核心功能集中体现在3个方法(Add/Done/Wait)上。 这3个方法互相依赖,相辅相成,需要一起配合使用。
WaitGroup.Add
go
func (wg *WaitGroup) Add(delta int) {
state := wg.state.Add(uint64(delta) << 32)
v := int32(state >> 32)
w := uint32(state)
if v < 0 {
panic("sync: negative WaitGroup counter")
}
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
if v > 0 || w == 0 {
return
}
// This goroutine has set counter to 0 when waiters > 0.
// Now there can't be concurrent mutations of state:
// - Adds must not happen concurrently with Wait,
// - Wait does not increment waiters if it sees counter == 0.
// Still do a cheap sanity check to detect WaitGroup misuse.
if wg.state.Load() != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// Reset waiters count to 0.
wg.state.Store(0)
for ; w != 0; w-- {
runtime_Semrelease(&wg.sema, false, 0)
}
}
首先是Add 方法。Add 方法会将64位int型的整数 wg.state 拆分成了高32位和低32位,分别用于存储不同含义的值:v 和 w。
- v:计数器,代表有多少个子goroutine未完成任务。
- w:等待者数量,代表有多少个正在等待所有任务完成的goroutine的数量。
在我们上述的例子中,有1个主goroutine在等待3个子goroutine完成任务。因此v=3,w=1,但是w的值需要等到调用Wait方法时才设置。
刚开始调用Add 方法时,记录了要等待完成任务的子goroutine数量,存储在wg.state字段的高32位。赋值完成后,v等于3,w等于0(还没赋值)。经过几个异常判断的if条件检验。程序会在以下代码跳出Add函数。
go
func (wg *WaitGroup) Add(delta int) {
......
if v > 0 || w == 0 {
return
}
......
}
WiatGroup.Wait
go
func (wg *WaitGroup) Wait() {
for {
state := wg.state.Load()
v := int32(state >> 32)
w := uint32(state)
// Increment waiters count.
if wg.state.CompareAndSwap(state, state+1) {
runtime_Semacquire(&wg.sema)
if wg.state.Load() != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
return
}
}
}
Wait() 函数的执行时间在Done() 之前,因此优先分析Wait函数。Wait函数 写了一个for循环。在for循环里,首先将wg.state 的高32位和低32位拆分开,分别赋值给v和w,分别代表计数器和等待者数量。计数器不需要更新,而等待者数量w每调用一次Wait函数都需要自增1,也就是下面这行代码:
go
wg.state.CompareAndSwap(state, state+1)
之后会调用runtime_Semacquire函 数,确保调用Wait函数的地方都会阻塞当前goroutine的进一步执行。
go
runtime_Semacquire(&wg.sema)
WaitGroup.Done
go
func (wg *WaitGroup) Done() {
wg.Add(-1)
}
通过以上对Add函数 和Wait函数 分析,我们知道了在等待所有子goroutine完成任务之前,外层等待的goroutine都会被阻塞,阻塞的原因是由于Wait函数在for循环中尝试获取信号量,但是并没有可用的信号量可以获取。
Done函数 实际调用了Add函数 ,因此只需要分析Add函数的执行过程即可。在最后一个子goroutine执行Done()之前,每次调用Done函数,都会将wg.state 字段的高32位减1,即将计数器减1。并在以下这行代码跳出Add函数。
go
func (wg *WaitGroup) Add(delta int) {
......
if v > 0 || w == 0 {
return
}
......
}
最后一个子goroutine调用Done函数时,wg.state字段的32位减到0(计数器归0) ,w>0(存在等待者)。此时会将wg.state重新设置为0,并执行到下面这段代码,用于释放w个信号量,唤醒w个阻塞的等待者。之后退出Add函数 ,也即正常退出Done函数。
go
func (wg *WaitGroup) Add(delta int) {
for ; w != 0; w-- {
runtime_Semrelease(&wg.sema, false, 0)
}
}
处于等待的goroutine(在我们的例子当中是调用了wg.Wait函数的main goroutine)由于runtime_Semrelease释放了信号,for循环尝试获取信号量成功,wg.state也已经被重新设置为0,执行return跳出for循环,程序不再阻塞。
go
func (wg *WaitGroup) Wait() {
for {
。。。。。。
if wg.state.CompareAndSwap(state, state+1) {
runtime_Semacquire(&wg.sema)
if wg.state.Load() != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
return
}
}
}
工作流程
为了方便理解,我们可以利用流程图来演示程序执行的不同时刻,各个goroutine的执行情况。这里以上述的简单示例代码为例。
常见误区
大部分情况下,我们在使用WaitGroup时只会调用一次wg.Wait() 函数。通过上述对源码的分析,我们知道WaitGroup其实是允许有多个goroutine等待操作完成的,例如下面这段代码。
go
package main
import (
"fmt"
"sync"
"sync/atomic"
"time"
)
func main() {
var wg sync.WaitGroup
// 设置等待组的计数器为 3,代表有 3 个任务要完成
wg.Add(3)
var num atomic.Int32
// 模拟 3 个子goroutine完成
for i := 0; i < 3; i++ {
_i := i
go func() {
time.Sleep(2 * time.Second)
num.Add(1)
fmt.Printf("sub goroutine %d is working. \n", _i)
wg.Done()
}()
}
// 创建 10 个等待的 goroutine
for i := 0; i < 10; i++ {
go func(id int) {
// 进入等待状态,直到等待组的计数器变为 0
wg.Wait()
fmt.Printf("The waiting goroutine %d has been awakened, get num: %d\n", id, num.Load())
}(i)
}
// 等待一段时间,确保所有输出都能显示
time.Sleep(4 * time.Second)
fmt.Println("All goroutines have been awakened.")
}
总结
在本篇文章中,我们通过一段常见的示例代码,演示了sync.WaitGroup的基础用法。之后按照程序执行时间线逐步分析源码,探究了sync.WaitGroup可以等待一组子goroutine执行完成的原因。
以上便是我对WaitGroup源码的分析和总结,如果这篇文章对屏幕前的你有帮助的话,欢迎点赞+关注,你的支持是我创作的最大动力!