Go WaitGroup 源码解析

Go WaitGroup

WaitGroup 是 Go 用于等待一组 goroutine 完成的同步原语。

结构体说明

Go 1.21+ 后 WaitGroup 的核心结构体

go 复制代码
type WaitGroup struct {
	noCopy noCopy

	// Bits (high to low):
	//   bits[0:32]  counter
	//   bits[32]    flag: synctest bubble membership
	//   bits[33:64] wait count
	state atomic.Uint64
	sema  uint32
}

noCopy noCopy

  • 作用 :这是一个「编译期检查工具」,不是运行时字段,目的是禁止 WaitGroup 被值拷贝
  • 底层逻辑noCopy 是 sync 包内部的空结构体(type noCopy struct{}),Go 编译器的 vet 工具会检测到这个字段,若代码中出现 wg2 := wg1 这类值拷贝操作,执行 go vet 时会抛出警告
  • 禁止拷贝原因 : waitgroup 依赖state(原子变量)和sema(信号量)的内存地址实现同步,如果拷贝,新的waitgroup会拥有独立的statesema,原有的同步逻辑会失效,导致程序异常。

state atomic.Uint64
核心状态字段 ,通过 64 位无符号整数的「位分段」,把 3 个逻辑状态打包到一个原子变量中,既节省内存,又能通过原子操作保证并发安全。

位范围 长度 含义
bits[0:32] 32 counter: 未完成的goroutine数目
bits[32] 1 flag:仅用于 Go 内部的 synctest 测试框架
bits[33:64] 31 wait count:等待counter 归0的goroutine数量
  • 原子操作stateatomic.Uint64 类型,意味着对它的读写(比如修改 counter、wait count)都是原子的,无需额外加锁,保证多 goroutine 并发操作时的状态一致性

sema uint32

这是 WaitGroup 依赖的信号量(semaphore) ,用于实现 goroutine 的阻塞和唤醒

  • 底层逻辑 : Go 运行时(runtime)提供了基于信号量的两个核心函数:
    • runtime_Semacquire(&wg.sema):阻塞当前 goroutine,直到信号量可用;
    • runtime_Semrelease(&wg.sema, false, 0):释放信号量,唤醒阻塞的 goroutine
  • 使用场景 :
    • 当调用 wg.Wait() 且 counter > 0 时,当前 goroutine 会通过 runtime_Semacquire 阻塞在 sema 上
    • 当 counter 归 0 时,WaitGroup 会调用 runtime_Semrelease 释放信号量,唤醒所有等待的 goroutine

基本用法

go 复制代码
    func (wg *WaitGroup) Add(delta int)
    func (wg *WaitGroup) Done()
    func (wg *WaitGroup) Wait()
  • 方法说明
    • Add(n int):设置需要等待的 goroutine 数量(必须在 Wait() 前调用)
    • Done():等价于 Add(-1),标记一个 goroutine 完成;
    • Wait():阻塞当前 goroutine,直到所有标记的 goroutine 都调用了 Done()
  • 规则:
    • Add() 必须在 Wait() 前调用,且 Done() 调用次数 = Add() 的正数次数
  • 使用场景
    • 任务编排启动多个 goroutine 执行任务,主 goroutine 需要等待子 goroutine 都完成后才继续执行

使用示例

go 复制代码
package main

import (
	"fmt"
	"sync"
	"time"
)

func main() {
	// 1. 初始化 WaitGroup
	var wg sync.WaitGroup

	// 2. 设置需要等待的 goroutine 数量(这里是 3)
	wg.Add(3)

	// 3. 启动多个 goroutine
	for i := 1; i <= 3; i++ {
		go func(id int) {
			// 4. 延迟调用 Done(确保 goroutine 执行完后标记完成)
			defer wg.Done()

			fmt.Printf("goroutine %d 开始执行\n", id)
			// 模拟耗时操作(比如业务逻辑)
			time.Sleep(time.Second * 1)
			fmt.Printf("goroutine %d 执行完成\n", id)
		}(i)
	}

	fmt.Println("主 goroutine 等待所有工作 goroutine 完成...")
	// 5. 阻塞主 goroutine,直到所有 Done 被调用
	wg.Wait()

	fmt.Println("所有 goroutine 执行完毕,主 goroutine 继续执行")
}

