万字解析gorilla/websocket源码

项目整体结构

gorilla/websocket整体项目概览

bash 复制代码
gorilla/websocket/
├── conn.go               # 核心连接实现(读写消息、状态管理、控制帧处理等)
├── client.go             # 客户端握手和连接逻辑
├── server.go             # 服务端握手和升级逻辑
├── compression.go        # 支持 WebSocket 的数据压缩(permessage-deflate)
├── json.go               # 基于 JSON 的读写辅助方法
├── mask.go               # 客户端 masking 实现
├── mask_safe.go          # 可选的 masking 安全实现(避免数据泄露)
├── prepared.go           # PreparedMessage 缓存机制优化性能
├── proxy.go              # WebSocket 代理支持
├── join.go               # 内部数据拼接优化
├── util.go               # 一些工具函数(如校验等)
├── doc.go                # 包级文档说明(GoDoc 用)
├── README.md             # 项目说明

2. 关于websocket帧的简单说明

lua 复制代码
  0               1               2               3
  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
 +-+-+-+-+-------+-+-------------+-------------------------------+
 |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
 |I|S|S|S|  (4)  |A|     (7)     |         (16/64)               |
 |N|V|V|V|       |S|             | (if payload len==126 or 127)  |
 | |1|2|3|       |K|             |                               |
 +-+-+-+-+-------+-+-------------+-------------------------------+
 |    Masking-key (if MASK set) |       Payload Data            |
 +------------------------------+-------------------------------+

FIN: 1个比特

如果是1,表示这是消息(message)的最后一个分片(fragment),如果是0,表示不是是消息(message)的最后一个分片(fragment)。

RSV1, RSV2, RSV3: 各占1个比特

一般情况下全为0。当客户端、服务端协商采用WebSocket扩展时,这三个标志位可以非0,且值的含义由扩展进行定义。如果出现非零的值,且并没有采用WebSocket扩展,连接出错。

Opcode: 4个比特

操作代码,Opcode的值决定了应该如何解析后续的数据载荷(data payload)。如果操作代码是不认识的,那么接收端应该断开连接(fail the connection)。可选的操作代码如下:

%x0:表示一个延续帧。当Opcode为0时,表示本次数据传输采用了数据分片,当前收到的数据帧为其中一个数据分片;

%x1:表示这是一个文本帧(frame);

%x2:表示这是一个二进制帧(frame);

%x3-7:保留的操作代码,用于后续定义的非控制帧;

%x8:表示连接断开;

%x8:表示这是一个ping操作;

%xA:表示这是一个pong操作;

%xB-F:保留的操作代码,用于后续定义的控制帧。

Mask: 1个比特

表示是否要对数据载荷进行掩码操作。从客户端向服务端发送数据时,需要对数据进行掩码操作;从服务端向客户端发送数据时,不需要对数据进行掩码操作。

如果服务端接收到的数据没有进行过掩码操作,服务端需要断开连接。

如果Mask是1,那么在Masking-key中会定义一个掩码键(masking key),并用这个掩码键来对数据载荷进行反掩码。所有客户端发送到服务端的数据帧,Mask都是1。

掩码的算法、用途在下一小节讲解。

Payload length: 数据载荷的长度,单位是字节。为7位,或7+16位,或1+64位

假设数Payload length === x,如果:

x为0~126:数据的长度为x字节;

x为126:后续2个字节代表一个16位的无符号整数,该无符号整数的值为数据的长度;

x为127:后续8个字节代表一个64位的无符号整数(最高位为0),该无符号整数的值为数据的长度。

此外,如果payload length占用了多个字节的话,payload length的二进制表达采用网络序(big endian,重要的位在前)。

Masking-key: 0或4字节(32位)

所有从客户端传送到服务端的数据帧,数据载荷都进行了掩码操作,Mask为1,且携带了4字节的Masking-key。如果Mask为0,则没有Masking-key。

备注:载荷数据的长度,不包括mask key的长度。

Payload data: (x+y) 字节

载荷数据:

包括了扩展数据、应用数据。其中,扩展数据x字节,应用数据y字节;

扩展数据:

如果没有协商使用扩展的话,扩展数据数据为0字节。所有的扩展都必须声明扩展数据的长度,或者可以如何计算出扩展数据的长度。此外,扩展如何使用必须在握手阶段就协商好。如果扩展数据存在,那么载荷数据长度必须将扩展数据的长度包含在内;

应用数据:

任意的应用数据,在扩展数据之后(如果存在扩展数据),占据了数据帧剩余的位置。载荷数据长度 减去 扩展数据长度,就得到应用数据的长度。

关于http升级为websocket一些请求头中的参数说明

客户端请求头示例

makefile 复制代码
GET /ws HTTP/1.1
Host: example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Sec-WebSocket-Version: 13
Origin: https://example.com
Sec-WebSocket-Protocol: chat
  • Upgrade: websocket:告诉服务器端需要将http连接升级为websocket,必须
  • Connection: Upgrade:表明此次连接是为了升级协议而建立的,必须
  • Sec-WebSocket-Key:客户端生成的一个Base64编码的随机字符串(16字节),必须
  • Sec-WebSocket-Version:指定websocket的协议版本,必须
  • Origin:表明请求源,可选
  • Sec-WebSocket-Protocol:作为子协议,客户端期望使用的 WebSocket 协议之上的协议,可选
  • Sec-WebSocket-Extensions:用于协商 WebSocket 的扩展机制,比如说启动压缩扩展,可选

服务器端响应头示例

makefile 复制代码
HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
  • HTTP/1.1 101 Switching Protocols:表明服务器端同意升级,必须
  • Upgrade: websocket:通知客户端,服务器已同意将协议升级为 websocket,必须
  • Connection: Upgrade:示当前连接将被升级,必须
  • Sec-WebSocket-Accept: xxxxxxx:与请求头中的Sec-WebSocket-Key相对应,表示通过校验,可以升级协议,必须
  • Sec-WebSocket-Protocol:同上
  • Sec-WebSocket-Extensions:同上

Server

核心数据结构

作为一个websocket库中server端最重要则是处理升级流程,并返回conn实例,下面来看一下gorilla/websocket这个库中server端的核心类中的属性以及方法

类图

