golang:微服务架构下的日志追踪系统

背景

为了把一个请求中所有的日志串联起来,我在日志中增加一个traceId的信息。《gorm log with traceId 打印带有traceId信息的日志,通过context实现》在这篇文章中,提供了第一个版本的解决方案。这个解决方案主要是解决单体服务架构下的日志串联问题。随着,我们分布式应用的发展,我们面临着需要一个能串联跨服务调用链路的日志系统。因此,我们在第一个版本上做了优化处理。主要是增加了日志的统计维度。

统计维度说明

trace:全局唯一的链路追踪标识,通常是在负载均衡或者网关层生成。

span:当前请求的唯一链路追踪标识,通常是一个服务请求的链路追踪标识。

parent_span:上一级服务请求的唯一链路追踪标识。

caller_service_name:上一级请求服务的名称。在微服务架构下,我们时常会遇到某个错误的传参导致错误,但是,却苦于不知道是哪个服务错误传参,导致增加了bug定位的难度。另外,有了这个维度,我们也可以基于该维度统计各个服务调用频次信息,为我们的服务优化提供了清晰明了的方向。

caller_ip:调用方的ip,由于微服务架构下,一个服务会存在多个实例,有了ip,我们可以更加精准的实现具体pod的定位。

为了更加清楚地知道每一步的执行时间,单个请求的总体耗时,以及网络耗时。增加这些维度重点是要做到心中有数,知道我们服务的边界在哪里,知道我们服务的瓶颈在哪里,为我们的服务优化提供更加清晰明了的方向。

execute_start_time:单步执行开始时间。

execute_duration:单步执行耗时时长。

span_start_time:跨度开始时间,通常是指一个服务接收到请求的时间。

span_duration:跨度耗时时长,通常是指一个服务从接受到请求的时间到发出响应这段总耗时。

network_duration:网络耗时。由于,从上一级服务发生请求到我们服务接收到请求之间有一层网络耗时,因此,我们增加了这个维度。这个统计维度,我们需要在中间件中实现。

解决方案

这个部分我们就直接上代码了。

这个部分是日志中需要用到的常量的定义,主要是日志统计维度key定义。

Go 复制代码
package constant

// 日志上下文中的键
const (
	CONTEXT_KEY_TRACE               = "trace"
	CONTEXT_KEY_SPAN                = "span"
	CONTEXT_KEY_PARENT_SPAN         = "parent_span"
	CONTEXT_KEY_METHOD              = "method"
	CONTEXT_KEY_PATH                = "path"
	CONTEXT_KEY_VERSION             = "version"
	CONTEXT_KEY_PROJECT             = "project"
	CONTEXT_KEY_REQUEST_ID          = "Request-Id"
	CONTEXT_KEY_CALLER_SERVICE_NAME = "caller_service_name"
	CONTEXT_KEY_CALLER_IP           = "caller_ip"
	CONTEXT_KEY_CUSTOMLOGGER        = "customLogger"
	CONTEXT_KEY_EXECUTE_START_TIME  = "execute_start_time"
	CONTEXT_KEY_EXECUTE_DURATION    = "execute_duration" // 单步执行时间
	CONTEXT_KEY_SPAN_START_TIME     = "span_start_time"
	CONTEXT_KEY_SPAN_DURATION       = "span_duration"    // 跨度持续时间
	CONTEXT_KEY_NETWORK_DURATION    = "network_duration" // 网络耗时
)

// 日志http请求头中的键
const (
	HEADER_KEY_TRACE               = "trace"
	HEADER_KEY_SPAN                = "span"
	HEADER_KEY_PARENT_SPAN         = "parent_span"
	HEADER_KEY_PROJECT             = "project"
	HEADER_KEY_CALLER_SERVICE_NAME = "caller_service_name"
	HEADER_KEY_RESPONSE_TIME       = "response_time"      // 响应时间,用于统计网络耗时
	HEADER_KEY_REQUEST_START_TIME  = "request_start_time" // 请求开始时间,用于统计网络耗时,接受前端或者其他服务的请求时间
)

主要是初始化的方法修改。主要是增加了这个方法:

Info方法中增加时长的统计的逻辑。

Go 复制代码
package logger

