一、sync.WaitGroup
作用:协程结束后子协程会被立刻销毁,sync.WaitGroup 可以让协程等待子协程执行完,再执行下一步
常见场景:并行处理,初始化资源,多协程结束
主要接口:
go
var wg sync.WaitGroup
wg.Add(3) // 增加三个等待计数
wg.Done() //减少一个等待计数
wg.Wait() // 阻塞,直到等待计数== 0
多协程结束代码:
go
func wg22() {
var wg sync.WaitGroup
wg.Add(3) // 增加三个等待计数
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(404*time.Millisecond))
go func() error {
defer wg.Done() // 减少等待计数,最好用defer,不然中途panic就出问题了
// ps:扩展下,如果这里不加defer的时候panic了,不会执行wg.Done,分情况 recover的时候会死锁。否则panic会导致所有协程结束
time.Sleep(100 * time.Millisecond)
fmt.Println("hello world")
return nil
}()
go func() error {
fmt.Println("hello world22")
time.Sleep(200 * time.Millisecond)
wg.Done()
return fmt.Errorf("err")
}()
go func() error {
fmt.Println("hello world33")
time.Sleep(300 * time.Millisecond)
select {
case <-ctx.Done(): // 超时后触发
fmt.Println("调用成功")
default:
fmt.Println("调用失败")
return fmt.Errorf("err2")
}
wg.Done()
return nil
}()
wg.Wait() // 等待三个子协程返回
}
注意事项:
1.Add 必须在 Goroutine 外部调用,不然可能会出现wgWait()直接返回的情况
- defer wg.Done() 可以防止panic导致没有执行,引发错误
go
wg.Add(1) //步骤1
go func() error {
defer wg.Done() // 步骤2
}
- WaitGroup 不能被复制:
和 sync.Mutex 一样,sync.WaitGroup 是值类型,但不应被复制。在函数间传递时,必须使用指针。
go
func process(wg *sync.WaitGroup) { // 传指针
defer wg.Done()
}
4.复用与清零:
WaitGroup 在计数器归零后可以被再次使用。但是你必须确保它已经wait返回后,再执行wait
二、sync.errgroup
原理 :在 sync.WaitGroup 的基础上,增加了错误传播和上下文(Context)取消功能。如果任一子任务返回错误,它可以自动取消所有其他正在执行的任务(通过 Context)。
除此之外还有控制并发数量的功能(通过SetLimit)
代码示例:
go
func TestErrGroup() {
g, ctx := errgroup.WithContext(context.Background())
urls := []string{"url1", "url2", "bad-url"}
for _, url := range urls {
url := url
g.Go(func() error {
fmt.Printf("Fetching %s...\n", url)
if url == "bad-ur1l" { //模拟一个错误 返回error
return errors.New("failed to fetch: bad-url")
}
select {
case <-time.After(100 * time.Millisecond): // 模拟工作
fmt.Printf("Success: %s\n", url)
return nil
case <-ctx.Done(): // 子协程报错后,errgroup 会向这个通道发送消息,从而结束子协程
fmt.Printf("Cancelled fetching %s due to error elsewhere\n", url)
return ctx.Err() // 返回取消原因,或者直接返回nil也可
}
})
}
// Wait 会阻塞,直到所有 goroutine 都完成。
// 它会返回第一个非空的错误(如果有的话),如果所有都成功则返回nil。
// 只要有一个错误,就会立刻返回,并给ctx发送消息,结束其他子线程
if err := g.Wait(); err != nil {
fmt.Println("Overall error:", err)
} else {
fmt.Println("All successes!")
}
}
源码解析1: 遇到error的取消原理
go
// 这个函数会生成一个 errgroup 和一个可取消的上下文
func WithContext(ctx context.Context) (*Group, context.Context) {
ctx, cancel := context.WithCancelCause(ctx)
return &Group{cancel: cancel}, ctx
}
func (g *Group) Wait() error {
g.wg.Wait()
if g.cancel != nil { // 任务结束之后,这里会在检查一遍,并再次取消
g.cancel(g.err)
}
return g.err
}
func (g *Group) Go(f func() error) {
...
g.wg.Add(1)
go func() {
defer g.done()
if err := f(); err != nil {
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
g.cancel(g.err) // 执行该函数,取消上下文
}
})
}
}()
}
// 然后就是业务代码调用
select {
case <-ctx.Done(): // 通道接受消息,结束阻塞
...
对于 取消上下文不太了解的,可以看我之前写的 go资深之路笔记(一) Context第3点
源码解析2: 限制并发协程数
errgroup 的 SetLimit 函数,可以限制其同时并发数量,原理其实就是用有缓存的chan,来控制协程发送给通道数据得数量,从而控制并发数量
go
func (g *Group) SetLimit(n int) {
if n < 0 {
g.sem = nil
return
}
if len(g.sem) != 0 {
panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem)))
}
g.sem = make(chan token, n) // 有缓存的chan
}
func (g *Group) done() {
if g.sem != nil {
<-g.sem // 释放消息
}
g.wg.Done()
}
func (g *Group) Go(f func() error) {
if g.sem != nil {
g.sem <- token{} //// 写入消息,如果已经满了,等待。从而实现并发数量控制
}
}
三、 multierror
作用:errgroup只能返回一个错误,这种适用于快速返回错误的情况。但有时候可能需要收集全部错误,这个时候 multierror就派上用场了
代码示例:
go
func TestMultierror() {
var wg sync.WaitGroup
var mu sync.Mutex // 保护merr的并发安全
var merr *multierror.Error
tasks := []string{"task1", "task2", "task3_error", "task4_error"}
for _, task := range tasks {
wg.Add(1)
go func(t string) {
defer wg.Done()
err := doWork(t)
if err != nil {
mu.Lock()
merr = multierror.Append(merr, err) // 增加error
mu.Unlock()
}
}(task)
}
wg.Wait() // 等待所有任务完成
if merr != nil {
merr.ErrorFormat = func(errors []error) string {
// 可以自定义错误输出的格式
return fmt.Sprintf("All errors: %v", errors)
}
fmt.Println(merr.Error())
// 输出: All errors: [error in task3_error error in task4_error]
}
}
func doWork(task string) error {
if task == "task3_error" || task == "task4_error" {
return errors.New("error in " + task)
}
return nil
}
总结:
这个没啥好讲的,只是封装了一个 []error 和一个格式化函数而已。当工具用就行。