深度讲解Go源码-sync.WaitGroup

🐉大家好,我是gopher_looklook,现任某独角兽企业Go语言工程师,喜欢钻研Go源码,发掘各项技术在大型Go微服务项目中的最佳实践,期待与各位小伙伴多多交流,共同进步!

概念

sync.WaitGroup 是Go语言中用于协调多个goroutine同步的核心工具,用于等待一组goroutine完成它们的任务。

简单示例

go 复制代码
package main

import (
    "fmt"
    "sync"
    "time"
)

func worker(id int, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Worker %d starting\n", id)
    time.Sleep(1 * time.Second)
    fmt.Printf("Worker %d done\n", id)
}

func main() {
    var wg sync.WaitGroup
    numWorkers := 3
    wg.Add(numWorkers)

    for i := 0; i < numWorkers; i++ {
       go worker(i, &wg)
    }
    wg.Wait() // 最常见的用法,此时只有一个等待者
    fmt.Println("All workers have finished")
}
  • 输出
  • 分析

可以看到,当调用wg.Add(numWorkers)时,表示我们要执行numWorkers组子goroutine,执行wg.Wait()则会阻塞当前goroutine。只有当全部子goroutine全部执行完成,最后一个子goroutine执行完wg.Done()后,当前被wg.Wait()阻塞的goroutine才能继续往下执行。

背景知识-Go源码中信号量的实现

Go语言提供了两个函数用于实现信号量的控制runtime_Semreleaseruntime_Semacquire

go 复制代码
func runtime_Semrelease(s *uint32, handoff bool, skipframes int)
 
func runtime_Semacquire(s *uint32)
  1. runtime_Semrelease 函数的主要作用是释放信号量。
  • 当一个 goroutine 使用完资源后,会调用 runtime_Semrelease 将信号量的值加 1,表示释放了一个资源。
  1. runtime_Semacquire 函数的主要作用是尝试获取信号量。
  • 如果当前信号量的值大于 0,runtime_Semacquire 会将信号量的值减 1,表示成功获取了一个资源,然后立即返回,调用该函数的 goroutine 可以继续执行后续操作。
  • 如果当前信号量的值为 0,说明没有可用资源,调用 runtime_Semacquire 的 goroutine 会被阻塞。该 goroutine 会被放入等待队列中,直到有其他 goroutine 调用 runtime_Semrelease 释放信号量,才有可能被唤醒继续执行。

sync.WaitGroup正是使用了获取和释放信号量的操作,实现了等待一组子goroutine完成任务,并通知到正在等待中的goroutine的功能。

源码解读

  • go源码版本:1.23.0

  • 源码

go 复制代码
package sync

import (
    "sync/atomic"
    "unsafe"
)

type WaitGroup struct {
    noCopy noCopy

    state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
    sema  uint32
}