重要的属性

  • HandshakeTimeout:指定升级完成的持续时间
  • ReadBufferSize:指定读缓冲区大小
  • WriteBufferSize:指定写缓冲区大小
  • WriteBufferPool:指定写缓冲池
  • Subprotocols:与客户端约定的子协议列表
  • Error:自定义升级过程中出现错误之后的http错误返回
  • CheckOrigin:自定义检查跨域请求
  • EnableCompression:指定服务器端是否协商开启消息压缩

重要的方法

  • Upgrade:处理升级的http升级为websocket的方法

服务端核心升级流程

整个服务端的升级过程大概可以分为几个阶段:

  1. 检查http请求头是否需要升级
  2. 获取一些客户端请求头中携带的信息并做一些基础校验
  3. 劫持http中tcp的网络连接实例
  4. 构造websocket的连接实例(conn)
  5. 构造http升级为websocket的响应消息并返回

核心源码解读

go 复制代码
// http升级为websocket的方法
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {

    // 1.下面的判断逻辑都是在校验该http请求是否需要升级为websocket,检查协议层面预定的内容
    const badHandshake = "websocket: the client is not using the websocket protocol: "
    if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
       return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header")
    }
    if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
       w.Header().Set("Upgrade", "websocket")
       return u.returnError(w, r, http.StatusUpgradeRequired, badHandshake+"'websocket' token not found in 'Upgrade' header")
    }
    if r.Method != http.MethodGet {
       return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
    }
    if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
       return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
    }

    // 2.这里校验服务器端响应给客户端的http头部分中不能够包含协议扩展字段,防止协议出错
    if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
       return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported")
    }

    // 3.校验跨域请求
    checkOrigin := u.CheckOrigin
    // 3.1如果用户没有自定义处理跨域的函数,这里就默认调用库中实现的函数来判断是否跨域
    if checkOrigin == nil {
       checkOrigin = checkSameOrigin// checkSameOrigin函数会判断如果客户端没有设置跨域或者客户端ip与服务器端ip相同则返回true,否则返回false
    }
    if !checkOrigin(r) {
       return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin")
    }

    // 4.校验http请求头中Sec-WebSocket-Key字段是否符合规范     
    challengeKey := r.Header.Get("Sec-Websocket-Key")
    if !isValidChallengeKey(challengeKey) {
       return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
    }

    // 5.这里是获取websocket子协议的内容
    // 如果Upgrader这个类中设置了Subprotocols这个属性,则会去匹配http请求头中客户端是否设置了Sec-Websocket-Protocol这个值以及是否有匹配的内容,并且只会取第一个匹配到的子协议
    // 如果Upgrader没有设置这个属性则取responseHeader中设置的值
    // 如果都没有则是空字符串
    subprotocol := u.selectSubprotocol(r, responseHeader)

    // 6.如果服务器端设置开启了传输过程中的数据压缩,则会判断客户端是否也支持开启,只有客户端和服务器端同时设置了才能生效
    var compress bool
    if u.EnableCompression {
       for _, ext := range parseExtensions(r.Header) {
          if ext[""] != "permessage-deflate" {
             continue
          }
          compress = true
          break
       }
    }

    // 7.从http层面劫持出底层的tcp网络连接实例和读写缓冲IO(bufio.ReadWriter)
    netConn, brw, err := http.NewResponseController(w).Hijack()
    if err != nil {
       return u.returnError(w, r, http.StatusInternalServerError,
          "websocket: hijack: "+err.Error())
    }

    // 8.这里是判断当整个方法执行结束之后如果存在错误则会关闭此次tcp连接
    defer func() {
       if netConn != nil {
          _ = netConn.Close()
       }
    }()

    // 9.如果服务器端没有设置读缓冲区大小并且bufio.ReadWriter的大小大于256字节,则会复用这个tcp连接本来的bufio.ReadWriter
    var br *bufio.Reader
    if u.ReadBufferSize == 0 && brw.Reader.Size() > 256 {
       br = brw.Reader
    } else if brw.Reader.Buffered() > 0 {
       // 9.1如果说在升级过程中读缓冲区中的数据不为空(这种应该算是很少见或者异常的场景),则会去实现bufio.Reader接口,这样的处理是为了在升级之后会先去读取这部分缓存的消息,防止被丢失
       netConn = &brNetConn{br: brw.Reader, Conn: netConn}
    }

    // 10.获取现有tcp连接中可用的写缓冲区大小
    buf := brw.Writer.AvailableBuffer()
    // 10.1如果服务器端没有设置写缓冲池和写缓冲区大小以及tcp连接中可用的缓冲区大小足够大,则会复用这个tcp的写缓冲区
    var writeBuf []byte
    if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
       writeBuf = buf
    }

    // 11.构建conn连接对象(conn部分会详细讲解)
    c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
    c.subprotocol = subprotocol
    // 11.1判断是设置启数据压缩
    if compress {
       c.newCompressionWriter = compressNoContextTakeover
       c.newDecompressionReader = decompressNoContextTakeover
    }

    // 12.这里的buf取决于上面tcp的写缓冲区大小和conn的写缓冲区大小,取最大值
    p := buf
    if len(c.writeBuf) > len(p) {
       p = c.writeBuf
    }
    p = p[:0]

    // 13.构建响应头
    p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
    // 13.1服务器端会根据客户端传入的Sec-WebSocket-Key来构造对应的Sec-WebSocket-Accept响应给客户端
    p = append(p, computeAcceptKey(challengeKey)...)
    p = append(p, "\r\n"...)
    //13.2如果设置了子协议也会在响应头中构造对应的信息
    if c.subprotocol != "" {
       p = append(p, "Sec-WebSocket-Protocol: "...)
       p = append(p, c.subprotocol...)
       p = append(p, "\r\n"...)
    }
    //13.3如果开启了数据压缩也会在响应头中构造对应的信息
    if compress {
       p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
    }
    //13.4构造服务器自定义的一些响应头
    for k, vs := range responseHeader {
       if k == "Sec-Websocket-Protocol" {
          continue
       }
       for _, v := range vs {
          p = append(p, k...)
          p = append(p, ": "...)
          for i := 0; i < len(v); i++ {
             b := v[i]
             if b <= 31 {
                // prevent response splitting.
                b = ' '
             }
             p = append(p, b)
          }
          p = append(p, "\r\n"...)
       }
    }
    p = append(p, "\r\n"...)

    // 14.如果自定义了握手的超时时间,则设置自定义的握手超时时间
    if u.HandshakeTimeout > 0 {
       if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil {
          return nil, err
       }
    } else {
       // 14.1如果没有自定义握手超时时间,则清除http默认设置的超时时间
       if err := netConn.SetDeadline(time.Time{}); err != nil {
          return nil, err
       }
    }

    // 15.将消息写入网络中
    if _, err = netConn.Write(p); err != nil {
       return nil, err
    }
   
    if u.HandshakeTimeout > 0 {
       if err := netConn.SetWriteDeadline(time.Time{}); err != nil {
          return nil, err
       }
    }
    // 16.返回conn连接实例
    netConn = nil
    return c, nil
}

