架构设计与实现要点
WebSocket长连接网关
采用Hub-Worker模型构建实时审核系统。Hub中心节点负责管理所有客户端连接与消息路由,WorkerPool异步处理审核任务。系统支持500+并发客户端连接,通过动态资源分配确保高吞吐量。
动态批处理算法
基于多维度的智能合并策略:按客户端优先级、文本长度、超时阈值等参数对流式文本分组。通过合并算法将多个请求打包为单个大模型API调用,实测合并率可达10:1,显著降低API成本。
关键技术解决方案
顺序一致性保障
为每个WebSocket连接维护递增序列号(seq)和待响应队列(pendingResponses)。批量审核完成后,根据序列号严格按原始顺序流式返回结果,确保客户端接收顺序与发送顺序一致。
资源优化机制
- 连接池管理:GuardianPool复用审核客户端实例,避免频繁创建销毁
- 弹性工作池:通过ShardCount配置控制Worker分片数量,动态调整Goroutine规模
- 实时监控:集成Prometheus指标暴露协程数、队列深度等关键指标
生产环境特性
高可用保障
- 实现基于Ping/Pong的连接健康检查机制
- 支持优雅关机(Graceful Shutdown),确保审核任务不丢失
- 通过Eureka服务注册与AWS ELB集成,实现自动扩缩容
性能优化公式
动态批处理的合并效率模型
部署与扩展
云原生部署
- 容器化部署于AWS ECS集群
横向扩展设计
-
Hub节点支持无状态水平扩展
package main
import (
//"context""encoding/json" "errors" "fmt" "os" "os/signal" "runtime" "sync/atomic" "syscall" pb "test/proto" "test/servers" "time" tal "/microservices_framework" "/microservices_framework/util" "github.com/astaxie/beego" "github.com/astaxie/beego/logs" "github.com/gorilla/websocket" "go-micro.dev/v4/config" "sync")
const (
BatchTimeout = 2000 * time.Millisecond // 定义批处理超时时间//MaxBatchSize = 1000 // 最大条数 MaxClientsPerBatch = 20 // 每批最大客户端数 MaxBatchDataLen = 4500 // 最大长度 ShardCount = 500 // 处理分片数 MaxGuardianPoolSize = 1000 // 对象池最大容量 ReadTimeout = 10 * time.Second WriteTimeout = 10 * time.Second SendChanBuffer = 2000 // 写通道缓冲区大小 PingInterval = 3 * time.Second)
func monitorGoroutines() {
for {
fmt.Println("Active goroutines:", runtime.NumGoroutine())
time.Sleep(10 * time.Second)
}
}// WebSocket控制器
type WSController struct {
beego.Controller
}// 客户端连接管理
type Client struct {
conn *websocket.Conn
active int32 // 连接状态
sendChan chan interface{} // 消息发送通道seqMu sync.Mutex seq uint64 pendingResponses map[uint64]pendingResponse nextExpectedSeq uint64 curPos uint64}
type batchGroup struct {
messages map[*Client]Message // 客户端到消息的映射
totalLen int // 分组总长度
timer *time.Timer // 组定时器
isPending int // 是否在审核中(0:未审核 1:审核中)
}type pendingResponse struct {
code pb.Errs
msg string
data string
isEnd bool
requestID string
curPos uint64
}// 消息结构
type Message struct {
Data stringjson:"data"
Client *Clientjson:"-"
RequestID stringjson:"-"
IsEnd booljson:"-"
}func (m Message) DeepCopy() Message {
copyMsg := Message{
Data: m.Data,
Client: m.Client,
RequestID: m.RequestID,
IsEnd: m.IsEnd,
}return copyMsg}
type GuardianRequest = servers.GuardianRequest
type Response struct {
Code int64json:"code"
Msg stringjson:"msg"
Data struct {
Result stringjson:"result"
}json:"data"
RequestID stringjson:"requestId"
}type WorkerPool struct {
workChan chan func()
activeTasks *atomic.Int32 // 活跃任务计数(队列中+运行中)
}func NewWorkerPool(size int) *WorkerPool {
pool := &WorkerPool{
workChan: make(chan func(), 10000),
activeTasks: new(atomic.Int32),
}
for i := 0; i < size; i++ {
go pool.worker()
}
return pool
}func (p *WorkerPool) worker() {
for task := range p.workChan {
task()
p.activeTasks.Add(-1)
}
}func (p *WorkerPool) Submit(task func()) bool {
p.activeTasks.Add(1)select { case p.workChan <- task: return true default: p.activeTasks.Add(-1) return false }}
func (p *WorkerPool) ActiveTasks() int {
return int(p.activeTasks.Load())
}// 全局管理器
type Hub struct {
workerPool *WorkerPool
clients map[*Client]bool
batchChan chan Message
register chan *Client
unregister chan *Client
guardianPool *servers.GuardianPool // 添加对象池引用
conf config.Config
mu sync.RWMutex
groups []*batchGroup // 所有分组列表
groupProcessChan chan *batchGroup
quit chan struct{}
}var hub *Hub
func NewHub(conf config.Config) *Hub {
return &Hub{
workerPool: NewWorkerPool(conf.Get("ShardCount").Int(ShardCount)),
clients: make(map[*Client]bool),
batchChan: make(chan Message, 20000),
register: make(chan *Client),
unregister: make(chan *Client, 1000),
guardianPool: servers.NewGuardianPool(conf, conf.Get("MaxGuardianPoolSize").Int(MaxGuardianPoolSize)),
conf: conf,
groups: make([]*batchGroup, 0),
groupProcessChan: make(chan *batchGroup, 2000), // 缓冲通道
quit: make(chan struct{}),
}
}func (c *WSController) Get() {
requestId := util.GetContext(c.Ctx).Get("requestId")
conn, err := websocket.Upgrade(c.Ctx.ResponseWriter, c.Ctx.Request, nil, 0, 0)
if err != nil {
logs.Error("requestId : %s websocket upgrade error: %v", requestId, err)
return
}client := newClient(conn) hub.register <- client // 启动读写协程 go client.writePump(requestId) go client.readPump(requestId)}
func newClient(conn *websocket.Conn) *Client {
c := &Client{
conn: conn,
active: 1, // 原子值,1表示活跃
sendChan: make(chan interface{}, SendChanBuffer),
pendingResponses: make(map[uint64]pendingResponse),
seq: 1,
nextExpectedSeq: 1,
curPos: 1,
}
return c
}func (c *Client) shutdown() {
logs.Info("set active false and close sendChan,client=%p", c)atomic.StoreInt32(&c.active, 0) // 标记为不活跃 close(c.sendChan) c.seqMu.Lock() c.pendingResponses = nil // 释放内存 c.seq = 1 c.nextExpectedSeq = 1 c.curPos = 1 c.seqMu.Unlock() select { case hub.unregister <- c: default: logs.Warn("unregister channel full, client %p not unregistered", c) } c.conn.Close()}
func (c *Client) cacheResponse(code pb.Errs, msg *Message, seq uint64, curPos uint64) {
if atomic.LoadInt32(&c.active) == 0 {
logs.Info("cacheResponse, client %p disconnect", c)
return // 客户端已断开
}c.seqMu.Lock() defer c.seqMu.Unlock() logs.Info("cacheResponse client=%p,seq=%d,seq curPos=%d,msg.IsEnd=%t", c, seq, curPos, msg.IsEnd) // 缓存响应结果 c.pendingResponses[seq] = pendingResponse{ code: code, msg: util.GetMsg(code.String()), requestID: msg.RequestID, data: msg.Data, isEnd: msg.IsEnd, curPos: curPos, } c.flushPendingResponsesLocked()}
func (c *Client) flushPendingResponsesLocked() {
// 收集连续可发送的消息
var toSend []pendingResponse
for {
logs.Info("flushPendingResponsesLocked client=%p,expectSeq=%d", c, c.nextExpectedSeq)
resp, ok := c.pendingResponses[c.nextExpectedSeq]
if !ok {
break
}
//logs.Info("flushPendingResponsesLocked data=%s", resp.data)
toSend = append(toSend, resp)
delete(c.pendingResponses, c.nextExpectedSeq)
c.nextExpectedSeq++
}// 异步发送避免阻塞锁 if len(toSend) > 0 { go c.sendResponses(toSend) }}
func (c *Client) sendResponses(responses []pendingResponse) {
for _, resp := range responses {
if atomic.LoadInt32(&c.active) == 0 {
logs.Error("sendResponses, client active is 0")
return
}msg := map[string]interface{}{ "code": resp.code, "msg": util.GetMsg(resp.code.String()), "requestId": resp.requestID, "data": resp.data, "end": resp.isEnd, "curPos": resp.curPos, } select { case c.sendChan <- msg: default: logs.Error("Send buffer full, message dropped") } }}
func (c *Client) ReadMessage(requestId string) ([]byte, error) {
for {
c.conn.SetReadDeadline(time.Now().Add(ReadTimeout))
t, body, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err) {
logs.Info("RequestID:%s client: %p closed connection,err : %v", requestId, c, err)
} else {
logs.Error("RequestID:%s ,client: %p read error: %v", requestId, c, err)
}
return nil, err
}switch t { case websocket.TextMessage: return body, nil case websocket.BinaryMessage: logs.Debug("requestId : %s receiv binary", requestId) return body, nil case websocket.PingMessage: c.conn.WriteMessage(websocket.PongMessage, []byte("pong")) case websocket.CloseMessage: logs.Info("RequestID:%s received close frame", requestId) return nil, errors.New("receive close frame") } }}
// 客户端消息处理
func (c *Client) readPump(requestId string) {
defer c.shutdown()for { msgBytes, err := c.ReadMessage(requestId) if err != nil { logs.Error("requestId : %s read message : %s", requestId, err.Error()) return } var msgObj struct { Data string `json:"data"` End bool `json:"end"` } if err_info := json.Unmarshal(msgBytes, &msgObj); err_info != nil { logs.Error("Invalid message format:", err_info) continue } isEnd := msgObj.End select { case hub.batchChan <- Message{Data: msgObj.Data, Client: c, RequestID: requestId, IsEnd: isEnd}: default: logs.Error("requestId : %s batchChan Full!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!", requestId) return /*select { case hub.workerPool.workChan <- func() { hub.processSingle(Message{Data: msgObj.Data, Client: c, RequestID: requestId, IsEnd: isEnd}) }: default: // 工作池也满,直接处理 logs.Info("requestId : %s workerPool full, call processSingle()", requestId) hub.processSingle(Message{Data: msgObj.Data, Client: c, RequestID: requestId, IsEnd: isEnd}) }*/ } }}
// 启动写循环
func (c *Client) writePump(requestId string) {
ticker := time.NewTicker(PingInterval)
defer ticker.Stop()if atomic.LoadInt32(&c.active) == 0 { logs.Error("requestId:%s client:%p writePump(): has disconnected", requestId, c) return } defer func() { atomic.StoreInt32(&c.active, 0) // 标记为不活跃 //c.conn.Close() // 退出时关闭连接 }() for { select { case msg, ok := <-c.sendChan: if !ok { return } // 正常写消息 c.conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) if err := c.conn.WriteJSON(msg); err != nil { logs.Error("requestId:%s Write error: %v", requestId, err) return } else { if msgMap, ok := msg.(map[string]interface{}); ok { dataStr, _ := msgMap["data"].(string) end, _ := msgMap["end"].(bool) charCount := len([]rune(dataStr)) logs.Debug("requestId:%s client:%p, write %d Chars,isEnd:%t", msgMap["requestId"], c, charCount, end) } //logs.Debug("requestId:%s client:%p,Write success", requestId, c) } case <-ticker.C: // 发送Ping c.conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { logs.Error("Write ping error: %v", err) return } } }}
func combineStrings(s1, s2 string) string {
n := len(s1) / 2
return s1[n:] + s2[n:] // 两个字符串的后一半拼接
}func mergeMessage1(group *batchGroup, msg Message) bool {
if group.isPending == 1 {
logs.Error("merge new msg into a pending group:%p", group)
return false
}
if _, exists := group.messages[msg.Client]; exists {
oldMsg := group.messages[msg.Client]
// 合并Data到已存在的Message
oldMsg.Data += msg.Data
newLen := len([]rune(msg.Data))
oldMsg.IsEnd = msg.IsEnd
group.messages[msg.Client] = oldMsg
group.totalLen = group.totalLen + newLen
} else {
group.messages[msg.Client] = msg
group.totalLen = group.totalLen + len([]rune(msg.Data))
}
return true
}func (h *Hub) Stop() {
close(h.quit)
}func (h *Hub) setGroupPending(group *batchGroup) {
group.isPending = 1
}// 触发组批量审核
func (h *Hub) processGroup(group *batchGroup) {
if len(group.messages) == 0 {
return
}h.setGroupPending(group) // 转换为消息切片 batch := make([]Message, 0, len(group.messages)) for _, msg := range group.messages { batch = append(batch, msg) } currentBatch := make([]Message, len(batch)) copy(currentBatch, batch) logs.Debug("iiiiiiiii Allgroup=%d,group=%p,clientNum=%d,totalLen=%d", len(h.groups), group, len(group.messages), group.totalLen) if !h.workerPool.Submit(func() { h.processBatch(currentBatch) }) { h.processBatch(currentBatch) }}
func (h *Hub) delBatchGroup(group *batchGroup) bool {
// 移除分组
for i, grp := range h.groups {
if grp == group {
h.groups = append(h.groups[:i], h.groups[i+1:]...)
return true
}
}
return false
}func (h *Hub) newBatchGroup() *batchGroup {
group := &batchGroup{
messages: make(map[*Client]Message),
totalLen: 0,
timer: time.NewTimer(BatchTimeout),
isPending: 0,
}h.workerPool.Submit(func() { g := group // 创建局部变量副本 <-g.timer.C select { case h.groupProcessChan <- g: // 定时器到期,发送分组到处理通道 logs.Debug("3333333333定时器触发批量") default: // 通道已满,直接处理分组 //h.processGroup(g) logs.Info("groupProcessChan通道已满") } }) return group}
func (h *Hub) delBatchGroupClient(client *Client) {
for _, group := range h.groups {
if _, exists := group.messages[client]; exists {
logs.Debug("delBatchGroupClient,delete client:%p from BatchGroup:%p ", client, group)
msg_len := len([]rune(group.messages[client].Data))
delete(group.messages, client)
if len(group.messages) == 0 {
group.messages = make(map[*Client]Message)
}
if group.totalLen > msg_len {
group.totalLen -= msg_len
}
}
}
}func (h *Hub) getBatchGroup(msg Message) *batchGroup {
//保证新msg在已包含该client的batchGroup切片之后或者等于原包含该client的batchGroup
//因为切片元素的前一个定时器创建时间是早于后面的
//基于此可以保证msg批量审核时间会按加入先后顺序来
start_index := 0
for index, group := range h.groups {
if group.isPending == 1 {
continue
}
if _, exists := group.messages[msg.Client]; exists {
start_index = index
if group.totalLen+len([]rune(msg.Data)) < h.conf.Get("MaxBatchDataLen").Int(MaxBatchDataLen) {
return group
}
}
}for index, group := range h.groups { if index < start_index { continue } if group.isPending == 0 && len(group.messages) < h.conf.Get("MaxClientsPerBatch").Int(MaxClientsPerBatch) && group.totalLen+len([]rune(msg.Data)) < h.conf.Get("MaxBatchDataLen").Int(MaxBatchDataLen) { return group } } activeGroup := h.newBatchGroup() h.groups = append(h.groups, activeGroup) return activeGroup}
// 主处理循环
func (h *Hub) Run() {
for {
select {
case client, ok := <-h.register:
if ok {
h.mu.Lock()
logs.Info("register client %p", client)
h.clients[client] = true
h.mu.Unlock()
}
case client, ok := <-h.unregister:
if ok {
h.mu.Lock()
logs.Info("unregister client %p", client)
h.delBatchGroupClient(client)
delete(h.clients, client)
h.mu.Unlock()
}
case msg, ok := <-h.batchChan:
if !ok {
logs.Debug("batchChan closed,hub will down")
continue
}
//start := time.Now()
group := h.getBatchGroup(msg)
merge_ok := mergeMessage1(group, msg)
if !merge_ok {
group = h.getBatchGroup(msg)
mergeMessage1(group, msg)
}
case group := <-h.groupProcessChan:
logs.Debug("gggggggg")
h.processGroup(group)
h.delBatchGroup(group)
case <-h.quit: // 添加退出条件
logs.Info("Hub exiting...")// 处理所有剩余分组 for _, group := range h.groups { // 停止定时器 group.timer.Stop() h.processGroup(group) } close(h.register) // 关闭所有客户端 h.mu.Lock() for client := range h.clients { if len(client.pendingResponses) > 0 { time.Sleep(time.Millisecond * 20) continue } client.shutdown() } h.clients = make(map[*Client]bool) h.mu.Unlock() close(h.batchChan) close(h.unregister) close(h.groupProcessChan) // 关闭工作池 close(h.workerPool.workChan) h.guardianPool.Close() return } }}
// 批量处理逻辑
func (h *Hub) processBatch(batch []Message) {
logs.Info("call processBatch batch len:%d", len(batch))localBatch := make([]Message, len(batch)) // 仅分配空间 for i := range batch { localBatch[i] = batch[i].DeepCopy() } responses := make(map[*Client][]struct { seq uint64 curPos uint64 msg Message }) for _, msg := range localBatch { client := msg.Client // 为每个消息分配序列号 client.seqMu.Lock() seq := client.seq client.seq++ curPos := client.curPos client.curPos += uint64(len([]rune(msg.Data))) client.seqMu.Unlock() responses[client] = append(responses[client], struct { seq uint64 curPos uint64 msg Message }{seq, curPos, msg}) } if len(responses) == 1 { for _, msgs := range responses { for _, item := range msgs { taskMsg := item.msg taskSeq := item.seq taskCurPos := item.curPos if !h.workerPool.Submit(func() { h.processSingle(taskMsg, taskSeq, taskCurPos) }) { logs.Warn("工作池满") } } } return } // 合并消息内容 var combined string for _, msg := range localBatch { combined += msg.Data + "\n" } // 截断处理 if runes := []rune(combined); len(runes) > MaxBatchDataLen { combined = string(runes[len(runes)-MaxBatchDataLen:]) } batchGuardian := h.guardianPool.Acquire() defer hub.guardianPool.Release(batchGuardian) anotherIdx := len(localBatch) - 1 combRequestId := combineStrings(localBatch[0].RequestID, localBatch[anotherIdx].RequestID) subRequestId := fmt.Sprintf("%s%d", combRequestId, localBatch[0].Client.seq) // 创建批量审核请求 req := &GuardianRequest{ AccessKeyID: h.conf.Get("appKey").String(""), Data: combined, RequestID: subRequestId, } // 发送批量审核 resp, err := batchGuardian.Request(req) if err == nil && resp.Data.IsNormal { // 批量审核成功,为每个消息缓存结果 for client, msgs := range responses { for _, item := range msgs { client.cacheResponse(pb.Errs_OK, &item.msg, item.seq, item.curPos) } } } else { maxInterval := 400 * time.Millisecond baseInterval := 20 * time.Millisecond interval := baseInterval ticker := time.NewTicker(interval) defer ticker.Stop() for _, msgs := range responses { for _, item := range msgs { <-ticker.C activeTasks := h.workerPool.ActiveTasks() poolCapacity := cap(h.workerPool.workChan) // 工作池总容量 if activeTasks > 30 { newInterval := interval * 2 if newInterval > maxInterval { newInterval = maxInterval } if newInterval != interval { interval = newInterval ticker.Reset(interval) logs.Info("高负载,间隔增加至%v", interval) } } logs.Debug("99999999 activeTasks:%d,poolCapacity:%d", activeTasks, poolCapacity) taskMsg := item.msg taskSeq := item.seq taskCurPos := item.curPos if !h.workerPool.Submit(func() { h.processSingle(taskMsg, taskSeq, taskCurPos) }) { logs.Warn("工作池满,直接处理单条审核") } } } }}
// 单条处理逻辑
func (h *Hub) processSingle(msg Message, seq uint64, curPos uint64) {
singleGuardian := h.guardianPool.Acquire()
defer h.guardianPool.Release(singleGuardian) // 确保释放回对象池requestId := msg.RequestID logs.Info("call processSingle requestId:%s,seq:%d", requestId, seq) subRequestId := fmt.Sprintf("%s%d", msg.RequestID, seq) req := &GuardianRequest{ AccessKeyID: h.conf.Get("appKey").String(""), Data: msg.Data, RequestID: subRequestId, } resp, err := singleGuardian.Request(req) if err != nil || resp == nil { logs.Error("requestId : %s Single audit failed: %s", requestId, err) msg.Client.cacheResponse(pb.Errs_INVAL_ALG, &msg, seq, curPos) return } if resp.Code != 0 { logs.Error("requestId : %s guardian response code is not ok: %d", requestId, resp.Code) msg.Client.cacheResponse(pb.Errs_INVAL_ALG, &msg, seq, curPos) return } if !resp.Data.IsNormal { logs.Error("requestId : %s response is not normal data", requestId) msg.Client.cacheResponse(pb.Errs_Sensitive_DATA, &msg, seq, curPos) return } msg.Client.cacheResponse(pb.Errs_OK, &msg, seq, curPos)}
func main() {
server := tal.NewServer("/", "text")
if err := server.Init(); err != nil {
logs.Error("server init failed: %v", err)
return
}logs.Info(server.Config()) hub = NewHub(server.Config()) go hub.Run() defer hub.Stop() sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGTSTP) go func() { <-sigChan hub.Stop() // 信号中断时主动关闭 os.Exit(1) }() go monitorGoroutines() server.Insert("/", &WSController{}) server.Run()}