import (
	"context"
	"errors"
	"fmt"
	"time"

	"{{your module}}/go-core/utils/constant"

	"go.uber.org/zap"
	"go.uber.org/zap/zapcore"
	"gorm.io/gorm"
	"gorm.io/gorm/logger"
	"gorm.io/gorm/utils"
)

// Logger logger for gorm2
type Logger struct {
	log *zap.Logger
	logger.Config
	customFields []func(ctx context.Context) zap.Field
}

// Option logger/recover option
type Option func(l *Logger)

// WithCustomFields optional custom field
func WithCustomFields(fields ...func(ctx context.Context) zap.Field) Option {
	return func(l *Logger) {
		l.customFields = fields
	}
}

// WithConfig optional custom logger.Config
func WithConfig(cfg logger.Config) Option {
	return func(l *Logger) {
		l.Config = cfg
	}
}

// SetGormDBLogger set db logger
func SetGormDBLogger(db *gorm.DB, l logger.Interface) {
	db.Logger = l
}

// New logger form gorm2
func New(zapLogger *zap.Logger, opts ...Option) logger.Interface {
	l := &Logger{
		log: zapLogger,
		Config: logger.Config{
			SlowThreshold:             200 * time.Millisecond,
			Colorful:                  false,
			IgnoreRecordNotFoundError: false,
			LogLevel:                  logger.Warn,
		},
	}
	for _, opt := range opts {
		opt(l)
	}
	return l
}

// NewDefault new default logger
// 初始化一个默认的 logger
func NewDefault(zapLogger *zap.Logger) logger.Interface {
	return New(zapLogger, WithCustomFields(
		func(ctx context.Context) zap.Field {
			v := ctx.Value("Request-Id")
			if v == nil {
				return zap.Skip()
			}
			if vv, ok := v.(string); ok {
				return zap.String("trace", vv)
			}
			return zap.Skip()
		}, func(ctx context.Context) zap.Field {
			v := ctx.Value("method")
			if v == nil {
				return zap.Skip()
			}
			if vv, ok := v.(string); ok {
				return zap.String("method", vv)
			}
			return zap.Skip()
		}, func(ctx context.Context) zap.Field {
			v := ctx.Value("path")
			if v == nil {
				return zap.Skip()
			}
			if vv, ok := v.(string); ok {
				return zap.String("path", vv)
			}
			return zap.Skip()
		}, func(ctx context.Context) zap.Field {
			v := ctx.Value("version")
			if v == nil {
				return zap.Skip()
			}
			if vv, ok := v.(string); ok {
				return zap.String("version", vv)
			}
			return zap.Skip()
		}),
		WithConfig(logger.Config{
			SlowThreshold:             200 * time.Millisecond,
			Colorful:                  false,
			IgnoreRecordNotFoundError: false,
			LogLevel:                  logger.Info,
		}))
}