2.流程图

Client

核心数据结构

作为一个gorilla/websocket库中client端构造的结构体,它的作用是 "定义客户端如何建立一个 WebSocket 连接",包括底层 TCP 连接、代理处理、TLS 加密、握手超时、缓冲区配置、子协议选择、压缩协商等所有必要参数。

类图

重要的属性(在server端重复的属性这里就不叙述了)

  • NetDial:自定义TCP连接的 Dial 方法(不带 context)
  • NetDialContext:自定义TCP连接的 Dial 方法(带 context)
  • NetDialTLSContext:自定义TLS+TCP连接的 Dial 方法(带 context)
  • Proxy:返回请求使用的代理 URL
  • TLSClientConfig:TLS 加密相关配置
  • Jar:Cookie

...

重要的方法

  • DialContext:通过http协议与服务端建立websocket握手,并将http升级为websocket,最后返回websocket连接对象

客户端核心升级流程

整个客户端从http升级为websocket过程大概可以分为几个阶段

  1. 解析 URL、生成握手 Key
  2. 构建 HTTP Upgrade 请求
  3. 检查是否走代理
  4. 构建 TCP 连接 (Dial)
  5. (如有) 代理 CONNECT / TLS Handshake
  6. 写 HTTP Upgrade 请求
  7. 等待并解析服务器 Upgrade 响应
  8. 成功后返回 WebSocket 连接对象

核心源码解读

go 复制代码
//todo 省略掉一些无关主流程的代码,包括httptrace对象和一些钩子函数的设置
func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
   
    // 1.生成Sec-WebSocket-Key
    challengeKey, err := generateChallengeKey()
    if err != nil {
       return nil, nil, err
    }

    // 2.解析传入的url,生成封装好的URL对象
    u, err := url.Parse(urlStr)
    if err != nil {
       return nil, nil, err
    }
    switch u.Scheme {
    case "ws":
       u.Scheme = "http"
    case "wss":
       u.Scheme = "https"
    default:
       return nil, nil, errMalformedURL
    }

    // 3.构建http请求基本参数
    req := &http.Request{
       Method:     http.MethodGet,
       URL:        u,
       Proto:      "HTTP/1.1",
       ProtoMajor: 1,
       ProtoMinor: 1,
       Header:     make(http.Header),
       Host:       u.Host,
    }
    req = req.WithContext(ctx)
    if d.Jar != nil {
       for _, cookie := range d.Jar.Cookies(u) {
          req.AddCookie(cookie)
       }
    }

    // 4.构建http升级为websocket的请求头部分信息
    req.Header["Upgrade"] = []string{"websocket"}
    req.Header["Connection"] = []string{"Upgrade"}
    req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
    req.Header["Sec-WebSocket-Version"] = []string{"13"}
    // 4.1如果subprotocl不为空则设置扩展协议
    if len(d.Subprotocols) > 0 {
       req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
    }
    for k, vs := range requestHeader {
       switch {
       case k == "Host":
          if len(vs) > 0 {
             req.Host = vs[0]
          }
       case k == "Upgrade" ||
          k == "Connection" ||
          k == "Sec-Websocket-Key" ||
          k == "Sec-Websocket-Version" ||
          k == "Sec-Websocket-Extensions" ||
          (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
          return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
       case k == "Sec-Websocket-Protocol":
          req.Header["Sec-WebSocket-Protocol"] = vs
       default:
          req.Header[k] = vs
       }
    }
    // 4.2如果开启数据压缩则设置Sec-WebSocket-Extensions的请求头信息
    if d.EnableCompression {
       req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"}
    }
    // 4.3构建握手超时上下文
    if d.HandshakeTimeout != 0 {
       var cancel func()
       ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
       defer cancel()
    }

    // 5.判断如果设置了代理,则生成代理的URL对象
    var proxyURL *url.URL
    if d.Proxy != nil {
       proxyURL, err = d.Proxy(req)
       if err != nil {
          return nil, nil, err
       }
    }

    // 6.构建拨号器函数
    netDial, err := d.netDialFn(ctx, proxyURL, u)
    if err != nil {
       return nil, nil, err
    }

    // 7.获取有端口号的host和无无端口号的host
    hostPort, hostNoPort := hostPortNoPort(u)

    // 8.建立TCP连接
    netConn, err := netDial(ctx, "tcp", hostPort)
    if err != nil {
       return nil, nil, err
    }


    // 9.提前设置错误后tcp连接实例Close
    defer func() {
       if netConn != nil {
          _ = netConn.Close()
       }
    }()

    // 10.如果设置了代理,则需要使用代理的连接,否则直接使用netConn
    if proxyURL != nil && u.Scheme == "https" {

       cfg := cloneTLSConfig(d.TLSClientConfig)
       if cfg.ServerName == "" {
          cfg.ServerName = hostNoPort
       }
       tlsConn := tls.Client(netConn, cfg)
       netConn = tlsConn

       if trace != nil && trace.TLSHandshakeStart != nil {
          trace.TLSHandshakeStart()
       }
       err := doHandshake(ctx, tlsConn, cfg)
       if err != nil {
          return nil, nil, err
       }
    }

    // 11.构建websocket连接对象
    conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)

    // 12.WebSocket Upgrade HTTP 请求
    if err := req.Write(netConn); err != nil {
       return nil, nil, err
    }

    // 13.读取http请求返回,并检查是否成功升级为websocket连接
    resp, err := http.ReadResponse(conn.br, req)
    // 13.1这里做了详细错误判断,特殊检查TLS协议协商时NextProtos配置是否缺少http/1.1,否则WebSocket无法升级
    if err != nil {
       if d.TLSClientConfig != nil {
          for _, proto := range d.TLSClientConfig.NextProtos {
             if proto != "http/1.1" {
                return nil, nil, fmt.Errorf(
                   "websocket: protocol %q was given but is not supported;"+
                      "sharing tls.Config with net/http Transport can cause this error: %w",
                   proto, err,
                )
             }
          }
       }
       return nil, nil, err
    }
    // 13.2将服务器响应头中的Cookies存入Dialer.Jar
    if d.Jar != nil {
       if rc := resp.Cookies(); len(rc) > 0 {
          d.Jar.SetCookies(u, rc)
       }
    }
    // 13.3校验是否成功完成WebSocket升级
    if resp.StatusCode != 101 ||
       !tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
       !tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
       resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
       buf := make([]byte, 1024)
       n, _ := io.ReadFull(resp.Body, buf)
       resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
       return nil, resp, ErrBadHandshake
    }
    // 13.4解析响应头中的WebSocket扩展
    for _, ext := range parseExtensions(resp.Header) {
       if ext[""] != "permessage-deflate" {
          continue
       }
       _, snct := ext["server_no_context_takeover"]
       _, cnct := ext["client_no_context_takeover"]
       if !snct || !cnct {
          return nil, resp, errInvalidCompression
       }
       conn.newCompressionWriter = compressNoContextTakeover
       conn.newDecompressionReader = decompressNoContextTakeover
       break
    }
    
    // 14.将ResponseBody重置为空(避免后续对Body的误读)
    // 因为WebSocket协议握手后,TCP连接上就开始直接发送WebSocket帧了,这个HTTPResponseBody就没有意义了
    resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
    
    // 15.读取服务器返回的Sec-WebSocket-Protocol,记录协商成功的子协议(如果有)
    conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")

    // 16.清除连接超时
    if err := netConn.SetDeadline(time.Time{}); err != nil {
       return nil, resp, err
    }

    // 17.清空临时的tcp连接对象并返回websocket连接对象
    netConn = nil
    return conn, resp, nil
}