源码解读

ps : 竞态检测 和 synctest 测试框架相关可以忽略

Add

Add 方法的核心是原子修改 counter,并在 counter 归 0 且有等待者时唤醒所有等待的 goroutine

go 复制代码
func (wg *WaitGroup) Add(delta int) {
    // 竞态检测
    // 运行时候指定 -race 就会启用
	if race.Enabled {
		// 辅助检测 WaitGroup 的并发使用是否存在竞态问题
		if delta < 0 {
			// Synchronize decrements with Wait.
			race.ReleaseMerge(unsafe.Pointer(wg))
		}
		race.Disable()
		defer race.Enable()
	}
	bubbled := false
	// synctest 测试框架的 bubble 检测 
	if synctest.IsInBubble() {
		// If Add is called from within a bubble, then all Add calls must be made
		// from the same bubble.
		switch synctest.Associate(wg) {
		case synctest.Unbubbled:
		case synctest.OtherBubble:
			// wg is already associated with a different bubble.
			fatal("sync: WaitGroup.Add called from multiple synctest bubbles")
		case synctest.CurrentBubble:
			bubbled = true
			state := wg.state.Or(waitGroupBubbleFlag)
			if state != 0 && state&waitGroupBubbleFlag == 0 {
				// Add has been called from outside this bubble.
				fatal("sync: WaitGroup.Add called from inside and outside synctest bubble")
			}
		}
	}
	//  原子修改 counter
	state := wg.state.Add(uint64(delta) << 32) 
	//  低 32 位是 counter,这里把 delta 左移 32 位(<<32),刚好对应 counter 的位置
	if state&waitGroupBubbleFlag != 0 && !bubbled {
		// Add has been called from within a synctest bubble (and we aren't in one).
		fatal("sync: WaitGroup.Add called from inside and outside synctest bubble")
	}
	v := int32(state >> 32)  	// counter(未完成数)
	// v: 把 state 右移 32 位,提取出低 32 位的 counter(转成 int32 是因为 counter 可正可负)
	w := uint32(state & 0x7fffffff) // wait count(等待数)
	// 用 0x7fffffff(二进制是 31 个 1)做按位与,提取出 state 中 33-64 位的 wait count(仅保留 31 位,因为第 32 位是 bubble flag)
	if race.Enabled && delta > 0 && v == int32(delta) {
		// The first increment must be synchronized with Wait.
		// Need to model this as a read, because there can be
		// several concurrent wg.counter transitions from 0.
		race.Read(unsafe.Pointer(&wg.sema))
	}
	// 核心校验 - counter 不能为负
	// WaitGroup 的核心约束:未完成的 goroutine 计数不能为负
	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}
	// 核心校验 - 禁止 Add 和 Wait 并发调用
	// w != 0 : 等待数不为0,已有 goroutine 调用了 Wait() 并等待
	// delta > 0:当前是新增 counter
	// v == int32(delta):表示这是第一次新增 counter(且和 Wait 并发)
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// v > 0 : counter 不为 0 ,不需要唤醒wait
	// w ==0 : 没有 goroutine 在等待,即使 counter 归 0 也无需处理
	if v > 0 || w == 0 {
		return
	}
	// This goroutine has set counter to 0 when waiters > 0.
	// Now there can't be concurrent mutations of state:
	// - Adds must not happen concurrently with Wait,
	// - Wait does not increment waiters if it sees counter == 0.
	// Still do a cheap sanity check to detect WaitGroup misuse.
	// 二次校验并发修改
	// 验证 state 没有被其他 goroutine 并发修改(比如 Wait 同时修改了 wait count)
	if wg.state.Load() != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// Reset waiters count to 0.
	// 重置状态,支持 WaitGroup 复用
	// 此时counter 已经是 0,且有等待者wait count 即将被处理(唤醒等待者)
	wg.state.Store(0)
	if bubbled {
		// Adds must not happen concurrently with wait when counter is 0,
		// so we can safely disassociate wg from its current bubble.
		synctest.Disassociate(wg)
	}
	// 唤醒所有等待的 goroutine
	// 循环 w 次(等待的 goroutine 数量),每次释放一个信号量,唤醒一个等待的 goroutine
	// 当 counter 归 0 时,唤醒所有调用 Wait() 的 goroutine。
	for ; w != 0; w-- {
		runtime_Semrelease(&wg.sema, false, 0)
	}
}