// 用于支持微服务架构下的链路追踪
func NewTracingLogger(zapLogger *zap.Logger) logger.Interface {
	return New(zapLogger, WithCustomFields(
		// trace是链路追踪的唯一标识
		// span是当前请求的唯一标识
		// parent_span是父请求的唯一标识
		func(ctx context.Context) zap.Field {
			v := ctx.Value(constant.CONTEXT_KEY_TRACE)
			if v == nil {
				return zap.Skip()
			}
			if vv, ok := v.(string); ok {
				return zap.String(constant.CONTEXT_KEY_TRACE, vv)
			}
			return zap.Skip()
		}, func(ctx context.Context) zap.Field {
			v := ctx.Value(constant.CONTEXT_KEY_SPAN)
			if v == nil {
				return zap.Skip()
			}
			if vv, ok := v.(string); ok {
				return zap.String(constant.CONTEXT_KEY_SPAN, vv)
			}
			return zap.Skip()
		}, func(ctx context.Context) zap.Field {
			v := ctx.Value(constant.CONTEXT_KEY_PARENT_SPAN)
			if v == nil {
				return zap.Skip()
			}
			if vv, ok := v.(string); ok {
				return zap.String(constant.CONTEXT_KEY_PARENT_SPAN, vv)
			}
			return zap.Skip()
		}, func(ctx context.Context) zap.Field {
			v := ctx.Value(constant.CONTEXT_KEY_METHOD)
			if v == nil {
				return zap.Skip()
			}
			if vv, ok := v.(string); ok {
				return zap.String(constant.CONTEXT_KEY_METHOD, vv)
			}
			return zap.Skip()
		}, func(ctx context.Context) zap.Field {
			v := ctx.Value(constant.CONTEXT_KEY_PATH)
			if v == nil {
				return zap.Skip()
			}
			if vv, ok := v.(string); ok {
				return zap.String(constant.CONTEXT_KEY_PATH, vv)
			}
			return zap.Skip()
		}, func(ctx context.Context) zap.Field {
			v := ctx.Value(constant.CONTEXT_KEY_VERSION)
			if v == nil {
				return zap.Skip()
			}
			if vv, ok := v.(string); ok {
				return zap.String(constant.CONTEXT_KEY_VERSION, vv)
			}
			return zap.Skip()
		}, func(ctx context.Context) zap.Field {
			// 用于标识调用方服务名
			v := ctx.Value(constant.CONTEXT_KEY_CALLER_SERVICE_NAME)
			if v == nil {
				return zap.Skip()
			}
			if vv, ok := v.(string); ok {
				return zap.String(constant.CONTEXT_KEY_CALLER_SERVICE_NAME, vv)
			}
			return zap.Skip()
		}, func(ctx context.Context) zap.Field {
			// 用于标识调用方ip
			v := ctx.Value(constant.CONTEXT_KEY_CALLER_IP)
			if v == nil {
				return zap.Skip()
			}
			if vv, ok := v.(string); ok {
				return zap.String(constant.CONTEXT_KEY_CALLER_IP, vv)
			}
			return zap.Skip()
		}),
		WithConfig(logger.Config{
			SlowThreshold:             200 * time.Millisecond,
			Colorful:                  false,
			IgnoreRecordNotFoundError: false,
			LogLevel:                  logger.Info,
		}))
}

// LogMode log mode
func (l *Logger) LogMode(level logger.LogLevel) logger.Interface {
	newLogger := *l
	newLogger.LogLevel = level
	return &newLogger
}

// Info print info
func (l Logger) Info(ctx context.Context, msg string, args ...interface{}) {
	if l.LogLevel >= logger.Info {
		//预留10个字段位置
		fields := make([]zap.Field, 0, 10+len(l.customFields))
		fields = append(fields, zap.String("file", utils.FileWithLineNum()))
		for _, customField := range l.customFields {
			fields = append(fields, customField(ctx))
		}
		now := time.Now().UnixMilli()
		// 从ctx中获取操作的开始时间
		if v := ctx.Value(constant.CONTEXT_KEY_EXECUTE_START_TIME); v != nil {
			if vv, ok := v.(int64); ok {
				// 计算操作的执行时间,以毫秒为单位
				duration := now - vv
				// 将操作的执行时间放入ctx
				fields = append(fields, zap.Int64(constant.CONTEXT_KEY_EXECUTE_DURATION, duration))
			}
		}
		for _, arg := range args {
			if vv, ok := arg.(zapcore.Field); ok {
				if len(vv.String) > 0 {
					fields = append(fields, zap.String(vv.Key, vv.String))
				} else if vv.Integer > 0 {
					fields = append(fields, zap.Int64(vv.Key, vv.Integer))
				} else {
					fields = append(fields, zap.Any(vv.Key, vv.Interface))
				}
			}
		}
		l.log.Info(msg, fields...)
	}
}

// Warn print warn messages
func (l Logger) Warn(ctx context.Context, msg string, args ...interface{}) {
	if l.LogLevel >= logger.Warn {
		//预留10个字段位置
		fields := make([]zap.Field, 0, 10+len(l.customFields))
		fields = append(fields, zap.String("file", utils.FileWithLineNum()))
		for _, customField := range l.customFields {
			fields = append(fields, customField(ctx))
		}
		for _, arg := range args {
			if vv, ok := arg.(zapcore.Field); ok {
				if len(vv.String) > 0 {
					fields = append(fields, zap.String(vv.Key, vv.String))
				} else if vv.Integer > 0 {
					fields = append(fields, zap.Int64(vv.Key, vv.Integer))
				} else {
					fields = append(fields, zap.Any(vv.Key, vv.Interface))
				}
			}
		}
		l.log.Warn(msg, fields...)
	}
}