2.流程图

额外流程补充

关于如何构建netDialFn这里有必要说明一下(源码在github.com/gorilla/websocket/client.go里面)

netDialFn的代码解读
go 复制代码
func (d *Dialer) netDialFn(ctx context.Context, proxyURL *url.URL, backendURL *url.URL) (netDialerFunc, error) {
    var netDial netDialerFunc

    // 1.获取基本的dialFn
    if proxyURL != nil {
       netDial = d.netDialFromURL(proxyURL)
    } else {
       netDial = d.netDialFromURL(backendURL)
    }

    // 2.如果ctx带了deadline,就在拨号函数外层包一层,
    if deadline, ok := ctx.Deadline(); ok {
       netDial = netDialWithDeadline(netDial, deadline)
    }

    // 3.如果我们要用代理,还要再包一层CONNECT隧道逻辑
    if proxyURL != nil {
       return proxyFromURL(proxyURL, netDial)
    }
    return netDial, nil
}
2.流程说明

Conn

核心数据结构

GorillaWebSocket的核心连接对象,负责维护、读写状态、缓冲区管理、协议状态机等,下面将介绍一下关键的一些成员变量以及方法

首先可以将Conn类中的成员变量分为三类,基础信息、write相关、read相关

类图

基础属性信息

  • conn:底层的TCP连接对象(也有可能是TSL Conn)
  • isServer:标识当前是客户端还是服务器端
  • subprotocol:websocket子协议

write相关属性

  • writeBuf:WebSocket帧头+负载的缓冲区,所有待发送数据都会先写到这个缓冲区
  • writePool:缓冲池,用于复用 writeBuf,减少内存分配开销
  • writeBufSize:writeBuf 的总大小
  • writeDeadline:当前 write 操作的截止时间,写数据时会设置到TCPConn上
  • writer:当前返回给上层应用的Writer(由NextWriter返回),可以是gorillaWebSocket库的实现,可以是用户自己的实现(最终都会实现Write这个接口方法)
  • isWriting:标志当前是否正在写数据,用于检测并发写入的panic保护

...

read相关属性

  • reader:当前活跃的Reader
  • br:读取用的缓冲区封装,负责管理底层TCP的recv缓存
  • readRemaining:当前Frame剩余未读的字节数
  • readFinal:当前Frame是否为该Message的最后一帧(FIN Bit)
  • readLength:当前Message的总长度
  • readMaskPos:当前解掩码的偏移位置
  • readMaskKey:当前Frame的MaskingKey
  • handlePong:Pong控制帧的回调函数
  • handlePing:Ping控制帧的回调函数
  • handleClose:Close控制帧的回调函数

...

再来看一下Conn中一些重要的方法,同样的这些重要的方法可以分为write和read两部分来说明

write相关方法

  • WriteMessage:一次性写入一整条消息(文本、二进制或控制帧)
  • NextWriter:返回一个io.WriteCloser,用于逐步写入一条消息数据(支持分片写入)
  • beginMessage:初始化并准备一个新的消息写入器
  • write:底层写方法,将缓冲数据写入网络连接
  • writeBufs:将多个字节切片拼接写入底层网络连接
  • flushFrame(这个虽然不是Conn的方法但是很重要所以在这里也列举了出来):构造一个完整的 WebSocket 帧,将缓冲区的数据和额外数据写出

...

read相关方法

  • ReadMessage:一次性读取一条完整的 WebSocket 消息
  • NextReader:返回当前消息的 io.Reader,支持分段读取
  • advanceFrame:解析下一个 WebSocket 帧的完整结构,是该库最底层协议解析方法

写消息

下面将详细讲解一下写消息的整个流程,我将从WriteMessage这个方法出发,讲解一下整个write流程,首先我会先将所有写相关的方法的调用链路进行说明,然后再对每一个链路上面的方法做单独详细的讲解,自顶向下的为大家讲解write的流程细节。

write调用链路说明

可以从上述的流程中看到通过WriteMessage这个方法进入整个写出流程的方法调用链,可以看到无论是直接写数据还是现将数据先写进缓存区再写出,最后真正写出数据的时候都会调用到Conn.conn.write这个方法,当然用户也可以自己实现write。现在我们已经知道了大致的write写出链路,但是还不知道为什么会有这么多分支流程以及什么情况下会走什么样的流程,所以接下来将会分开对这个调用链路中重要方法中的代码进行解读。

WriteMessage源码解读

这个方法是在使用这个库中最常用的write方法,同时也是最顶层的API,所以先从这个方法开始解读整个write流程