Done

go 复制代码
func (wg *WaitGroup) Done() {
	wg.Add(-1)
}

Wait

Wait() 的核心是「自旋检测 + CAS 原子操作 + 信号量阻塞」,无锁设计保证高性能

  • for { ... } 是无锁自旋设计,目的是不断获取最新的 state,处理并发修改场景(比如多个 goroutine 同时调用 Wait()/Add()),避免加锁带来的性能开销
  • CompareAndSwap(state, state+1) 是原子操作,仅当 state 未被并发修改时,才把 wait count 加 1;CAS 失败说明有其他 goroutine 抢修改了 state,需重新自旋
go 复制代码
func (wg *WaitGroup) Wait() {
	// 竞态检测
	if race.Enabled {
		race.Disable()
	}
	// 自旋检测 state 状态,直到 counter 归 0
	for {
		// 原子加载当前 state(64位,包含 counter + wait count + flag)
		state := wg.state.Load()
		// 拆解 state 为核心变量
		v := int32(state >> 32)  // counter(未完成 goroutine 数)
		w := uint32(state & 0x7fffffff) // wait count(等待的 goroutine 数)
		if v == 0 {
			// Counter is 0, no need to wait. counter 为 0 ,无需等待
			// 竞态检测恢复(仅调试用)
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			// synctest 测试框架的 bubble 清理
			if w == 0 && state&waitGroupBubbleFlag != 0 && synctest.IsAssociated(wg) {
				// Adds must not happen concurrently with wait when counter is 0,
				// so we can disassociate wg from its current bubble.
				if wg.state.CompareAndSwap(state, 0) {
					synctest.Disassociate(wg)
				}
			}
			return  // // 直接返回,不阻塞
		}
		// Increment waiters count. CAS 
		// state + 1(注意:+1 只会影响 wait count 部分
		// state 的低 32 位是 counter,33-64 位是 wait count
		// 给 64 位的 state 加 1,等价于给「wait count」加 1(因为 counter 在高 32 位,+1 不会影响)
		if wg.state.CompareAndSwap(state, state+1) {
		    // 竞态检测的同步逻辑(仅调试用)
			if race.Enabled && w == 0 {
				// Wait must be synchronized with the first Add.
				// Need to model this is as a write to race with the read in Add.
				// As a consequence, can do the write only for the first waiter,
				// otherwise concurrent Waits will race with each other.
				race.Write(unsafe.Pointer(&wg.sema))
			}
			// synctest 测试框架的 bubble 关联检测
			synctestDurable := false
			if state&waitGroupBubbleFlag != 0 && synctest.IsInBubble() {
				if race.Enabled {
					race.Enable()
				}
				if synctest.IsAssociated(wg) {
					// Add was called within the current bubble,
					// so this Wait is durably blocking.
					synctestDurable = true
				}
				if race.Enabled {
					race.Disable()
				}
			}
			// 阻塞当前 goroutine,直到信号量被释放
			runtime_SemacquireWaitGroup(&wg.sema, synctestDurable)
			// 正常流程下,被唤醒的 Wait() 看到的 state 一定是 0
			// state != 0 说明:在 Wait() 被唤醒但还没返回的间隙,有其他 goroutine 调用了 Add() 修改了 state(也就是「复用」了 WaitGroup)
			// WaitGroup 不能在 Wait 未返回时复用
			if wg.state.Load() != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			// 竞态检测恢复(仅调试用)
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return // // 被唤醒后返回
		}
	}
}