// Error print error messages
func (l Logger) Error(ctx context.Context, msg string, args ...interface{}) {
	if l.LogLevel >= logger.Error {
		//预留10个字段位置
		fields := make([]zap.Field, 0, 10+len(l.customFields))
		fields = append(fields, zap.String("file", utils.FileWithLineNum()))
		for _, customField := range l.customFields {
			fields = append(fields, customField(ctx))
		}

		for _, arg := range args {
			if vv, ok := arg.(zapcore.Field); ok {
				if len(vv.String) > 0 {
					fields = append(fields, zap.String(vv.Key, vv.String))
				} else if vv.Integer > 0 {
					fields = append(fields, zap.Int64(vv.Key, vv.Integer))
				} else {
					fields = append(fields, zap.Any(vv.Key, vv.Interface))
				}
			}
		}
		l.log.Error(msg, fields...)
	}
}

// Trace print sql message
func (l Logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
	if l.LogLevel <= logger.Silent {
		return
	}
	fields := make([]zap.Field, 0, 6+len(l.customFields))
	elapsed := time.Since(begin)
	switch {
	case err != nil && l.LogLevel >= logger.Error && (!l.IgnoreRecordNotFoundError || !errors.Is(err, gorm.ErrRecordNotFound)):
		for _, customField := range l.customFields {
			fields = append(fields, customField(ctx))
		}
		fields = append(fields,
			zap.Error(err),
			zap.String("file", utils.FileWithLineNum()),
			zap.Duration("latency", elapsed),
		)

		sql, rows := fc()
		if rows == -1 {
			fields = append(fields, zap.String("rows", "-"))
		} else {
			fields = append(fields, zap.Int64("rows", rows))
		}
		fields = append(fields, zap.String("sql", sql))
		l.log.Error("", fields...)
	case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= logger.Warn:
		for _, customField := range l.customFields {
			fields = append(fields, customField(ctx))
		}
		fields = append(fields,
			zap.Error(err),
			zap.String("file", utils.FileWithLineNum()),
			zap.String("slow!!!", fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)),
			zap.Duration("latency", elapsed),
		)

		sql, rows := fc()
		if rows == -1 {
			fields = append(fields, zap.String("rows", "-"))
		} else {
			fields = append(fields, zap.Int64("rows", rows))
		}
		fields = append(fields, zap.String("sql", sql))
		l.log.Warn("", fields...)
	case l.LogLevel == logger.Info:
		for _, customField := range l.customFields {
			fields = append(fields, customField(ctx))
		}
		fields = append(fields,
			zap.Error(err),
			zap.String("file", utils.FileWithLineNum()),
			zap.Duration("latency", elapsed),
		)

		sql, rows := fc()
		if rows == -1 {
			fields = append(fields, zap.String("rows", "-"))
		} else {
			fields = append(fields, zap.Int64("rows", rows))
		}
		fields = append(fields, zap.String("sql", sql))
		l.log.Info("", fields...)
	}
}

// Immutable custom immutable field
// Deprecated: use Any instead
func Immutable(key string, value interface{}) func(ctx context.Context) zap.Field {
	return Any(key, value)
}

// Any custom immutable any field
func Any(key string, value interface{}) func(ctx context.Context) zap.Field {
	field := zap.Any(key, value)
	return func(ctx context.Context) zap.Field { return field }
}

// String custom immutable string field
func String(key string, value string) func(ctx context.Context) zap.Field {
	field := zap.String(key, value)
	return func(ctx context.Context) zap.Field { return field }
}

// Int64 custom immutable int64 field
func Int64(key string, value int64) func(ctx context.Context) zap.Field {
	field := zap.Int64(key, value)
	return func(ctx context.Context) zap.Field { return field }
}

// Uint64 custom immutable uint64 field
func Uint64(key string, value uint64) func(ctx context.Context) zap.Field {
	field := zap.Uint64(key, value)
	return func(ctx context.Context) zap.Field { return field }
}

// Float64 custom immutable float32 field
func Float64(key string, value float64) func(ctx context.Context) zap.Field {
	field := zap.Float64(key, value)
	return func(ctx context.Context) zap.Field { return field }
}

测试用例:

Go 复制代码
package logger