go 复制代码
func (c *Conn) WriteMessage(messageType int, data []byte) error {

    // 1.如果当前是服务端并且没有采用数据压缩的话就直接将数据写出去
    if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
       var mw messageWriter
       // 1.1初始化messageWriter写入器,设置消息类型与写缓冲区,为新消息做准备
       if err := c.beginMessage(&mw, messageType); err != nil {
          return err
       }
       // 1.2将数据尽可能写入缓冲区中,如果有剩余部分作为额外数据帧传入
       n := copy(c.writeBuf[mw.pos:], data)
       mw.pos += n
       data = data[n:]
       // 1.3flushFrame将缓冲区数据与extra数据一起打包成WebSocket帧并发送
       // 这里设置为true表示是最后一帧
       return mw.flushFrame(true, data)
    }

    // 2.如果启用了写压缩,或当前是客户端(需要进行掩码处理),则通过NextWriter进行逐步写入
    w, err := c.NextWriter(messageType)
    if err != nil {
       return err
    }
    // 2.1将数据以流的方式写入
    if _, err = w.Write(data); err != nil {
       return err
    }
    // 2.2关闭此次连接流,在关闭之前会将最后的数据写出
    return w.Close()
}

简单总结一下这个顶层API在里面其实只做了一个判断,如果当前是服务端并且没有开启数据压缩则直接将数据写出,反之则是以流的方式将数据写出

beginMessage源码解读

接下来看一下beginMessage这个方法做了什么,先简单说一下这个方法的作用,为当前的消息构建一个处理它的对象

go 复制代码
func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
    // 1.关闭上一个未关闭的writer,清理写入状态,防止并发或资源泄露
    if c.writer != nil {
       c.writer.Close()
       c.writer = nil
    }

    // 2.校验消息类型的合法性,只能是控制帧或者消息帧对应的是websocket帧上面的opencode码
    if !isControl(messageType) && !isData(messageType) {
       return errBadWriteOpCode
    }

    // 3.检查是否已处于错误状态,如果之前写入失败,则不能继续写
    c.writeErrMu.Lock()
    err := c.writeErr
    c.writeErrMu.Unlock()
    if err != nil {
       return err
    }

    // 4.初始化messageWriter写入器,设置帧类型,设置写入偏移(留出帧头部预留空间)
    mw.c = c
    mw.frameType = messageType
    // 4.1留出最大帧头长度(14字节)的位置,写数据时跳过
    mw.pos = maxFrameHeaderSize

// 5.分配写缓冲区,如果当前writeBuf为 nil,从池子里取一个缓冲区;如果没有池子,自己new一个
    if c.writeBuf == nil {
       wpd, ok := c.writePool.Get().(writePoolData)
       if ok {
          c.writeBuf = wpd.buf
       } else {
          c.writeBuf = make([]byte, c.writeBufSize)
       }
    }
    return nil
}

简单总结beginMessage这个方法作用是初始化写入器对象,确保写入器状态正确、缓冲区分配完成、连接处于健康状态,做好写消息前的预备工作

flushFrame源码解读

flushFrame这个方式是构建写消息的中间层API,它会构建出完整的帧头和负载并将消息写出

go 复制代码
func (w *messageWriter) flushFrame(final bool, extra []byte) error {
    c := w.c
    // 这里的length是获取此次消息完整的长度
    length := w.pos - maxFrameHeaderSize + len(extra)
    
    // 1.如果当前是控制协议的话并且不是最后一帧或者length超过最大控制帧大小,则返回错误
    if isControl(w.frameType) &&
       (!final || length > maxControlFramePayloadSize) {
       return w.endMessage(errInvalidControlFrame)
    }

    // 2.构建websocket帧头部的第一个字节b0
    b0 := byte(w.frameType)
    // 2.1构建最后一帧标识 
    if final {
       b0 |= finalBit
}
    // 2.2构建数据压缩标识
    if w.compress {
       b0 |= rsv1Bit
}
    w.compress = false

    //2.3构建帧头部的第二个字节b1,如果是客户端则构建mask位
    b1 := byte(0)
    if !c.isServer {
       b1 |= maskBit
}

    //3.初始化针对于此次消息的帧偏移量
    framePos := 0
    // 3.1客户端:Mask Key 占用4字节,所以framePos=0
    // 服务端:不需要Mask,但为了内存布局对齐,预留4字节占位
    if c.isServer {
       framePos = 4
    }

    // 4.构建PayloadLength区域(Payload 长度字段)
    // WebSocket 协议对Payload长度的编码规则:
    // <126 直接填入b1低7位
    // 126~65535 使用2字节扩展长度
    // >=65536 使用8字节扩展长度 
    switch {
    case length >= 65536:
       c.writeBuf[framePos] = b0
       c.writeBuf[framePos+1] = b1 | 127
       binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length))  
    case length > 125:
       framePos += 6
       c.writeBuf[framePos] = b0
       c.writeBuf[framePos+1] = b1 | 126
       binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length))
    default:
       framePos += 8
       c.writeBuf[framePos] = b0
       c.writeBuf[framePos+1] = b1 | byte(length)
    }

    // 5.如果是客户端的话,则需要mask数据
    if !c.isServer {
       key := newMaskKey()
       // 5.1将Mask Key放入帧头部
       copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
       // 5.2对当前写入缓冲区中的Payload进行掩码处理
       maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
       // 5.3如果客户端还有剩余数据,则直接报错,因为剩余部分的数据没有被mask,可以理解为上游代码有问题
       if len(extra) > 0 {
          return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode")))
       }
    }


    // 6.保证写入过程中没有其他协程正在写入
    if c.isWriting {
       panic("concurrent write to websocket connection")
    }
    c.isWriting = true

    // 7.将Frame写入底层TCP连接
    // frameType:当前帧类型
    // writeDeadline:写入超时时间
    // writeBuf[framePos:w.pos]:本次帧缓冲区数据
    // extra:额外追加数据 (服务端可以用extra追加大量数据)
    err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra)

    // 8.写入结束后,解除写锁
    if !c.isWriting {
       panic("concurrent write to websocket connection")
    }
    c.isWriting = false

    if err != nil {
       return w.endMessage(err)
    }

    // 9.如果当前帧是最后一帧(final=true),则标记写入器已关闭 (w.endMessage)
    if final {
       _ = w.endMessage(errWriteClosed)
       return nil
    }

    // 10.如果当前帧是中间帧(续帧),则复位writer,准备下一帧写入
    w.pos = maxFrameHeaderSize
w.frameType = continuationFrame
return nil
}

总结:将当前缓冲区中的数据封装成符合WebSocket协议格式的帧并安全写入底层连接,支持分片发送与压缩标识处理。

write源码解读

