背景
为了把一个请求中所有的日志串联起来,我在日志中增加一个traceId的信息。《gorm log with traceId 打印带有traceId信息的日志,通过context实现》在这篇文章中,提供了第一个版本的解决方案。这个解决方案主要是解决单体服务架构下的日志串联问题。随着,我们分布式应用的发展,我们面临着需要一个能串联跨服务调用链路的日志系统。因此,我们在第一个版本上做了优化处理。主要是增加了日志的统计维度。
统计维度说明
trace:全局唯一的链路追踪标识,通常是在负载均衡或者网关层生成。
span:当前请求的唯一链路追踪标识,通常是一个服务请求的链路追踪标识。
parent_span:上一级服务请求的唯一链路追踪标识。
caller_service_name:上一级请求服务的名称。在微服务架构下,我们时常会遇到某个错误的传参导致错误,但是,却苦于不知道是哪个服务错误传参,导致增加了bug定位的难度。另外,有了这个维度,我们也可以基于该维度统计各个服务调用频次信息,为我们的服务优化提供了清晰明了的方向。
caller_ip:调用方的ip,由于微服务架构下,一个服务会存在多个实例,有了ip,我们可以更加精准的实现具体pod的定位。
为了更加清楚地知道每一步的执行时间,单个请求的总体耗时,以及网络耗时。增加这些维度重点是要做到心中有数,知道我们服务的边界在哪里,知道我们服务的瓶颈在哪里,为我们的服务优化提供更加清晰明了的方向。
execute_start_time:单步执行开始时间。
execute_duration:单步执行耗时时长。
span_start_time:跨度开始时间,通常是指一个服务接收到请求的时间。
span_duration:跨度耗时时长,通常是指一个服务从接受到请求的时间到发出响应这段总耗时。
network_duration:网络耗时。由于,从上一级服务发生请求到我们服务接收到请求之间有一层网络耗时,因此,我们增加了这个维度。这个统计维度,我们需要在中间件中实现。
解决方案
这个部分我们就直接上代码了。
这个部分是日志中需要用到的常量的定义,主要是日志统计维度key定义。
Go
package constant
// 日志上下文中的键
const (
CONTEXT_KEY_TRACE = "trace"
CONTEXT_KEY_SPAN = "span"
CONTEXT_KEY_PARENT_SPAN = "parent_span"
CONTEXT_KEY_METHOD = "method"
CONTEXT_KEY_PATH = "path"
CONTEXT_KEY_VERSION = "version"
CONTEXT_KEY_PROJECT = "project"
CONTEXT_KEY_REQUEST_ID = "Request-Id"
CONTEXT_KEY_CALLER_SERVICE_NAME = "caller_service_name"
CONTEXT_KEY_CALLER_IP = "caller_ip"
CONTEXT_KEY_CUSTOMLOGGER = "customLogger"
CONTEXT_KEY_EXECUTE_START_TIME = "execute_start_time"
CONTEXT_KEY_EXECUTE_DURATION = "execute_duration" // 单步执行时间
CONTEXT_KEY_SPAN_START_TIME = "span_start_time"
CONTEXT_KEY_SPAN_DURATION = "span_duration" // 跨度持续时间
CONTEXT_KEY_NETWORK_DURATION = "network_duration" // 网络耗时
)
// 日志http请求头中的键
const (
HEADER_KEY_TRACE = "trace"
HEADER_KEY_SPAN = "span"
HEADER_KEY_PARENT_SPAN = "parent_span"
HEADER_KEY_PROJECT = "project"
HEADER_KEY_CALLER_SERVICE_NAME = "caller_service_name"
HEADER_KEY_RESPONSE_TIME = "response_time" // 响应时间,用于统计网络耗时
HEADER_KEY_REQUEST_START_TIME = "request_start_time" // 请求开始时间,用于统计网络耗时,接受前端或者其他服务的请求时间
)
主要是初始化的方法修改。主要是增加了这个方法:
Info方法中增加时长的统计的逻辑。
Go
package logger
import (
"context"
"errors"
"fmt"
"time"
"{{your module}}/go-core/utils/constant"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/utils"
)
// Logger logger for gorm2
type Logger struct {
log *zap.Logger
logger.Config
customFields []func(ctx context.Context) zap.Field
}
// Option logger/recover option
type Option func(l *Logger)
// WithCustomFields optional custom field
func WithCustomFields(fields ...func(ctx context.Context) zap.Field) Option {
return func(l *Logger) {
l.customFields = fields
}
}
// WithConfig optional custom logger.Config
func WithConfig(cfg logger.Config) Option {
return func(l *Logger) {
l.Config = cfg
}
}
// SetGormDBLogger set db logger
func SetGormDBLogger(db *gorm.DB, l logger.Interface) {
db.Logger = l
}
// New logger form gorm2
func New(zapLogger *zap.Logger, opts ...Option) logger.Interface {
l := &Logger{
log: zapLogger,
Config: logger.Config{
SlowThreshold: 200 * time.Millisecond,
Colorful: false,
IgnoreRecordNotFoundError: false,
LogLevel: logger.Warn,
},
}
for _, opt := range opts {
opt(l)
}
return l
}
// NewDefault new default logger
// 初始化一个默认的 logger
func NewDefault(zapLogger *zap.Logger) logger.Interface {
return New(zapLogger, WithCustomFields(
func(ctx context.Context) zap.Field {
v := ctx.Value("Request-Id")
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String("trace", vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
v := ctx.Value("method")
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String("method", vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
v := ctx.Value("path")
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String("path", vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
v := ctx.Value("version")
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String("version", vv)
}
return zap.Skip()
}),
WithConfig(logger.Config{
SlowThreshold: 200 * time.Millisecond,
Colorful: false,
IgnoreRecordNotFoundError: false,
LogLevel: logger.Info,
}))
}
// 用于支持微服务架构下的链路追踪
func NewTracingLogger(zapLogger *zap.Logger) logger.Interface {
return New(zapLogger, WithCustomFields(
// trace是链路追踪的唯一标识
// span是当前请求的唯一标识
// parent_span是父请求的唯一标识
func(ctx context.Context) zap.Field {
v := ctx.Value(constant.CONTEXT_KEY_TRACE)
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String(constant.CONTEXT_KEY_TRACE, vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
v := ctx.Value(constant.CONTEXT_KEY_SPAN)
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String(constant.CONTEXT_KEY_SPAN, vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
v := ctx.Value(constant.CONTEXT_KEY_PARENT_SPAN)
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String(constant.CONTEXT_KEY_PARENT_SPAN, vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
v := ctx.Value(constant.CONTEXT_KEY_METHOD)
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String(constant.CONTEXT_KEY_METHOD, vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
v := ctx.Value(constant.CONTEXT_KEY_PATH)
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String(constant.CONTEXT_KEY_PATH, vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
v := ctx.Value(constant.CONTEXT_KEY_VERSION)
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String(constant.CONTEXT_KEY_VERSION, vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
// 用于标识调用方服务名
v := ctx.Value(constant.CONTEXT_KEY_CALLER_SERVICE_NAME)
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String(constant.CONTEXT_KEY_CALLER_SERVICE_NAME, vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
// 用于标识调用方ip
v := ctx.Value(constant.CONTEXT_KEY_CALLER_IP)
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String(constant.CONTEXT_KEY_CALLER_IP, vv)
}
return zap.Skip()
}),
WithConfig(logger.Config{
SlowThreshold: 200 * time.Millisecond,
Colorful: false,
IgnoreRecordNotFoundError: false,
LogLevel: logger.Info,
}))
}
// LogMode log mode
func (l *Logger) LogMode(level logger.LogLevel) logger.Interface {
newLogger := *l
newLogger.LogLevel = level
return &newLogger
}
// Info print info
func (l Logger) Info(ctx context.Context, msg string, args ...interface{}) {
if l.LogLevel >= logger.Info {
//预留10个字段位置
fields := make([]zap.Field, 0, 10+len(l.customFields))
fields = append(fields, zap.String("file", utils.FileWithLineNum()))
for _, customField := range l.customFields {
fields = append(fields, customField(ctx))
}
now := time.Now().UnixMilli()
// 从ctx中获取操作的开始时间
if v := ctx.Value(constant.CONTEXT_KEY_EXECUTE_START_TIME); v != nil {
if vv, ok := v.(int64); ok {
// 计算操作的执行时间,以毫秒为单位
duration := now - vv
// 将操作的执行时间放入ctx
fields = append(fields, zap.Int64(constant.CONTEXT_KEY_EXECUTE_DURATION, duration))
}
}
for _, arg := range args {
if vv, ok := arg.(zapcore.Field); ok {
if len(vv.String) > 0 {
fields = append(fields, zap.String(vv.Key, vv.String))
} else if vv.Integer > 0 {
fields = append(fields, zap.Int64(vv.Key, vv.Integer))
} else {
fields = append(fields, zap.Any(vv.Key, vv.Interface))
}
}
}
l.log.Info(msg, fields...)
}
}
// Warn print warn messages
func (l Logger) Warn(ctx context.Context, msg string, args ...interface{}) {
if l.LogLevel >= logger.Warn {
//预留10个字段位置
fields := make([]zap.Field, 0, 10+len(l.customFields))
fields = append(fields, zap.String("file", utils.FileWithLineNum()))
for _, customField := range l.customFields {
fields = append(fields, customField(ctx))
}
for _, arg := range args {
if vv, ok := arg.(zapcore.Field); ok {
if len(vv.String) > 0 {
fields = append(fields, zap.String(vv.Key, vv.String))
} else if vv.Integer > 0 {
fields = append(fields, zap.Int64(vv.Key, vv.Integer))
} else {
fields = append(fields, zap.Any(vv.Key, vv.Interface))
}
}
}
l.log.Warn(msg, fields...)
}
}
// Error print error messages
func (l Logger) Error(ctx context.Context, msg string, args ...interface{}) {
if l.LogLevel >= logger.Error {
//预留10个字段位置
fields := make([]zap.Field, 0, 10+len(l.customFields))
fields = append(fields, zap.String("file", utils.FileWithLineNum()))
for _, customField := range l.customFields {
fields = append(fields, customField(ctx))
}
for _, arg := range args {
if vv, ok := arg.(zapcore.Field); ok {
if len(vv.String) > 0 {
fields = append(fields, zap.String(vv.Key, vv.String))
} else if vv.Integer > 0 {
fields = append(fields, zap.Int64(vv.Key, vv.Integer))
} else {
fields = append(fields, zap.Any(vv.Key, vv.Interface))
}
}
}
l.log.Error(msg, fields...)
}
}
// Trace print sql message
func (l Logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.LogLevel <= logger.Silent {
return
}
fields := make([]zap.Field, 0, 6+len(l.customFields))
elapsed := time.Since(begin)
switch {
case err != nil && l.LogLevel >= logger.Error && (!l.IgnoreRecordNotFoundError || !errors.Is(err, gorm.ErrRecordNotFound)):
for _, customField := range l.customFields {
fields = append(fields, customField(ctx))
}
fields = append(fields,
zap.Error(err),
zap.String("file", utils.FileWithLineNum()),
zap.Duration("latency", elapsed),
)
sql, rows := fc()
if rows == -1 {
fields = append(fields, zap.String("rows", "-"))
} else {
fields = append(fields, zap.Int64("rows", rows))
}
fields = append(fields, zap.String("sql", sql))
l.log.Error("", fields...)
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= logger.Warn:
for _, customField := range l.customFields {
fields = append(fields, customField(ctx))
}
fields = append(fields,
zap.Error(err),
zap.String("file", utils.FileWithLineNum()),
zap.String("slow!!!", fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)),
zap.Duration("latency", elapsed),
)
sql, rows := fc()
if rows == -1 {
fields = append(fields, zap.String("rows", "-"))
} else {
fields = append(fields, zap.Int64("rows", rows))
}
fields = append(fields, zap.String("sql", sql))
l.log.Warn("", fields...)
case l.LogLevel == logger.Info:
for _, customField := range l.customFields {
fields = append(fields, customField(ctx))
}
fields = append(fields,
zap.Error(err),
zap.String("file", utils.FileWithLineNum()),
zap.Duration("latency", elapsed),
)
sql, rows := fc()
if rows == -1 {
fields = append(fields, zap.String("rows", "-"))
} else {
fields = append(fields, zap.Int64("rows", rows))
}
fields = append(fields, zap.String("sql", sql))
l.log.Info("", fields...)
}
}
// Immutable custom immutable field
// Deprecated: use Any instead
func Immutable(key string, value interface{}) func(ctx context.Context) zap.Field {
return Any(key, value)
}
// Any custom immutable any field
func Any(key string, value interface{}) func(ctx context.Context) zap.Field {
field := zap.Any(key, value)
return func(ctx context.Context) zap.Field { return field }
}
// String custom immutable string field
func String(key string, value string) func(ctx context.Context) zap.Field {
field := zap.String(key, value)
return func(ctx context.Context) zap.Field { return field }
}
// Int64 custom immutable int64 field
func Int64(key string, value int64) func(ctx context.Context) zap.Field {
field := zap.Int64(key, value)
return func(ctx context.Context) zap.Field { return field }
}
// Uint64 custom immutable uint64 field
func Uint64(key string, value uint64) func(ctx context.Context) zap.Field {
field := zap.Uint64(key, value)
return func(ctx context.Context) zap.Field { return field }
}
// Float64 custom immutable float32 field
func Float64(key string, value float64) func(ctx context.Context) zap.Field {
field := zap.Float64(key, value)
return func(ctx context.Context) zap.Field { return field }
}
测试用例:
Go
package logger
import (
"context"
"fmt"
"testing"
"time"
"{{your module}}/go-core/utils/constant"
"go.uber.org/zap"
)
// 测试微服务模式下的日志记录
func TestTracingLog(t *testing.T) {
// 创建一个 zap logger 实例
zapLogger, _ := zap.NewProduction()
defer zapLogger.Sync() // 确保日志被刷新
// 创建一个带有自定义字段和配置的 Logger 实例
customLogger := NewTracingLogger(zapLogger)
ctx := context.Background()
ctx = context.WithValue(ctx, constant.CONTEXT_KEY_SPAN, "123456")
ctx = context.WithValue(ctx, constant.CONTEXT_KEY_PARENT_SPAN, "parent_span_123456")
ctx = context.WithValue(ctx, constant.CONTEXT_KEY_TRACE, "trace_id_123456")
ctx = context.WithValue(ctx, constant.CONTEXT_KEY_METHOD, "POST")
ctx = context.WithValue(ctx, constant.CONTEXT_KEY_PATH, "/api/v1/users")
ctx = context.WithValue(ctx, constant.CONTEXT_KEY_VERSION, "v1.0.0")
ctx = context.WithValue(ctx, constant.CONTEXT_KEY_CALLER_SERVICE_NAME, "user-service")
ctx = context.WithValue(ctx, constant.CONTEXT_KEY_CALLER_IP, "172.0.0.3")
// 测试 Info 方法
customLogger.Info(ctx, "This is an info message")
// 测试 Warn 方法
customLogger.Warn(ctx, "This is a warning message")
// 测试 Error 方法
customLogger.Error(ctx, "This is an error message")
// 测试 Trace 方法,模拟一个慢查询
slowQueryBegin := time.Now()
slowQueryFunc := func() (string, int64) {
return "SELECT * FROM users", 100
}
time.Sleep(2 * time.Second) // 模拟一个慢查询
customLogger.Trace(ctx, slowQueryBegin, slowQueryFunc, nil)
// 测试 Trace 方法,模拟一个错误查询
errorQueryBegin := time.Now()
errorQueryFunc := func() (string, int64) {
return "SELECT * FROM non_existent_table", 0
}
customLogger.Trace(ctx, errorQueryBegin, errorQueryFunc, fmt.Errorf("table not found"))
// 由于日志是异步的,我们需要在测试结束时等待一段时间以确保所有日志都被输出
time.Sleep(500 * time.Millisecond)
}
// 测试执行时间记录
func TestExecuteDuration(t *testing.T) {
// 创建一个 zap logger 实例
zapLogger, _ := zap.NewProduction()
defer zapLogger.Sync() // 确保日志被刷新
// 创建一个带有自定义字段和配置的 Logger 实例
customLogger := NewDefault(zapLogger)
ctx := context.Background()
now := time.Now().UnixMilli()
ctx = context.WithValue(ctx, constant.CONTEXT_KEY_EXECUTE_START_TIME, now)
// 模拟一个耗时操作
time.Sleep(1670 * time.Millisecond)
// 测试 ExecuteDuration 方法
customLogger.Info(ctx, "This is an info message")
}
中间件中增加网络耗时时间统计。
通过接口实现配置注入。
Go
package ginmiddleware
import (
"github.com/bwmarrin/snowflake"
"gorm.io/gorm/logger"
)
type GinMiddlewareParamsProvider interface {
GetRouterPrefix() string
GetSnowflakeNode() *snowflake.Node
GetLogger() logger.Interface
GetVersion() string
}
Go
package ginmiddleware
import (
"bytes"
"io"
"strconv"
"strings"
"time"
"{{your module}}/go-core/utils/constant"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
func RequestLog(provider GinMiddlewareParamsProvider) gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request.Body != nil {
bodyBytes, _ := io.ReadAll(c.Request.Body)
defer c.Request.Body.Close()
if len(bodyBytes) > 0 {
bodyStr := string(bodyBytes)
//去除空格、去除换行等转义字符
bodyStr = strings.ReplaceAll(bodyStr, " ", "")
bodyStr = strings.ReplaceAll(bodyStr, "\n", "")
bodyStr = strings.ReplaceAll(bodyStr, "\t", "")
bodyStr = strings.ReplaceAll(bodyStr, "\r", "")
// 获取当前时间戳,毫秒时间戳
now := time.Now().UnixMilli()
duration := int64(0)
// 请求开始时间,用于计算网络耗时
requestStartTime := c.GetHeader(constant.HEADER_KEY_REQUEST_START_TIME)
if requestStartTime != "" {
startTime, _ := strconv.ParseInt(requestStartTime, 10, 64)
duration = now - startTime
}
provider.GetLogger().Info(c.Request.Context(), "", zap.String("request", bodyStr), zap.Int64(constant.CONTEXT_KEY_NETWORK_DURATION, duration))
}
// 创建新的reader,使用bytes.NewReader
// 恢复r.Body,以便可以多次读取
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
c.Next()
}
}
测试用例
Go
package ginmiddleware
import (
"bytes"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"{{your module}}/go-core/utils/constant"
customlogger "{{your module}}/go-core/utils/logger"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
func TestRequestLog(t *testing.T) {
// 创建一个 zap logger 实例
zapLogger, _ := zap.NewProduction()
defer zapLogger.Sync() // 确保日志被刷新
// 创建一个带有自定义字段和配置的 Logger 实例
gloablLogger = customlogger.NewTracingLogger(zapLogger)
// 初始化 gin 引擎
r := gin.New()
// 创建 MockProvider 实例
mockProvider := &MockProvider{}
// 使用 RequestLog 中间件
r.Use(RequestLog(mockProvider))
// 创建一个简单的路由
r.POST("/test", func(c *gin.Context) {
c.String(http.StatusOK, "新年快乐呀!")
})
requestStartTime := time.Now().UnixMilli()
time.Sleep(1 * time.Second)
// 模拟一个 HTTP 请求
w := httptest.NewRecorder()
reqBody := `{"test": "Happy New Year!"}`
req, _ := http.NewRequest(http.MethodPost, "/test", bytes.NewBufferString(reqBody))
// 设置请求头
req.Header.Set(constant.HEADER_KEY_REQUEST_START_TIME, fmt.Sprint(requestStartTime))
// 执行请求
r.ServeHTTP(w, req)
}
日志相关的上下文在中间件中实现。
Go
package ginmiddleware
import (
"context"
"time"
"{{your module}}/go-core/utils/constant"
"github.com/gin-gonic/gin"
)
func CommTrace(provider GinMiddlewareParamsProvider) gin.HandlerFunc {
return func(c *gin.Context) {
// 微服务场景下,我们需要一个全局唯一的 traceId,通常情况下在网关层生成或者LB层生成
traceId := c.GetHeader(constant.HEADER_KEY_TRACE)
//这个最好加个空格 方便分词检索
//我们现在没有用es 而是使用阿里sls
//它需要分词检索
snowflakeNode := provider.GetSnowflakeNode()
if snowflakeNode == nil {
panic("snowflake node is nil")
}
// span是当前请求的唯一标识
spanId := provider.GetRouterPrefix() + " " + snowflakeNode.Generate().String()
//将requestId放入上下文 //兼容之前单体版本的日志
c.Set("requestId", spanId)
// parent_span是父请求的唯一标识
parentSpanId := c.GetHeader(constant.HEADER_KEY_PARENT_SPAN)
//将sapn_id放入header
c.Header(constant.HEADER_KEY_SPAN, spanId)
//将log对象放入上下文
c.Set(constant.CONTEXT_KEY_CUSTOMLOGGER, provider.GetLogger())
// 从当前请求的上下文中获取旧的上下文
oldCtx := c.Request.Context()
// 获取当前时间戳,毫秒时间戳
now := time.Now().UnixMilli()
// 使用context.WithValue来创建一个新的上下文,并设置值
newCtx := context.WithValue(oldCtx, constant.CONTEXT_KEY_TRACE, traceId)
newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_SPAN, spanId)
newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_REQUEST_ID, spanId) //兼容之前单体版本的日志
newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_PARENT_SPAN, parentSpanId)
newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_METHOD, c.Request.Method)
newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_PATH, c.Request.URL.Path)
newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_VERSION, provider.GetVersion())
newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_PROJECT, c.GetHeader(constant.HEADER_KEY_PROJECT))
newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_CALLER_SERVICE_NAME, c.GetHeader(constant.HEADER_KEY_CALLER_SERVICE_NAME))
newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_CALLER_IP, c.ClientIP())
newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_EXECUTE_START_TIME, now)
newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_SPAN_START_TIME, now)
// 更新请求的上下文
c.Request = c.Request.WithContext(newCtx)
c.Next()
}
}
测试用例
Go
package ginmiddleware
import (
"net/http"
"net/http/httptest"
"testing"
"{{your module}}/go-core/utils/constant"
customlogger "{{your module}}/go-core/utils/logger"
"github.com/bwmarrin/snowflake"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gorm.io/gorm/logger"
)
type MockProvider struct{}
var gloablLogger logger.Interface
func (m *MockProvider) GetRouterPrefix() string {
return "mock-router-prefix"
}
func (m *MockProvider) GetSnowflakeNode() *snowflake.Node {
node, _ := snowflake.NewNode(1)
return node
}
func (m *MockProvider) GetLogger() logger.Interface {
return gloablLogger
}
func (m *MockProvider) GetVersion() string {
return "mock-version"
}
func TestCommTrace(t *testing.T) {
// 创建一个 zap logger 实例
zapLogger, _ := zap.NewProduction()
defer zapLogger.Sync() // 确保日志被刷新
// 创建一个带有自定义字段和配置的 Logger 实例
gloablLogger = customlogger.NewTracingLogger(zapLogger)
// 初始化 gin 引擎
r := gin.New()
// 创建 MockProvider 实例
mockProvider := &MockProvider{}
// 使用 CommTrace 中间件
r.Use(CommTrace(mockProvider))
// 创建一个简单的路由
r.GET("/test", func(c *gin.Context) {
gloablLogger.Info(c.Request.Context(), "Hello, World!")
c.String(http.StatusOK, "Hello, World!")
})
// 模拟一个 HTTP 请求
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
// 设置请求头
req.Header.Set(constant.HEADER_KEY_TRACE, "testTraceId")
req.Header.Set(constant.HEADER_KEY_PARENT_SPAN, "testParentSpanId")
req.Header.Set(constant.HEADER_KEY_PROJECT, "testProject")
req.Header.Set(constant.HEADER_KEY_CALLER_SERVICE_NAME, "testService")
// 执行请求
r.ServeHTTP(w, req)
}