func (wg *WaitGroup) Add(delta int) {
    state := wg.state.Add(uint64(delta) << 32)
    v := int32(state >> 32)
    w := uint32(state)
    if v < 0 {
       panic("sync: negative WaitGroup counter")
    }
    if w != 0 && delta > 0 && v == int32(delta) {
       panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    if v > 0 || w == 0 {
       return
    }
    // This goroutine has set counter to 0 when waiters > 0.
    // Now there can't be concurrent mutations of state:
    // - Adds must not happen concurrently with Wait,
    // - Wait does not increment waiters if it sees counter == 0.
    // Still do a cheap sanity check to detect WaitGroup misuse.
    if wg.state.Load() != state {
       panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    // Reset waiters count to 0.
    wg.state.Store(0)
    for ; w != 0; w-- {
       runtime_Semrelease(&wg.sema, false, 0)
    }
}

// Done decrements the [WaitGroup] counter by one.
func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

// Wait blocks until the [WaitGroup] counter is zero.
func (wg *WaitGroup) Wait() {
    for {
       state := wg.state.Load()
       v := int32(state >> 32)
       w := uint32(state)

       // Increment waiters count.
       if wg.state.CompareAndSwap(state, state+1) {
          runtime_Semacquire(&wg.sema)
          if wg.state.Load() != 0 {
             panic("sync: WaitGroup is reused before previous Wait has returned")
          }

          return
       }
    }
}

这里我删减了部分与竞态检测相关的代码。可以看到,sync.WaitGroup的核心功能集中体现在3个方法(Add/Done/Wait)上。 这3个方法互相依赖,相辅相成,需要一起配合使用。

WaitGroup.Add

go 复制代码
func (wg *WaitGroup) Add(delta int) {
    state := wg.state.Add(uint64(delta) << 32)
    v := int32(state >> 32)
    w := uint32(state)
    if v < 0 {
       panic("sync: negative WaitGroup counter")
    }
    if w != 0 && delta > 0 && v == int32(delta) {
       panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    if v > 0 || w == 0 {
       return
    }
    // This goroutine has set counter to 0 when waiters > 0.
    // Now there can't be concurrent mutations of state:
    // - Adds must not happen concurrently with Wait,
    // - Wait does not increment waiters if it sees counter == 0.
    // Still do a cheap sanity check to detect WaitGroup misuse.
    if wg.state.Load() != state {
       panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    // Reset waiters count to 0.
    wg.state.Store(0)
    for ; w != 0; w-- {
       runtime_Semrelease(&wg.sema, false, 0)
    }
}

首先是Add 方法。Add 方法会将64位int型的整数 wg.state 拆分成了高32位和低32位,分别用于存储不同含义的值:v 和 w。

  • v:计数器,代表有多少个子goroutine未完成任务。
  • w:等待者数量,代表有多少个正在等待所有任务完成的goroutine的数量。

在我们上述的例子中,有1个主goroutine在等待3个子goroutine完成任务。因此v=3,w=1,但是w的值需要等到调用Wait方法时才设置。

刚开始调用Add 方法时,记录了要等待完成任务的子goroutine数量,存储在wg.state字段的高32位。赋值完成后,v等于3,w等于0(还没赋值)。经过几个异常判断的if条件检验。程序会在以下代码跳出Add函数。

go 复制代码
func (wg *WaitGroup) Add(delta int) {
    ...... 
    if v > 0 || w == 0 {
        return
    }
    ......
}

WiatGroup.Wait

go 复制代码
func (wg *WaitGroup) Wait() {
    for {
       state := wg.state.Load()
       v := int32(state >> 32)
       w := uint32(state)

       // Increment waiters count.
       if wg.state.CompareAndSwap(state, state+1) {
          runtime_Semacquire(&wg.sema)
          if wg.state.Load() != 0 {
             panic("sync: WaitGroup is reused before previous Wait has returned")
          }

          return
       }
    }
}

Wait() 函数的执行时间在Done() 之前,因此优先分析Wait函数。Wait函数 写了一个for循环。在for循环里,首先将wg.state 的高32位和低32位拆分开,分别赋值给v和w,分别代表计数器和等待者数量。计数器不需要更新,而等待者数量w每调用一次Wait函数都需要自增1,也就是下面这行代码:

go 复制代码
 wg.state.CompareAndSwap(state, state+1)

之后会调用runtime_Semacquire函 数,确保调用Wait函数的地方都会阻塞当前goroutine的进一步执行。

go 复制代码
runtime_Semacquire(&wg.sema)

WaitGroup.Done

go 复制代码
func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

通过以上对Add函数Wait函数 分析,我们知道了在等待所有子goroutine完成任务之前,外层等待的goroutine都会被阻塞,阻塞的原因是由于Wait函数在for循环中尝试获取信号量,但是并没有可用的信号量可以获取。

Done函数 实际调用了Add函数 ,因此只需要分析Add函数的执行过程即可。在最后一个子goroutine执行Done()之前,每次调用Done函数,都会将wg.state 字段的高32位减1,即将计数器减1。并在以下这行代码跳出Add函数。

go 复制代码
func (wg *WaitGroup) Add(delta int) {
    ...... 
    if v > 0 || w == 0 {
        return
    }
    ......
}

最后一个子goroutine调用Done函数时,wg.state字段的32位减到0(计数器归0) ,w>0(存在等待者)。此时会将wg.state重新设置为0,并执行到下面这段代码,用于释放w个信号量,唤醒w个阻塞的等待者。之后退出Add函数 ,也即正常退出Done函数

go 复制代码
func (wg *WaitGroup) Add(delta int) {
    for ; w != 0; w-- {
        runtime_Semrelease(&wg.sema, false, 0)
    }
 
}

处于等待的goroutine(在我们的例子当中是调用了wg.Wait函数的main goroutine)由于runtime_Semrelease释放了信号,for循环尝试获取信号量成功,wg.state也已经被重新设置为0,执行return跳出for循环,程序不再阻塞。

go 复制代码
func (wg *WaitGroup) Wait() {
    for {
        。。。。。。
       if wg.state.CompareAndSwap(state, state+1) {
          runtime_Semacquire(&wg.sema)
          if wg.state.Load() != 0 {
             panic("sync: WaitGroup is reused before previous Wait has returned")
          }
          return
       }
    }
}

工作流程

为了方便理解,我们可以利用流程图来演示程序执行的不同时刻,各个goroutine的执行情况。这里以上述的简单示例代码为例。

常见误区

大部分情况下,我们在使用WaitGroup时只会调用一次wg.Wait() 函数。通过上述对源码的分析,我们知道WaitGroup其实是允许有多个goroutine等待操作完成的,例如下面这段代码。

go 复制代码
package main

import (
    "fmt"
    "sync"
    "sync/atomic"
    "time"
)

func main() {
    var wg sync.WaitGroup

    // 设置等待组的计数器为 3,代表有 3 个任务要完成
    wg.Add(3)
    var num atomic.Int32

    // 模拟 3 个子goroutine完成
    for i := 0; i < 3; i++ {
       _i := i
       go func() {
          time.Sleep(2 * time.Second)
          num.Add(1)
          fmt.Printf("sub goroutine %d is working. \n", _i)
          wg.Done()
       }()
    }

    // 创建 10 个等待的 goroutine
    for i := 0; i < 10; i++ {
       go func(id int) {
          // 进入等待状态,直到等待组的计数器变为 0
          wg.Wait()
          fmt.Printf("The waiting goroutine %d has been awakened, get num: %d\n", id, num.Load())
       }(i)
    }

    // 等待一段时间,确保所有输出都能显示
    time.Sleep(4 * time.Second)
    fmt.Println("All goroutines have been awakened.")
}

总结

在本篇文章中,我们通过一段常见的示例代码,演示了sync.WaitGroup的基础用法。之后按照程序执行时间线逐步分析源码,探究了sync.WaitGroup可以等待一组子goroutine执行完成的原因。

以上便是我对WaitGroup源码的分析和总结,如果这篇文章对屏幕前的你有帮助的话,欢迎点赞+关注,你的支持是我创作的最大动力!

相关推荐
Snow_Dragon_L44 分钟前
【MySQL】语言连接
android·数据库·后端·sql·mysql·ubuntu·adb
wn5312 小时前
【Go - 小顶堆/大顶堆】
开发语言·后端·golang
安清h2 小时前
【基于SprintBoot+Mybatis+Mysql】电脑商城项目之修改密码和个人资料
数据库·后端·mysql·spring·mybatis
uhakadotcom3 小时前
Java反序列化漏洞利用进阶:绕过WAF和EDR,实现隐蔽攻击
后端·架构·github
黑兔子3 小时前
Java|导出Excel文件
java·后端
二闹3 小时前
Java抽象工厂模式的面试题目及其答案
java·后端·面试
傲娇的萌3 小时前
mac彻底删除goland
后端
uhakadotcom5 小时前
macOS 内核扩展 Fuzzing 指南:用户空间 + IDA + TinyInst
后端·架构·github
uhakadotcom5 小时前
震惊!Google的AI驱动的OSS-Fuzz工具在开源项目中发现大量漏洞!
后端·架构·github