接下来就进入到websocket底层的方法解读,这个方法最终会将已经封装好的帧提交到tcp层去,也是当前库最底层的API

go 复制代码
// write是底层实际执行TCP写入的关键方法,负责将webSocket帧发送到网络连接中
func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error {
    // 1.加锁 
    <-c.mu
    defer func() { c.mu <- struct{}{} }()

    // 2.检查是否在之前的写操作中已经发生错误(writeErr被标记)
    // 如果之前写入已失败,则本次写入直接返回错误,避免后续操作
    c.writeErrMu.Lock()
    err := c.writeErr
    c.writeErrMu.Unlock()
    if err != nil {
       return err
    }

    // 3.设置底层 TCP 连接的写入超时时间(WriteDeadline)
    if err := c.conn.SetWriteDeadline(deadline); err != nil {
       return c.writeFatal(err)
    }

    // 4.判断 buf1 是否为空(没有额外数据需要发送)
    // 如果没有额外数据(单帧数据),直接将buf0写入TCP连接
    if len(buf1) == 0 {
       _, err = c.conn.Write(buf0)
    } else {
       // 4.1如果有buf1额外数据(需要分片发送时),则将buf0与buf1一起写出
       err = c.writeBufs(buf0, buf1)
    }

    if err != nil {
       return c.writeFatal(err)
    }

    // 5.特殊情况:如果本次写入的是CloseMessage(关闭帧)
    // 需要标记 writeErr = ErrCloseSent,确保后续操作知道连接已进入关闭流程
    if frameType == CloseMessage {
       _ = c.writeFatal(ErrCloseSent)
    }

    return nil
}

// writeBufs 是将多个数据缓冲区(buf0, buf1, ...)
func (c *Conn) writeBufs(bufs ...[]byte) error {
    // 1. 将多个buf包装为net.Buffers(零拷贝写出优化)
    b := net.Buffers(bufs)
    // 2. 使用WriteTo接口将所有数据写入TCP连接
    _, err := b.WriteTo(c.conn)
    return err
}

总结:这就是其中的一条完整的调用链路,但是本质上其他调用链路没有区别,不同写数据的顶层API最终都会调用中下层的API

NextWriter源码解读

接下来再来看一下另一条调用链路的代码,最后都会调用中下层的APIflushFrame和write,所以这里我只会讲述上层API的代码

go 复制代码
// NextWriter 用于获取一个writer用于后续消息发送
// 它会返回一个实现io.WriteCloser接口的writer(messageWriter或压缩包装的writer)
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
        var mw messageWriter
        // 1.初始化 messageWriter,绑定连接c,设置消息类型、分配缓冲区
        if err := c.beginMessage(&mw, messageType); err != nil {
                return nil, err
        }
        // 2.将初始化好的messageWriter绑定到Conn的write属性上
        c.writer = &mw

        // 3.判断是否启用了压缩发送(permessage-deflate)
        // 3.1 如果启用了压缩,则用newCompressionWriter包装writer,替换成压缩 writer
        //     同时标记compress = true,后续flushFrame会设置RSV1标志位
        if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
                w := c.newCompressionWriter(c.writer, c.compressionLevel)
                mw.compress = true
                c.writer = w
        }

        // 4. 返回 writer(可能是 messageWriter,也可能是压缩包装的 writer)
        return c.writer, nil
}

// ncopy 负责判断当前writeBuf剩余空间是否足够容纳剩余数据帧大小的数据。
// 如果空间不足则会触发flushFrame刷新缓冲区,再返回剩余可写入的空间大小。
func (w *messageWriter) ncopy(max int) (int, error) {
        // 1.计算当前写缓冲区剩余空间
        n := len(w.c.writeBuf) - w.pos
        // 2.如果空间不足(<=0),则先将已缓冲的数据通过flushFrame发送出去
        if n <= 0 {
                if err := w.flushFrame(false, nil); err != nil {
                        return 0, err
                }
                //2.1flushFrame后 w.pos 被重置到maxFrameHeaderSize,重新计算剩余空间
                n = len(w.c.writeBuf) - w.pos
        }
        // 3.取剩余空间与待写入数据max的较小值作为可写入的字节数
        if n > max {
                n = max
        }
        return n, nil
}

// Write是messageWriter的写入实现,将数据p写入writeBuf中,缓冲区满时会自动flushFrame发送
// 对于大数据(> 2 * writeBuf)的情况,服务端会直接通过flushFrame发送数据,不走缓冲区
func (w *messageWriter) Write(p []byte) (int, error) {
        // 1.如果messageWriter已经处于错误状态(err != nil),直接返回错误
        if w.err != nil {
                return 0, w.err
        }

        // 2.优化路径:如果写入数据非常大(> 2 * writeBuf大小)并且是服务端
        //   则直接调用 flushFrame 发送数据,不走writeBuf缓存,避免不必要的拷贝
        if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
                err := w.flushFrame(false, p)
                if err != nil {
                        return 0, err
                }
                return len(p), nil
        }

        // 3. 一般路径:将数据分段写入 writeBuf,缓冲区满时自动flushFrame
        nn := len(p)  
        for len(p) > 0 {
                // 3.1计算本次能写入writeBuf的最大长度
                n, err := w.ncopy(len(p))
                if err != nil {
                        return 0, err
                }
                // 3.2将数据拷贝到writeBuf缓冲区中
                copy(w.c.writeBuf[w.pos:], p[:n])
                w.pos += n
                // 3.3 更新剩余待写入数据
                p = p[n:]
        }

        // 4.返回写入的总字节数与nil
        return nn, nil
}

// Close 用于关闭当前的 messageWriter,将缓冲区中剩余数据作为最后一帧发送出去
// Close 的调用会触发flushFrame(final = true),确保所有数据都被写入连接
func (w *messageWriter) Close() error {
        // 1.检查当前 messageWriter 是否已经处于错误状态
        //   如果在之前的写入或flush中发生过错误(err != nil),则直接返回该错误
        if w.err != nil {
                return w.err
        }

        // 2.调用 flushFrame,将缓冲区中剩余数据发送出去,并标记这是最后一帧(final = true)
        //   flushFrame 内部会将w的状态置为结束状态 (调用endMessage),并释放写缓冲区
        return w.flushFrame(true, nil)
}

7.总结

最后总结一下WriteMessage和NextWriter的区别