CompareAndSwap (简称 CAS)是原子操作 (不可被中断)

比较并交换」:如果变量的当前值等于预期值,就把它改成新值;否则不修改,返回操作是否成功

注意事项

1. Add () 必须在 Wait () 之前调用(禁止并发)

Add(n) 必须在所有 goroutine 启动前 /Wait() 调用前执行,禁止在 Wait() 执行期间(阻塞时)调用 Add(),

否则会触发 panic:sync: WaitGroup misuse: Add called concurrently with Wait

错误用法:

go 复制代码
package main

import (
	"fmt"
	"sync"
)

func main() {
	var wg sync.WaitGroup

	// 错误:Add 放在 goroutine 内部,可能和 Wait 并发
	go func() {
		wg.Add(1) // ❌ 主 goroutine 可能先执行 Wait(),或 Wait 已阻塞时执行 Add
		defer wg.Done()
		fmt.Println("goroutine 执行")
	}()

	wg.Wait() // 直接退出(counter=0)
}

修正

go 复制代码
package main

import (
	"fmt"
	"sync"
)

func main() {
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		defer wg.Done()
		fmt.Println("goroutine 执行")
	}()

	wg.Wait() // 直接退出(counter=0)
}

2.Done () 调用次数必须等于 Add () 的正数次数(禁止 counter 为负)

Done() 是 Add(-1) 的别名,调用次数不能超过 Add() 传入的正数总和

否则会触发 panic:sync: negative WaitGroup counter

正确用法: 预先确定好 WaitGroup 的计数值,然后调用相同次数的 Done 完成相应的任务

3. 禁止拷贝 WaitGroup

WaitGroup 包含 noCopy 字段,值拷贝会导致:新的 WaitGroup 拥有独立的 state 和 sema,原同步逻辑完全失效(go vet 会检测到拷贝并告警)。

原因:

WaitGroup 的同步依赖 state(原子变量)和 sema(信号量)的内存地址,拷贝后新对象的地址不同,原对象的 counter 永远无法归 0。

4. Wait () 未返回时,禁止复用 WaitGroup

必须等 Wait() 完全返回后,才能重新调用 Add() 复用 WaitGroup,否则会触发 panic:sync: WaitGroup is reused before previous Wait has returned

5. 必须保证 Done () 一定会被调用(避免永久阻塞)

goroutine 内部若提前返回、panic 或分支遗漏,导致 Done() 未调用,会使 counter 无法归 0,Wait() 永久阻塞

6. WaitGroup 无需手动初始化(零值可用)

WaitGroup 的零值(var wg sync.WaitGroup)是合法的,无需调用 new(sync.WaitGroup) 或手动初始化,直接使用即可

相关推荐
人间打气筒(Ada)6 小时前
如何基于 Go-kit 开发 Web 应用:从接口层到业务层再到数据层
开发语言·后端·golang
想搞艺术的程序员11 小时前
Go RWMutex 源码分析:一个计数器,如何把“读多写少”做得又快又稳
开发语言·redis·golang
喵了几个咪12 小时前
GoWind Content Hub|风行,开箱即用的企业级前后端一体内容中台
vue.js·golang·react·taro
人间打气筒(Ada)13 小时前
go实战案例:如何基于 Conul 给微服务添加服务注册与发现?
开发语言·微服务·zookeeper·golang·kubernetes·etcd·consul
superantwmhsxx13 小时前
[golang][MAC]Go环境搭建+VsCode配置
vscode·macos·golang
Cocktail_py14 小时前
Windows直接部署crawlab
windows·python·golang
人间打气筒(Ada)15 小时前
go实战案例:如何在 Go-kit 和 Service Meh 中进行服务注册与发现?
开发语言·后端·golang·istio·go-kit
小白的代码日记15 小时前
区块链分叉检测与回扫系统(Go语言)
人工智能·golang·区块链
brucelee18615 小时前
Windows 11 安装 Go(Golang)教程
开发语言·windows·golang