背景描述
有一个需求,大概可以描述为:有多个websocket连接,因此消息会并发地发送过来,这些消息中有一个标志可以表明是哪个连接发来的消息,但只有收到消息后才能建立channel或写入已有channel,在收消息前无法预先创建channel
解决过程(可直接阅读最终版)
初版:直接写入
因为对数据量错误预估(以为数据量不大),一开始我是用的mysql直接写入,每次收到ws消息立即处理,可测试中发现因数据量过多且都会操作同一行数据,出现了资源竞争,导致死锁。
第二版:增加锁
在发现出现数据竞争后,我第一反应是增加读写锁。读写锁的代码类似以下示例:
go
package main
import (
"database/sql"
"fmt"
"sync"
_ "github.com/go-sql-driver/mysql"
)
var (
db *sql.DB
mu sync.RWMutex
)
func init() {
var err error
db, err = sql.Open("mysql", "username:password@tcp(localhost:3306)/dbname")
if err != nil {
panic(err)
}
}
func main() {
defer db.Close()
// 读取数据
go readData()
// 写入数据
go writeData()
// 保持主线程运行
select {}
}
func readData() {
for {
mu.RLock()
rows, err := db.Query("SELECT * FROM table_name")
mu.RUnlock()
if err != nil {
fmt.Println("Error reading data:", err)
continue
}
defer rows.Close()
// 处理查询结果
// ...
// 睡眠一段时间,模拟读操作的持续性
// 请注意,这是一个简单示例,实际应用中可能需要更复杂的逻辑
// 或使用定时器进行控制
}
}
func writeData() {
for {
mu.Lock()
_, err := db.Exec("INSERT INTO table_name (column1, column2) VALUES (?, ?)", value1, value2)
mu.Unlock()
if err != nil {
fmt.Println("Error writing data:", err)
continue
}
// 睡眠一段时间,模拟写操作的持续性
// 请注意,这是一个简单示例,实际应用中可能需要更复杂的逻辑
// 或使用定时器进行控制
}
}
但是代码里对数据库的操作非常频繁且混乱,加了读写锁后经常出现请求很慢的情况,考虑其他方案
第三版 使用事务
使用事务代码忽略,最终发现,因为事务过长,导致出现了重复写的问题,考虑其他方案
第四版 map
通过一个二维的map来存储数据,每当数据存满10条就处理,当然毫不意外的,出现了map的竞争。map也是可以用锁的,但是这里是二维的map,加上两层锁之后使得效率极低,而且依旧有概率出现map竞争导致报错
此外,还可以考虑使用redis设置锁,直接set就行了,但是因为环境不支持redis,此方案弃用
最终版 动态channel
出现以上问题的根本原因是消费太快,其实完全可以把每个ws连接的数据都写到各自的channel里,同时设置每个channel都累积10条再消费,当然还需要一个处理机制,如果超过10s也消费一次。
启动"生产者"、"消费者"
在当前环境中,生产者就是每次从ws中读到数据往动态channel中写入,消费者就是不断获取有哪些channel,以及从channel中读数据,在ws写入时的处理逻辑大概可以简化为如下demo:
go
package test
import (
"context"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
"net/http"
"sync"
)
// RequestTemplate 请求模板
type RequestTemplate struct {
Op string `json:"op"` // 操作
Id int `json:"id"` // 唯一id标识
Time string `json:"time"` // 时间,用秒级时间戳,字符串包裹
Data *RequestTemplateData `json:"data"` // 请求数据
Code int `json:"code"` // 状态码
}
// RequestTemplateData 请求中data包含的部分,实际这里是很复杂的结构,之前超时/死锁也是因为这里处理逻辑比较复杂,但是这篇博客的演示重点不是这个,因此简略为id和请求ip
type RequestTemplateData struct {
ConnIp string `json:"conn_ip"` // 请求ip
Id int `json:"id"` // 唯一id标识
}
// ConnInfo 具体的连接信息
type ConnInfo struct {
Conn *websocket.Conn `json:"conn"` // websocket连接
Ctx context.Context `json:"ctx"` // 连接上下文
CtxCancel context.CancelFunc `json:"cancel"` // 连接上下文cancel function
Ip string `json:"ip"` // 连接的手机端ip
Id int `json:"id"` // 唯一id标识
}
var AllConns = make(map[string]*ConnInfo) //创建字典集合存储连接信息
// Start 启动
func Start() {
//处理ws的连接
http.HandleFunc("/ws", HandleMsg)
// //监听7001端口号,作为websocket连接的服务
log.Info("Server started on :7001")
log.Fatal(http.ListenAndServe(":7001", nil))
}
// ChannelStorage channel数据
type ChannelStorage struct {
sync.RWMutex
channels map[string]chan *RequestTemplateData
}
var ConnRequestData map[int]*RequestTemplateData
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// HandleMsg 处理ws连接,每来一个新客户端请求就建立一个新连接
func HandleMsg(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil) // 协议升级,这里也可以直连
if err != nil {
log.Error(err)
return
}
//获取连接ip,这里是为了区分每个连接
connIp := conn.RemoteAddr().String()
// 这里是为了后续关闭channel
rootCtx := context.Background()
ctx, cancel := context.WithCancel(rootCtx)
//加入连接
AllConns[connIp] = &ConnInfo{
Conn: conn, // 客户端ws链接对象
Ctx: ctx, // 连接上下文
CtxCancel: cancel, // 取消连接上下文
}
defer func() {
// 如果断开连接,删除数据
if AllConns[connIp] != nil {
AllConns[connIp].CtxCancel()
delete(ConnRequestData, AllConns[connIp].Id)
go SetDoneData(AllConns[connIp].Id, conn) // 这里对结束做处理
}
delete(AllConns, conn.RemoteAddr().String())
err = conn.Close()
if err != nil {
return
}
log.Error("HandleMsg异常,开始defer处理:", err)
if err := recover(); err != nil {
log.Error("websocket连接异常,已断开:", err)
}
}()
log.WithFields(log.Fields{
"connIp": connIp,
}).Info("沙箱已连接")
reqCh := &ChannelStorage{}
go reqCh.ResultConsumer(ctx) // 这里是消费者
//循环读取ws客户端的消息
for {
// 读取消息
_, msg, err := conn.ReadMessage()
if err != nil {
log.WithFields(log.Fields{
"connIp": connIp,
}).WithError(err).Error("读取websocket的消息失败")
if AllConns[connIp] != nil {
delete(ConnRequestData, AllConns[connIp].Id)
go SetDoneData(AllConns[connIp].Id, conn) // 连接断开设置状态为结束
}
// 断开ws连接
conn.Close()
delete(AllConns, conn.RemoteAddr().String())
return
}
//msg []byte转string
msgStr := string(msg)
log.Info("收到消息为:", msgStr)
//反序列化消息为结构体
requestData := RequestTemplate{}
if err := json.Unmarshal(msg, &requestData); err != nil {
conn.WriteJSON(gin.H{"id": "未知", "op": "未知", "error": "cmd通信的请求参数有误,无法json decode"})
log.Error("json_decode cmd命令的请求参数时出错:", err)
continue
}
dataInfo := requestData.Data
// 这里实际上有很多操作,简写为两种
if requestData.Op != "" {
switch requestData.Op {
// 收到报告
case "report":
go reqCh.Produce(dataInfo) // "生产者",发送一条消息
// 已完成
case "done":
go CheckDone(dataInfo, conn) // 做完成的处理
default:
log.Error("未识别的命令:", msgStr)
}
}
}
}
有一个for循环在持续监听ws消息,消费者只启动一次,这里重点就是生产和消费如何实现
"生产者"
"生产者"要做的事就是:
1 每当收到ws消息后,解析,拿到唯一id(这个唯一是指这个连接下的所有上报消息的id都是相同的)
2 判断这个"唯一id"是否已经创建了channel,若创建了则不需要创建,直接写入channel,若未创建则新建channel
以下是生产者的demo:
go
// GetChannel 获取通道
func (cs *ChannelStorage) GetChannel(key string) chan *RequestTemplateData {
cs.RLock()
defer cs.RUnlock()
return cs.channels[key]
}
// CreateChannel 创建通道并存储到 map 中
func (cs *ChannelStorage) CreateChannel(key string) chan *RequestTemplateData {
cs.Lock()
defer cs.Unlock()
if cs.channels == nil {
cs.channels = make(map[string]chan *RequestTemplateData, 800)
}
ch := make(chan *RequestTemplateData, 10)
cs.channels[key] = ch
return ch
}
// Produce 往上报channel中写数据
func (cs *ChannelStorage) Produce(requestData *RequestTemplateData) {
defer func() {
if err := recover(); err != nil {
log.Info("_____________recover CaseResultAdd error________: ", err)
}
}()
// 创建存储通道的结构体实例
chanelKey := strconv.Itoa(requestData.Id)
channel := cs.GetChannel(chanelKey)
if channel == nil {
channel = cs.CreateChannel(chanelKey)
}
// 直接往channel里面塞
if channel != nil {
channel <- requestData
}
}
消费者
消费者由于只启动一次,但后续可能会有新的channel,因此需要增加一个获取所有连接的方法:
消费者demo:
go
func (cs *ChannelStorage) ResultConsumer(ctx context.Context) {
defer func() {
if err := recover(); err != nil {
log.Info("_____________recover CaseResultConsumer error________: ", err)
}
}()
for {
select {
case <-ctx.Done():
log.Info("websocket断开连接,消费者协程退出...")
return
default:
cs.processAllChannels(ctx) // 传入 context.Context
time.Sleep(2 * time.Second) // 控制处理频率
}
}
}
// processAllChannels 获取所有channel
func (cs *ChannelStorage) processAllChannels(ctx context.Context) {
cs.RLock()
defer cs.RUnlock()
var wg sync.WaitGroup // 用于等待所有通道处理完毕
for chName, channel := range cs.channels {
wg.Add(1)
go func(chName string, channel chan *RequestTemplateData) {
defer wg.Done()
cs.processChannel(chName, channel, ctx)
}(chName, channel)
}
wg.Wait() // 等待所有通道处理完毕
}
func (cs *ChannelStorage) processChannel(chName string, channel chan *RequestTemplateData, ctx context.Context) {
const batchSize = 10 // 每次处理的数据量
var messages []*RequestTemplateData
targetMsgOverTime := 10 * time.Second // 超时时间
for {
select {
case caseMsg := <-channel:
messages = append(messages, caseMsg) // 将接收到的消息放入 messages 切片中
if len(messages) == batchSize {
tmpMessages := messages
messages = nil
processMessages(tmpMessages)
}
case <-time.After(targetMsgOverTime):
log.Info("Timeout reached. Processing...")
if len(messages) > 0 {
tmpMessages := messages
messages = nil
log.Info("Processing remaining messages for channel:", chName)
processMessages(tmpMessages)
}
case <-ctx.Done(): // 如果收到上下文取消信号,退出函数
log.Info("______________________error__________cancel______")
return
}
}
}
func processMessages(messages []*RequestTemplateData) {
// 在这里处理消息就是批量的了
}