维度 WriteMessage NextWriter
调用方式 一次性传入完整数据 []byte 返回 io.WriteCloser,用户分段写入
数据来源 数据在调用时必须已经准备好(完整 payload 一次性给到) 数据可以动态生产、边生成边写入
服务端大数据场景 直接走"快速路径",自动绕过缓冲区(大数据也能一次性写出) 除非数据是流式来源(如文件、数据库导出),否则意义不大
客户端场景 内部也是流式(因为 Mask 机制),但外部是全量一次性调用 对客户端大数据量写入是唯一合理的方式
应用场景 小~中等数据量 / 已经准备好的大数据块(文件内容已在内存中) 数据流式生成、文件分段传输、数据库分页导出、动态内容流
优劣总结 简单直接、性能极佳(尤其服务端) 灵活控制 Frame 粒度、支持动态数据来源,但代码复杂度高

读消息

最后我也会详细为大家自顶向下讲解完整的读消息的调用链路,每一次API里面的处理细节

read调用链路说明

ReadMessage源码解读

ReadMessage是该库中读取websocket消息最常用的api,所以我从这个api开始进行read部分的代码解析

go 复制代码
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
        var r io.Reader

        // 1.获取下一条消息的类型和读取器(用于读取该消息的内容)
        messageType, r, err = c.NextReader()
        if err != nil {
                return messageType, nil, err
        }

        // 2.读取整个消息内容到内存(p是完整的消息内容)
        p, err = io.ReadAll(r)

        return messageType, p, err
}

这个方法内容比较简单,实际上ReadMessage会从NextReader中获取出io.reader这个读取器,然后将所有消息内容读取出来返回

NextReader源码解读

这个方法是用于读取下一个完整的webSocket消息

go 复制代码
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
    // 1.关闭上一个读取器
    if c.reader != nil {
       // 1.1保证上一个读取器资源被正确释放
       c.reader.Close()
       c.reader = nil
    }

    // 2.重置读取状态
    c.messageReader = nil
    c.readLength = 0

    // 3.开始获取下一个帧消息
    for c.readErr == nil {
       frameType, err := c.advanceFrame()
       if err != nil {
          c.readErr = err
          break
       }

       // 3.1如果读取的帧是正常的数据帧,则返回Reader
       if frameType == TextMessage || frameType == BinaryMessage {
          c.messageReader = &messageReader{c}
          // 3.2设置当前读取器
          c.reader = c.messageReader
          // 3.3如果需要解压(rsv1=1且negotiated compress),包装一层解压Reader
          if c.readDecompress {
             c.reader = c.newDecompressionReader(c.reader)
          }
          return frameType, c.reader, nil
       }
    }

    // 4.如果用户忽略读消息的报错,同时读消息错误连续发生1000次则触发panic
    c.readErrCount++
    if c.readErrCount >= 1000 {
       panic("repeated read on failed websocket connection")
    }

    return noFrame, nil, c.readErr
}

4.advanceFrame源码解读

这个方法它不会真正读取帧的数据负载(payload),而是推进读取进度到下一个帧的位置 ,并解析出控制信息(如 opcode 、payload 长度、是否分片、掩码等) ,为后续 NextReaderReadMessage 等操作准备上下文

go 复制代码
func (c *Conn) advanceFrame() (int, error) {

    // 1.跳过上一帧剩余的部分,如果还有未读完的Payload数据就丢弃
    if c.readRemaining > 0 {
       if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil {
          return noFrame, err
       }
    }

    // 2.读取并解析帧头前两个字节:第一个字节是FIN/RSV1-3/OPCODE,第二个字节是MASK和payload len
    var errors []string 
    p, err := c.read(2) 
    if err != nil {
       return noFrame, err
    }
    //2.1提取帧类型(OPCODE)-低4位表示帧类型
    frameType := int(p[0] & 0xf)
    //2.2提取控制位
    // 是否为消息最后一帧(FIN)
    // 扩展标志位,通常用于压缩
    // 是否启用了掩码(MASK)
    final := p[0]&finalBit != 0 
    rsv1 := p[0]&rsv1Bit != 0   
    rsv2 := p[0]&rsv2Bit != 0
    rsv3 := p[0]&rsv3Bit != 0
    mask := p[1]&maskBit != 0 
    // 2.3提取 Payload 长度字段的低 7 位(0~125 表示实际长度;126 表示接下来的 2 字节;127 表示接下来的 8 字节)
    _ = c.setReadRemaining(int64(p[1] & 0x7f))
    // 2.4判断是否启用了 WebSocket 压缩扩展(RSV1 被设置)
    c.readDecompress = false
    if rsv1 {
       if c.newDecompressionReader != nil {
          c.readDecompress = true 
       } else {
          errors = append(errors, "RSV1 set") 
       }
    }
    // 不支持 rsv2,报错
    if rsv2 {
       errors = append(errors, "RSV2 set") 
    }
    // 不支持 rsv3,报错
    if rsv3 {
       errors = append(errors, "RSV3 set") 
    }

    // 3.校验 OPCODE 是否合理,控制帧的长度、FIN位,数据帧的顺序等
    switch frameType {
    // 3.1控制帧 payload 最大只能为 125 字节,并且必须是完整帧
    case CloseMessage, PingMessage, PongMessage: 
       if c.readRemaining > maxControlFramePayloadSize {
          errors = append(errors, "len > 125 for control") 
       }
       if !final {
          errors = append(errors, "FIN not set on control") 
       }
    //3.2上一个消息未完成就发新消息,错误
    case TextMessage, BinaryMessage: 
       if !c.readFinal {
          errors = append(errors, "data before FIN") 
       }
       c.readFinal = final 
    // 3.3连续帧,上一个消息已结束,却收到了 continuation,错误
    case continuationFrame: 
       if c.readFinal {
          errors = append(errors, "continuation after FIN") 
       }
       c.readFinal = final
    // 3.4位置的消息帧,错误   
    default:
       errors = append(errors, "bad opcode "+strconv.Itoa(frameType)) 
    }

    // 🔍 校验掩码设置是否正确(客户端发来的数据必须加掩码,服务端不加)
    if mask != c.isServer {
       errors = append(errors, "bad MASK")
    }

    if len(errors) > 0 {
       return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
    }

    // 4.如果Payload长度是126或127,需要额外读取真实的长度字段
    switch c.readRemaining {
    case 126:
       p, err := c.read(2)
       if err != nil {
          return noFrame, err
       }
       // 4.1使用2字节的无符号整数表示Payload长度
       if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
          return noFrame, err
       }
    case 127:
       p, err := c.read(8)
       if err != nil {
          return noFrame, err
       }
       // 4.2使用8字节的无符号整数表示Payload长度
       if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
          return noFrame, err
       }
    }

    // 5.如果启用了MASK,就读取4字节的掩码key
    if mask {
       c.readMaskPos = 0                    
       p, err := c.read(len(c.readMaskKey)) 
       if err != nil {
          return noFrame, err
       }
       copy(c.readMaskKey[:], p) 
    }

    // 6.处理数据帧(包括 continuation)时做限制判断
    if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
       // 6.1累加读取的payload总长度 
       c.readLength += c.readRemaining 

       if c.readLength < 0 {
          return noFrame, ErrReadLimit
       }
       // 6.2超过最大允许读取长度,主动发关闭帧 
       if c.readLimit > 0 && c.readLength > c.readLimit {
          _ = c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
          return noFrame, ErrReadLimit
       }

       return frameType, nil
    }

    // 7.读取控制帧的Payload数据
    var payload []byte
    if c.readRemaining > 0 {
       payload, err = c.read(int(c.readRemaining)) 
       _ = c.setReadRemaining(0)                   
       if err != nil {
          return noFrame, err
       }
       // 7.1如果是服务端收到掩码数据,需要反掩码
       if c.isServer {
          maskBytes(c.readMaskKey, 0, payload)
       }
    }

    // 8.控制帧处理逻辑
    switch frameType {
    case PongMessage:
       // 8.1收到Pong响应,回调处理
       if err := c.handlePong(string(payload)); err != nil {
          return noFrame, err
       }
    case PingMessage:
       // 8.2收到Ping,回调处理
       if err := c.handlePing(string(payload)); err != nil {
          return noFrame, err
       }
    case CloseMessage:
       // 8.3收到Close消息,解析状态码和描述
       closeCode := CloseNoStatusReceived
closeText := ""
       if len(payload) >= 2 {
          // 前2字节是关闭状态码
          closeCode = int(binary.BigEndian.Uint16(payload))
          if !isValidReceivedCloseCode(closeCode) {
             return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
          }
          // 后面是关闭原因文本(必须为合法 UTF-8)
          closeText = string(payload[2:])
          if !utf8.ValidString(closeText) {
             return noFrame, c.handleProtocolError("invalid utf8 payload in close frame")
          }
       }
       // 执行关闭处理回调
       if err := c.handleClose(closeCode, closeText); err != nil {
          return noFrame, err
       }
       // 返回 CloseError,通知上层连接已关闭
       return noFrame, &CloseError{Code: closeCode, Text: closeText}
    }

    // 9.返回此帧类型
    return frameType, nil
}