import (
	"context"
	"fmt"
	"testing"
	"time"

	"{{your module}}/go-core/utils/constant"
	"go.uber.org/zap"
)


// 测试微服务模式下的日志记录
func TestTracingLog(t *testing.T) {
	// 创建一个 zap logger 实例
	zapLogger, _ := zap.NewProduction()
	defer zapLogger.Sync() // 确保日志被刷新

	// 创建一个带有自定义字段和配置的 Logger 实例
	customLogger := NewTracingLogger(zapLogger)
	ctx := context.Background()
	ctx = context.WithValue(ctx, constant.CONTEXT_KEY_SPAN, "123456")
	ctx = context.WithValue(ctx, constant.CONTEXT_KEY_PARENT_SPAN, "parent_span_123456")
	ctx = context.WithValue(ctx, constant.CONTEXT_KEY_TRACE, "trace_id_123456")
	ctx = context.WithValue(ctx, constant.CONTEXT_KEY_METHOD, "POST")
	ctx = context.WithValue(ctx, constant.CONTEXT_KEY_PATH, "/api/v1/users")
	ctx = context.WithValue(ctx, constant.CONTEXT_KEY_VERSION, "v1.0.0")
	ctx = context.WithValue(ctx, constant.CONTEXT_KEY_CALLER_SERVICE_NAME, "user-service")
	ctx = context.WithValue(ctx, constant.CONTEXT_KEY_CALLER_IP, "172.0.0.3")

	// 测试 Info 方法
	customLogger.Info(ctx, "This is an info message")

	// 测试 Warn 方法
	customLogger.Warn(ctx, "This is a warning message")

	// 测试 Error 方法
	customLogger.Error(ctx, "This is an error message")

	// 测试 Trace 方法,模拟一个慢查询
	slowQueryBegin := time.Now()
	slowQueryFunc := func() (string, int64) {
		return "SELECT * FROM users", 100
	}
	time.Sleep(2 * time.Second) // 模拟一个慢查询
	customLogger.Trace(ctx, slowQueryBegin, slowQueryFunc, nil)

	// 测试 Trace 方法,模拟一个错误查询
	errorQueryBegin := time.Now()
	errorQueryFunc := func() (string, int64) {
		return "SELECT * FROM non_existent_table", 0
	}
	customLogger.Trace(ctx, errorQueryBegin, errorQueryFunc, fmt.Errorf("table not found"))

	// 由于日志是异步的,我们需要在测试结束时等待一段时间以确保所有日志都被输出
	time.Sleep(500 * time.Millisecond)
}

// 测试执行时间记录
func TestExecuteDuration(t *testing.T) {
	// 创建一个 zap logger 实例
	zapLogger, _ := zap.NewProduction()
	defer zapLogger.Sync() // 确保日志被刷新

	// 创建一个带有自定义字段和配置的 Logger 实例
	customLogger := NewDefault(zapLogger)
	ctx := context.Background()
	now := time.Now().UnixMilli()
	ctx = context.WithValue(ctx, constant.CONTEXT_KEY_EXECUTE_START_TIME, now)
	// 模拟一个耗时操作
	time.Sleep(1670 * time.Millisecond)
	// 测试 ExecuteDuration 方法
	customLogger.Info(ctx, "This is an info message")
}

中间件中增加网络耗时时间统计。

通过接口实现配置注入。

Go 复制代码
package ginmiddleware

import (
	"github.com/bwmarrin/snowflake"
	"gorm.io/gorm/logger"
)

type GinMiddlewareParamsProvider interface {
	GetRouterPrefix() string
	GetSnowflakeNode() *snowflake.Node
	GetLogger() logger.Interface
	GetVersion() string
}
Go 复制代码
package ginmiddleware

import (
	"bytes"
	"io"
	"strconv"
	"strings"
	"time"

	"{{your module}}/go-core/utils/constant"
	"github.com/gin-gonic/gin"
	"go.uber.org/zap"
)

