1.背景
在使用 GORM 做事务管理时,常见的写法是直接调用:
go
db.Transaction(func(tx *gorm.DB) error {
// 事务逻辑
return nil
})
但是,这种方式有两个问题:
- 嵌套事务问题 :如果在一个事务中再次调用
Transaction
,GORM 会新建一个事务,而不是复用已有事务,可能造成事务混乱。 - 调试困难:事务调用链嵌套时,很难知道当前处于第几层事务,调试不方便。
为了解决这些问题,我们封装了一个 支持嵌套事务 + 事务计数器 + 日志开关 的中间件插件。
2. 核心思路
-
使用
context.Context
存储事务信息- 每次开启事务时,将
*gorm.DB
和当前事务层级存入context
。
- 每次开启事务时,将
-
检测已有事务
- 如果
context
中已有事务,就直接复用,而不是新开。
- 如果
-
事务计数器
- 层级
level
从 1 开始,内层每进入一层ExecTx
就加 1。
- 层级
-
日志开关
- 通过
debugLog
控制是否输出事务进入/退出的调试日志。
- 通过
3. 插件实现
go
package transaction
import (
"context"
"fmt"
"gorm.io/gorm"
)
// 存储事务信息
type txContext struct {
tx *gorm.DB
level int
}
type contextTxKey struct{}
type TransactionPlugin struct {
db *gorm.DB
debugLog bool // 日志开关
}
// New 创建事务插件
func New(debug bool) *TransactionPlugin {
return &TransactionPlugin{
debugLog: debug,
}
}
func (p *TransactionPlugin) Name() string {
return "transaction_plugin"
}
func (p *TransactionPlugin) Initialize(db *gorm.DB) error {
p.db = db
return nil
}
// ExecTx 在事务中执行函数(支持嵌套事务 + 事务计数器)
func (p *TransactionPlugin) ExecTx(ctx context.Context, fn func(ctx context.Context) error) error {
// 检查是否已有事务
if txData, ok := ctx.Value(contextTxKey{}).(txContext); ok {
newCtx := context.WithValue(ctx, contextTxKey{}, txContext{
tx: txData.tx,
level: txData.level + 1,
})
p.logf("[事务插件] 进入嵌套事务,层级: %d", txData.level+1)
return fn(newCtx)
}
// 开启新事务
return p.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
newCtx := context.WithValue(ctx, contextTxKey{}, txContext{
tx: tx,
level: 1,
})
p.logf("[事务插件] 开启事务,层级: 1")
err := fn(newCtx)
if err != nil {
p.logf("[事务插件] 事务回滚,层级: 1")
} else {
p.logf("[事务插件] 事务提交,层级: 1")
}
return err
})
}
// GetDB 根据 ctx 获取事务 DB
func GetDB(ctx context.Context, fallback *gorm.DB) *gorm.DB {
if txData, ok := ctx.Value(contextTxKey{}).(txContext); ok {
return txData.tx
}
return fallback.Session(&gorm.Session{})
}
// logf 日志输出(根据 debugLog 控制)
func (p *TransactionPlugin) logf(format string, args ...interface{}) {
if p.debugLog {
fmt.Printf(format+"\n", args...)
}
}
4.使用示例
go
package main
import (
"context"
"fmt"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"your_project/transaction"
)
type User struct {
ID uint
Name string
}
func main() {
dsn := "root:password@tcp(127.0.0.1:3306)/test?charset=utf8mb4&parseTime=True&loc=Local"
db, _ := gorm.Open(mysql.Open(dsn), &gorm.Config{})
// 创建事务插件(true=调试模式,false=生产模式)
txPlugin := transaction.New(true)
db.Use(txPlugin)
err := txPlugin.ExecTx(context.Background(), func(ctx context.Context) error {
tx := transaction.GetDB(ctx, db)
if err := tx.Create(&User{Name: "Tom"}).Error; err != nil {
return err
}
// 嵌套事务
return txPlugin.ExecTx(ctx, func(ctx context.Context) error {
tx2 := transaction.GetDB(ctx, db)
if err := tx2.Create(&User{Name: "Jerry"}).Error; err != nil {
return err
}
// 再嵌套一层
return txPlugin.ExecTx(ctx, func(ctx context.Context) error {
tx3 := transaction.GetDB(ctx, db)
return tx3.Create(&User{Name: "Spike"}).Error
})
})
})
if err != nil {
fmt.Println("事务失败:", err)
} else {
fmt.Println("事务成功")
}
}
5.运行效果
开启调试模式(debug=true
)
ini
[事务插件] 开启事务,层级: 1
[事务插件] 进入嵌套事务,层级: 2
[事务插件] 进入嵌套事务,层级: 3
[事务插件] 事务提交,层级: 1
事务成功
关闭调试模式(debug=false
)
事务成功
6.执行流程图
ini
ExecTx(level=1) ──> 开启事务
├── ExecTx(level=2) ──> 复用事务
│ └── ExecTx(level=3) ──> 复用事务
└── 提交 / 回滚(只在 level=1 处理)