5.Read源码解读

最后再来看一下,该库中真正获取消息的方法,从WebSocket连接中连续读取消息负载数据,自动处理帧切换、掩码和错误,直到完整消息被读取完毕或出错为止

这是 messageReader 实现的 Read 方法,使其符合 io.Reader 接口

go 复制代码
func (r *messageReader) Read(b []byte) (int, error) {
        c := r.c 

        // 1.如果当前正在使用的messageReader不是当前这个实例,说明消息已经读完了或连接被其他操作抢占,直接返回EOF
        if c.messageReader != r {
                return 0, io.EOF
        }

        // 2.进入主循环,直到遇到读取错误
        for c.readErr == nil {
                // 2.1如果当前帧还有未读取的数据,则优先读取当前帧
                if c.readRemaining > 0 {
                        // 2.2如果调用方传入的缓冲区比剩余数据还多,缩小读取范围,避免越界
                        if int64(len(b)) > c.readRemaining {
                                b = b[:c.readRemaining]
                        }
                        // 2.3从底层缓冲区读取数据到这个字节数组中(b)
                        n, err := c.br.Read(b)
                        c.readErr = err 
                        // 2.4如果是服务端,需要对客户端发来的 masked 数据做解码
                        if c.isServer {
                                c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
                        }
                        // 2.5更新当前帧剩余待读取字节数
                        rem := c.readRemaining
                        rem -= int64(n)
                        _ = c.setReadRemaining(rem) 

                        // 2.6如果数据还没读完就遇到 EOF,说明远端意外断开,是异常错误
                        if c.readRemaining > 0 && c.readErr == io.EOF {
                                c.readErr = errUnexpectedEOF
                        }
                        return n, c.readErr
                }

                // 2.7当前帧已经读完,且是最后一帧,当前消息结束
                if c.readFinal {
                        c.messageReader = nil
                        return 0, io.EOF
                }

                // 2.8当前帧读完了,但不是最后一帧,尝试读取下一帧
                frameType, err := c.advanceFrame()
                switch {
                case err != nil:
                        // 出现错误就退出
                        c.readErr = err
                case frameType == TextMessage || frameType == BinaryMessage:
                        // 如果中间帧出现了新的message类型(只允许continuation),说明协议异常
                        c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
                }
        }

        err := c.readErr
        if err == io.EOF && c.messageReader == r {
                err = errUnexpectedEOF
        }
        return 0, err
}

5.总结

希望这篇 gorilla/websocket 库的源码解读对大家有所帮助!虽然内容偏向底层实现,可能读起来有点枯燥,但通过一步步调试、查阅,我也在不断摸索中加深了理解。如果你也在学习这个库,不妨一起动手跑跑源码。欢迎交流指正!

相关推荐
DemonAvenger36 分钟前
微服务通信:Go网络编程实战
网络协议·架构·go
程序员爱钓鱼1 小时前
Go语言实战案例:文件上传服务
后端·go·trae
程序员爱钓鱼1 小时前
Go语言实战案例:表单提交数据解析
后端·go·trae
DemonAvenger5 小时前
Go网络编程中的设计模式:从理论到实践
网络协议·架构·go
岁忧13 小时前
(nice!!!)(LeetCode 每日一题) 3363. 最多可收集的水果数目 (深度优先搜索dfs)
java·c++·算法·leetcode·go·深度优先
程序员爱钓鱼18 小时前
Go语言实战案例:简易JSON数据返回
后端·go·trae
程序员爱钓鱼18 小时前
Go语言实战案例:用net/http构建一个RESTful API
后端·go·trae
岁忧21 小时前
(LeetCode 面试经典 150 题) 82. 删除排序链表中的重复元素 II (链表)
java·c++·leetcode·链表·面试·go
DemonAvenger1 天前
大规模Go网络应用的部署与监控
网络协议·架构·go