func RequestLog(provider GinMiddlewareParamsProvider) gin.HandlerFunc {
	return func(c *gin.Context) {
		if c.Request.Body != nil {
			bodyBytes, _ := io.ReadAll(c.Request.Body)
			defer c.Request.Body.Close()
			if len(bodyBytes) > 0 {
				bodyStr := string(bodyBytes)
				//去除空格、去除换行等转义字符
				bodyStr = strings.ReplaceAll(bodyStr, " ", "")
				bodyStr = strings.ReplaceAll(bodyStr, "\n", "")
				bodyStr = strings.ReplaceAll(bodyStr, "\t", "")
				bodyStr = strings.ReplaceAll(bodyStr, "\r", "")
				// 获取当前时间戳,毫秒时间戳
				now := time.Now().UnixMilli()
				duration := int64(0)
				// 请求开始时间,用于计算网络耗时
				requestStartTime := c.GetHeader(constant.HEADER_KEY_REQUEST_START_TIME)
				if requestStartTime != "" {
					startTime, _ := strconv.ParseInt(requestStartTime, 10, 64)
					duration = now - startTime
				}
				provider.GetLogger().Info(c.Request.Context(), "", zap.String("request", bodyStr), zap.Int64(constant.CONTEXT_KEY_NETWORK_DURATION, duration))
			}
			// 创建新的reader,使用bytes.NewReader
			// 恢复r.Body,以便可以多次读取
			c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
		}
		c.Next()
	}
}

测试用例

Go 复制代码
package ginmiddleware

import (
	"bytes"
	"fmt"
	"net/http"
	"net/http/httptest"
	"testing"
	"time"

	"{{your module}}/go-core/utils/constant"
	customlogger "{{your module}}/go-core/utils/logger"
	"github.com/gin-gonic/gin"
	"go.uber.org/zap"
)

func TestRequestLog(t *testing.T) {
	// 创建一个 zap logger 实例
	zapLogger, _ := zap.NewProduction()
	defer zapLogger.Sync() // 确保日志被刷新
	// 创建一个带有自定义字段和配置的 Logger 实例
	gloablLogger = customlogger.NewTracingLogger(zapLogger)
	// 初始化 gin 引擎
	r := gin.New()

	// 创建 MockProvider 实例
	mockProvider := &MockProvider{}
	// 使用 RequestLog 中间件
	r.Use(RequestLog(mockProvider))

	// 创建一个简单的路由
	r.POST("/test", func(c *gin.Context) {
		c.String(http.StatusOK, "新年快乐呀!")
	})
	requestStartTime := time.Now().UnixMilli()
	time.Sleep(1 * time.Second)
	// 模拟一个 HTTP 请求
	w := httptest.NewRecorder()
	reqBody := `{"test": "Happy New Year!"}`
	req, _ := http.NewRequest(http.MethodPost, "/test", bytes.NewBufferString(reqBody))
	// 设置请求头
	req.Header.Set(constant.HEADER_KEY_REQUEST_START_TIME, fmt.Sprint(requestStartTime))
	// 执行请求
	r.ServeHTTP(w, req)
}

日志相关的上下文在中间件中实现。

Go 复制代码
package ginmiddleware

import (
	"context"
	"time"

	"{{your module}}/go-core/utils/constant"
	"github.com/gin-gonic/gin"
)

func CommTrace(provider GinMiddlewareParamsProvider) gin.HandlerFunc {
	return func(c *gin.Context) {
		// 微服务场景下,我们需要一个全局唯一的 traceId,通常情况下在网关层生成或者LB层生成
		traceId := c.GetHeader(constant.HEADER_KEY_TRACE)
		//这个最好加个空格 方便分词检索
		//我们现在没有用es 而是使用阿里sls
		//它需要分词检索
		snowflakeNode := provider.GetSnowflakeNode()
		if snowflakeNode == nil {
			panic("snowflake node is nil")
		}
		// span是当前请求的唯一标识
		spanId := provider.GetRouterPrefix() + " " + snowflakeNode.Generate().String()
		//将requestId放入上下文 //兼容之前单体版本的日志
		c.Set("requestId", spanId)
		// parent_span是父请求的唯一标识
		parentSpanId := c.GetHeader(constant.HEADER_KEY_PARENT_SPAN)
		//将sapn_id放入header
		c.Header(constant.HEADER_KEY_SPAN, spanId)
		//将log对象放入上下文
		c.Set(constant.CONTEXT_KEY_CUSTOMLOGGER, provider.GetLogger())
		// 从当前请求的上下文中获取旧的上下文
		oldCtx := c.Request.Context()

		// 获取当前时间戳,毫秒时间戳
		now := time.Now().UnixMilli()
		// 使用context.WithValue来创建一个新的上下文,并设置值
		newCtx := context.WithValue(oldCtx, constant.CONTEXT_KEY_TRACE, traceId)
		newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_SPAN, spanId)
		newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_REQUEST_ID, spanId) //兼容之前单体版本的日志
		newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_PARENT_SPAN, parentSpanId)
		newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_METHOD, c.Request.Method)
		newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_PATH, c.Request.URL.Path)
		newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_VERSION, provider.GetVersion())
		newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_PROJECT, c.GetHeader(constant.HEADER_KEY_PROJECT))
		newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_CALLER_SERVICE_NAME, c.GetHeader(constant.HEADER_KEY_CALLER_SERVICE_NAME))
		newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_CALLER_IP, c.ClientIP())
		newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_EXECUTE_START_TIME, now)
		newCtx = context.WithValue(newCtx, constant.CONTEXT_KEY_SPAN_START_TIME, now)
		// 更新请求的上下文
		c.Request = c.Request.WithContext(newCtx)

		c.Next()
	}
}

测试用例

Go 复制代码
package ginmiddleware

import (
	"net/http"
	"net/http/httptest"
	"testing"

	"{{your module}}/go-core/utils/constant"
	customlogger "{{your module}}/go-core/utils/logger"
	"github.com/bwmarrin/snowflake"
	"github.com/gin-gonic/gin"
	"go.uber.org/zap"
	"gorm.io/gorm/logger"
)

type MockProvider struct{}

var gloablLogger logger.Interface

func (m *MockProvider) GetRouterPrefix() string {
	return "mock-router-prefix"
}
func (m *MockProvider) GetSnowflakeNode() *snowflake.Node {
	node, _ := snowflake.NewNode(1)
	return node
}
func (m *MockProvider) GetLogger() logger.Interface {
	return gloablLogger
}
func (m *MockProvider) GetVersion() string {
	return "mock-version"
}

func TestCommTrace(t *testing.T) {
	// 创建一个 zap logger 实例
	zapLogger, _ := zap.NewProduction()
	defer zapLogger.Sync() // 确保日志被刷新
	// 创建一个带有自定义字段和配置的 Logger 实例
	gloablLogger = customlogger.NewTracingLogger(zapLogger)
	// 初始化 gin 引擎
	r := gin.New()

	// 创建 MockProvider 实例
	mockProvider := &MockProvider{}
	// 使用 CommTrace 中间件
	r.Use(CommTrace(mockProvider))

	// 创建一个简单的路由
	r.GET("/test", func(c *gin.Context) {
		gloablLogger.Info(c.Request.Context(), "Hello, World!")
		c.String(http.StatusOK, "Hello, World!")
	})

	// 模拟一个 HTTP 请求
	w := httptest.NewRecorder()
	req, _ := http.NewRequest("GET", "/test", nil)

	// 设置请求头
	req.Header.Set(constant.HEADER_KEY_TRACE, "testTraceId")
	req.Header.Set(constant.HEADER_KEY_PARENT_SPAN, "testParentSpanId")
	req.Header.Set(constant.HEADER_KEY_PROJECT, "testProject")
	req.Header.Set(constant.HEADER_KEY_CALLER_SERVICE_NAME, "testService")

	// 执行请求
	r.ServeHTTP(w, req)
}
相关推荐
tatasix9 小时前
Redis 实现分布式锁
数据库·redis·分布式
杰克逊的日记12 小时前
Spark的原理以及使用
大数据·分布式·spark
编程、小哥哥17 小时前
Spring Boot项目中分布式锁实现方案:Redisson
spring boot·分布式·后端
飞火流星0202717 小时前
Kraft模式安装Kafka(含常规、容器两种安装方式)
分布式·容器·kafka·k8s·kraft模式
西瓜味儿的小志17 小时前
Kafka为什么快(高性能的原因)
分布式·中间件·kafka
默辨18 小时前
浅谈分布式共识算法
分布式·区块链·共识算法
kikyo哎哟喂1 天前
分布式锁常见实现方案总结
分布式
武子康1 天前
大数据-266 实时数仓 - Canal 对接 Kafka 客户端测试
java·大数据·数据仓库·分布式·kafka