skynet-socket.lua源码分析

skynet-socket.lua源码分析

源码

lua 复制代码
local driver = require "skynet.socketdriver"
local skynet = require "skynet"
local skynet_core = require "skynet.core"
local assert = assert

local BUFFER_LIMIT = 128 * 1024
local socket = {}   -- api
local socket_pool = setmetatable( -- store all socket object
    {},
    { __gc = function(p)
        for id,v in pairs(p) do
            driver.close(id)
            p[id] = nil
        end
    end
    }
)

local socket_onclose = {}
local socket_message = {}

local function wakeup(s)
    local co = s.co
    if co then
        s.co = nil
        skynet.wakeup(co)
    end
end

local function pause_socket(s, size)
    if s.pause ~= nil then
        return
    end
    if size then
        skynet.error(string.format("Pause socket (%d) size : %d" , s.id, size))
    else
        skynet.error(string.format("Pause socket (%d)" , s.id))
    end
    driver.pause(s.id)
    s.pause = true
    skynet.yield()  -- there are subsequent socket messages in mqueue, maybe.
end

local function suspend(s)
    assert(not s.co)
    s.co = coroutine.running()
    if s.pause then
        skynet.error(string.format("Resume socket (%d)", s.id))
        driver.start(s.id)
        skynet.wait(s.co)
        s.pause = nil
    else
        skynet.wait(s.co)
    end
    -- wakeup closing corouting every time suspend,
    -- because socket.close() will wait last socket buffer operation before clear the buffer.
    if s.closing then
        skynet.wakeup(s.closing)
    end
end

-- read skynet_socket.h for these macro
-- SKYNET_SOCKET_TYPE_DATA = 1
socket_message[1] = function(id, size, data)
    local s = socket_pool[id]
    if s == nil then
        skynet.error("socket: drop package from " .. id)
        driver.drop(data, size)
        return
    end

    local sz = driver.push(s.buffer, s.pool, data, size)
    local rr = s.read_required
    local rrt = type(rr)
    if rrt == "number" then
        -- read size
        if sz >= rr then
            s.read_required = nil
            if sz > BUFFER_LIMIT then
                pause_socket(s, sz)
            end
            wakeup(s)
        end
    else
        if s.buffer_limit and sz > s.buffer_limit then
            skynet.error(string.format("socket buffer overflow: fd=%d size=%d", id , sz))
            driver.close(id)
            return
        end
        if rrt == "string" then
            -- read line
            if driver.readline(s.buffer,nil,rr) then
                s.read_required = nil
                if sz > BUFFER_LIMIT then
                    pause_socket(s, sz)
                end
                wakeup(s)
            end
        elseif sz > BUFFER_LIMIT and not s.pause then
            pause_socket(s, sz)
        end
    end
end

-- SKYNET_SOCKET_TYPE_CONNECT = 2
socket_message[2] = function(id, ud , addr)
    local s = socket_pool[id]
    if s == nil then
        return
    end
    -- log remote addr
    if not s.connected then -- resume may also post connect message
        if s.listen then
            s.addr = addr
            s.port = ud
        end
        s.connected = true
        wakeup(s)
    end
end

-- SKYNET_SOCKET_TYPE_CLOSE = 3
socket_message[3] = function(id)
    local s = socket_pool[id]
    if s then
        s.connected = false
        wakeup(s)
    else
        driver.close(id)
    end
    local cb = socket_onclose[id]
    if cb then
        cb(id)
        socket_onclose[id] = nil
    end
end

-- SKYNET_SOCKET_TYPE_ACCEPT = 4
socket_message[4] = function(id, newid, addr)
    local s = socket_pool[id]
    if s == nil then
        driver.close(newid)
        return
    end
    s.callback(newid, addr)
end

-- SKYNET_SOCKET_TYPE_ERROR = 5
socket_message[5] = function(id, _, err)
    local s = socket_pool[id]
    if s == nil then
        driver.shutdown(id)
        skynet.error("socket: error on unknown", id, err)
        return
    end
    if s.callback then
        skynet.error("socket: accept error:", err)
        return
    end
    if s.connected then
        skynet.error("socket: error on", id, err)
    elseif s.connecting then
        s.connecting = err
    end
    s.connected = false
    driver.shutdown(id)

    wakeup(s)
end

-- SKYNET_SOCKET_TYPE_UDP = 6
socket_message[6] = function(id, size, data, address)
    local s = socket_pool[id]
    if s == nil or s.callback == nil then
        skynet.error("socket: drop udp package from " .. id)
        driver.drop(data, size)
        return
    end
    local str = skynet.tostring(data, size)
    skynet_core.trash(data, size)
    s.callback(str, address)
end

local function default_warning(id, size)
    local s = socket_pool[id]
    if not s then
        return
    end
    skynet.error(string.format("WARNING: %d K bytes need to send out (fd = %d)", size, id))
end

-- SKYNET_SOCKET_TYPE_WARNING
socket_message[7] = function(id, size)
    local s = socket_pool[id]
    if s then
        local warning = s.on_warning or default_warning
        warning(id, size)
    end
end

skynet.register_protocol {
    name = "socket",
    id = skynet.PTYPE_SOCKET,   -- PTYPE_SOCKET = 6
    unpack = driver.unpack,
    dispatch = function (_, _, t, ...)
        socket_message[t](...)
    end
}

local function connect(id, func)
    local newbuffer
    if func == nil then
        newbuffer = driver.buffer()
    end
    local s = {
        id = id,
        buffer = newbuffer,
        pool = newbuffer and {},
        connected = false,
        connecting = true,
        read_required = false,
        co = false,
        callback = func,
        protocol = "TCP",
    }
    assert(not socket_onclose[id], "socket has onclose callback")
    local s2 = socket_pool[id]
    if s2 and not s2.listen then
        error("socket is not closed")
    end
    socket_pool[id] = s
    suspend(s)
    local err = s.connecting
    s.connecting = nil
    if s.connected then
        return id
    else
        socket_pool[id] = nil
        return nil, err
    end
end

function socket.open(addr, port)
    local id = driver.connect(addr,port)
    return connect(id)
end

function socket.bind(os_fd)
    local id = driver.bind(os_fd)
    return connect(id)
end

function socket.stdin()
    return socket.bind(0)
end

function socket.start(id, func)
    driver.start(id)
    return connect(id, func)
end

function socket.pause(id)
    local s = socket_pool[id]
    if s == nil then
        return
    end
    pause_socket(s)
end

function socket.shutdown(id)
    local s = socket_pool[id]
    if s then
        -- the framework would send SKYNET_SOCKET_TYPE_CLOSE , need close(id) later
        driver.shutdown(id)
    end
end

function socket.close_fd(id)
    assert(socket_pool[id] == nil,"Use socket.close instead")
    driver.close(id)
end

function socket.close(id)
    local s = socket_pool[id]
    if s == nil then
        return
    end
    driver.close(id)
    if s.connected then
        s.pause = false -- Do not resume this fd if it paused.
        if s.co then
            -- reading this socket on another coroutine, so don't shutdown (clear the buffer) immediately
            -- wait reading coroutine read the buffer.
            assert(not s.closing)
            s.closing = coroutine.running()
            skynet.wait(s.closing)
        else
            suspend(s)
        end
        s.connected = false
    end
    socket_pool[id] = nil
end

function socket.read(id, sz)
    local s = socket_pool[id]
    assert(s)
    if sz == nil then
        -- read some bytes
        local ret = driver.readall(s.buffer, s.pool)
        if ret ~= "" then
            return ret
        end

        if not s.connected then
            return false, ret
        end
        assert(not s.read_required)
        s.read_required = 0
        suspend(s)
        ret = driver.readall(s.buffer, s.pool)
        if ret ~= "" then
            return ret
        else
            return false, ret
        end
    end

    local ret = driver.pop(s.buffer, s.pool, sz)
    if ret then
        return ret
    end
    if s.closing or not s.connected then
        return false, driver.readall(s.buffer, s.pool)
    end

    assert(not s.read_required)
    s.read_required = sz
    suspend(s)
    ret = driver.pop(s.buffer, s.pool, sz)
    if ret then
        return ret
    else
        return false, driver.readall(s.buffer, s.pool)
    end
end

function socket.readall(id)
    local s = socket_pool[id]
    assert(s)
    if not s.connected then
        local r = driver.readall(s.buffer, s.pool)
        return r ~= "" and r
    end
    assert(not s.read_required)
    s.read_required = true
    suspend(s)
    assert(s.connected == false)
    return driver.readall(s.buffer, s.pool)
end

function socket.readline(id, sep)
    sep = sep or "\n"
    local s = socket_pool[id]
    assert(s)
    local ret = driver.readline(s.buffer, s.pool, sep)
    if ret then
        return ret
    end
    if not s.connected then
        return false, driver.readall(s.buffer, s.pool)
    end
    assert(not s.read_required)
    s.read_required = sep
    suspend(s)
    if s.connected then
        return driver.readline(s.buffer, s.pool, sep)
    else
        return false, driver.readall(s.buffer, s.pool)
    end
end

function socket.block(id)
    local s = socket_pool[id]
    if not s or not s.connected then
        return false
    end
    assert(not s.read_required)
    s.read_required = 0
    suspend(s)
    return s.connected
end

socket.write = assert(driver.send)
socket.lwrite = assert(driver.lsend)
socket.header = assert(driver.header)

function socket.invalid(id)
    return socket_pool[id] == nil
end

function socket.disconnected(id)
    local s = socket_pool[id]
    if s then
        return not(s.connected or s.connecting)
    end
end

function socket.listen(host, port, backlog)
    if port == nil then
        host, port = string.match(host, "([^:]+):(.+)$")
        port = tonumber(port)
    end
    local id = driver.listen(host, port, backlog)
    local s = {
        id = id,
        connected = false,
        listen = true,
    }
    assert(socket_pool[id] == nil)
    socket_pool[id] = s
    suspend(s)
    return id, s.addr, s.port
end

-- abandon use to forward socket id to other service
-- you must call socket.start(id) later in other service
function socket.abandon(id)
    local s = socket_pool[id]
    if s then
        s.connected = false
        wakeup(s)
        socket_onclose[id] = nil
        socket_pool[id] = nil
    end
end

function socket.limit(id, limit)
    local s = assert(socket_pool[id])
    s.buffer_limit = limit
end

---------------------- UDP

local function create_udp_object(id, cb)
    assert(not socket_pool[id], "socket is not closed")
    socket_pool[id] = {
        id = id,
        connected = true,
        protocol = "UDP",
        callback = cb,
    }
end

function socket.udp(callback, host, port)
    local id = driver.udp(host, port)
    create_udp_object(id, callback)
    return id
end

function socket.udp_connect(id, addr, port, callback)
    local obj = socket_pool[id]
    if obj then
        assert(obj.protocol == "UDP")
        if callback then
            obj.callback = callback
        end
    else
        create_udp_object(id, callback)
    end
    driver.udp_connect(id, addr, port)
end

function socket.udp_listen(addr, port, callback)
    local id = driver.udp_listen(addr, port)
    create_udp_object(id, callback)
    return id
end

function socket.udp_dial(addr, port, callback)
    local id = driver.udp_dial(addr, port)
    create_udp_object(id, callback)
    return id
end

socket.sendto = assert(driver.udp_send)
socket.udp_address = assert(driver.udp_address)
socket.netstat = assert(driver.info)
socket.resolve = assert(driver.resolve)

function socket.warning(id, callback)
    local obj = socket_pool[id]
    assert(obj)
    obj.on_warning = callback
end

function socket.onclose(id, callback)
    socket_onclose[id] = callback
end

return socket

模块初始化和核心数据结构

引入依赖和常量

lua 复制代码
local driver = require "skynet.socketdriver"  -- 底层C驱动
local skynet = require "skynet"
local skynet_core = require "skynet.core"

local BUFFER_LIMIT = 128 * 1024  -- 缓冲区限制128KB

核心数据结构

lua 复制代码
local socket = {}  -- 对外API
local socket_pool = setmetatable({}, {  -- 所有socket对象池
    __gc = function(p)  -- GC时自动关闭所有socket
        for id,v in pairs(p) do
            driver.close(id)
            p[id] = nil
        end
    end
})

local socket_onclose = {}    -- socket关闭回调
local socket_message = {}    -- 消息类型处理函数

协程管理和挂起机制

协程控制函数

lua 复制代码
local function wakeup(s)
    local co = s.co
    if co then
        s.co = nil
        skynet.wakeup(co)  -- 唤醒挂起的协程
    end
end

local function pause_socket(s, size)
    if s.pause ~= nil then return end
    driver.pause(s.id)  -- 底层暂停接收数据
    s.pause = true
    skynet.yield()  -- 让出CPU,处理其他消息
end

local function suspend(s)
    assert(not s.co)
    s.co = coroutine.running()  -- 保存当前协程
    if s.pause then
        driver.start(s.id)  -- 恢复数据接收
        skynet.wait(s.co)   -- 等待唤醒
        s.pause = nil
    else
        skynet.wait(s.co)   -- 直接等待
    end
    -- 如果有关闭操作在等待,唤醒它
    if s.closing then
        skynet.wakeup(s.closing)
    end
end

Socket消息类型处理

Skynet定义了7种socket消息类型

类型1: 数据到达 (SKYNET_SOCKET_TYPE_DATA)

lua 复制代码
socket_message[1] = function(id, size, data)
    local s = socket_pool[id]
    if s == nil then
        skynet.error("socket: drop package from " .. id)
        driver.drop(data, size)  -- 丢弃数据
        return
    end

    local sz = driver.push(s.buffer, s.pool, data, size)  -- 数据压入缓冲区
    
    -- 根据读取需求唤醒等待的协程
    local rr = s.read_required
    if type(rr) == "number" then  -- 需要读取指定字节数
        if sz >= rr then
            s.read_required = nil
            if sz > BUFFER_LIMIT then
                pause_socket(s, sz)  -- 缓冲区过大,暂停接收
            end
            wakeup(s)  -- 唤醒读取协程
        end
    else
        -- 其他读取模式处理...
    end
end

类型2: 连接建立 (SKYNET_SOCKET_TYPE_CONNECT)

lua 复制代码
socket_message[2] = function(id, ud, addr)
    local s = socket_pool[id]
    if s == nil then return end
    
    if not s.connected then
        if s.listen then  -- 监听socket
            s.addr = addr
            s.port = ud
        end
        s.connected = true  -- 标记为已连接
        wakeup(s)  -- 唤醒等待连接的协程
    end
end

类型3: 连接关闭 (SKYNET_SOCKET_TYPE_CLOSE)

lua 复制代码
socket_message[3] = function(id)
    local s = socket_pool[id]
    if s then
        s.connected = false
        wakeup(s)  -- 唤醒所有等待的协程
    else
        driver.close(id)  -- 直接关闭
    end
    
    -- 执行关闭回调
    local cb = socket_onclose[id]
    if cb then
        cb(id)
        socket_onclose[id] = nil
    end
end

类型4: 接受连接 (SKYNET_SOCKET_TYPE_ACCEPT)

lua 复制代码
socket_message[4] = function(id, newid, addr)
    local s = socket_pool[id]
    if s == nil then
        driver.close(newid)  -- 监听socket已关闭,拒绝连接
        return
    end
    s.callback(newid, addr)  -- 调用accept回调
end

其他类型处理

  • 类型5: 错误处理
  • 类型6: UDP数据包
  • 类型7: 发送缓冲区警告

协议注册和消息分发

lua 复制代码
skynet.register_protocol {
    name = "socket",
    id = skynet.PTYPE_SOCKET,  -- PTYPE_SOCKET = 6
    unpack = driver.unpack,    -- 使用驱动解包
    dispatch = function (_, _, t, ...)
        socket_message[t](...)  -- 根据类型分发处理
    end
}

核心API方法

连接建立相关

lua 复制代码
function socket.open(addr, port)
    local id = driver.connect(addr,port)  -- 底层连接
    return connect(id)  -- 等待连接完成
end

function socket.listen(host, port, backlog)
    local id = driver.listen(host, port, backlog)  -- 创建监听socket
    local s = {
        id = id,
        connected = false,
        listen = true,
    }
    socket_pool[id] = s
    suspend(s)  -- 等待监听成功
    return id, s.addr, s.port  -- 返回实际监听的地址和端口
end

function socket.start(id, func)
    driver.start(id)      -- 开始接收数据
    return connect(id, func)  -- 对于监听socket,func是accept回调
end

数据读取相关

lua 复制代码
function socket.read(id, sz)
    local s = socket_pool[id]
    assert(s)
    
    if sz == nil then
        -- 读取所有可用数据
        local ret = driver.readall(s.buffer, s.pool)
        if ret ~= "" then return ret end
        
        if not s.connected then return false, ret end
        
        s.read_required = 0  -- 标记需要数据
        suspend(s)           -- 挂起等待数据
        ret = driver.readall(s.buffer, s.pool)
        return ret ~= "" and ret or false, ret
    else
        -- 读取指定大小数据
        local ret = driver.pop(s.buffer, s.pool, sz)
        if ret then return ret end
        
        if s.closing or not s.connected then
            return false, driver.readall(s.buffer, s.pool)
        end
        
        s.read_required = sz  -- 设置读取需求
        suspend(s)            -- 等待足够数据
        ret = driver.pop(s.buffer, s.pool, sz)
        return ret or false, driver.readall(s.buffer, s.pool)
    end
end

function socket.readline(id, sep)
    sep = sep or "\n"
    local s = socket_pool[id]
    assert(s)
    
    local ret = driver.readline(s.buffer, s.pool, sep)
    if ret then return ret end
    
    if not s.connected then
        return false, driver.readall(s.buffer, s.pool)
    end
    
    s.read_required = sep  -- 设置为行读取模式
    suspend(s)
    if s.connected then
        return driver.readline(s.buffer, s.pool, sep)
    else
        return false, driver.readall(s.buffer, s.pool)
    end
end

数据写入和连接管理

lua 复制代码
socket.write = assert(driver.send)    -- 异步发送
socket.lwrite = assert(driver.lsend)  -- 低级别发送

function socket.close(id)
    local s = socket_pool[id]
    if s == nil then return end
    
    driver.close(id)  -- 底层关闭
    
    if s.connected then
        s.pause = false
        if s.co then
            -- 有协程在读取,等待读取完成
            assert(not s.closing)
            s.closing = coroutine.running()
            skynet.wait(s.closing)
        else
            suspend(s)  -- 等待清理
        end
        s.connected = false
    end
    socket_pool[id] = nil
end

UDP相关功能

lua 复制代码
function socket.udp(callback, host, port)
    local id = driver.udp(host, port)
    create_udp_object(id, callback)  -- 创建UDP socket对象
    return id
end

socket.sendto = assert(driver.udp_send)  -- UDP发送

业务调用链路分析

服务启动监听流程

lua 复制代码
-- gateserver.lua 中的调用
function CMD.open(source, conf)
    local address = conf.address or "0.0.0.0"
    local port = assert(conf.port)
    socket = socketdriver.listen(address, port, backlog)  -- 创建监听socket
    socketdriver.start(socket)  -- 开始接受连接
end

客户端连接处理流程

lua 复制代码
1. 客户端连接 → 底层驱动 → socket_message[4] 
   → handler.connect() → gateserver.openclient()
   
2. 数据到达 → socket_message[1] 
   → 数据压入缓冲区 → 检查读取需求 → wakeup()唤醒读取协程
   
3. 业务读取 → socket.read() 
   → 缓冲区有数据立即返回 / 无数据则suspend()等待

数据发送流程

lua 复制代码
-- 业务代码调用
socket.write(fd, data) 
    → driver.send(fd, data)  -- 异步发送到底层

可以把socket.lua想象成一个高效的快递分拣中心

  • socket_pool = 快递货架(存放所有包裹)
  • buffer = 临时存放区(数据缓冲区)
  • 协程机制 = 智能调度系统(有人取件时才通知)
  • driver = 装卸工人(底层实际操作)

工作流程:
1. 收货(数据到达):

  • 快递车(网络数据)到达 → 分拣员(socket_message)处理
  • 根据标签(消息类型)放到对应货架
  • 如果有人预订(read_required),就打电话通知(wakeup)

2. 取货(数据读取):

  • 客户(业务代码)来取件 → 看货架有没有
  • 有货直接拿走 → 没货就登记需求(read_required)然后等待(suspend)

3. 发货(数据发送):

  • 客户要寄件 → 直接交给装卸工(driver.send)处理
  • 装卸工负责打包发送,不阻塞分拣中心

4. 特殊服务:

  • 暂停服务:货架太满时暂停收货(pause_socket)
  • 超时提醒:发货堆积时发出警告(socket_message[7])
  • 连接管理:新客户登记、老客户离开的接待流程
相关推荐
低代码布道师2 小时前
学习低代码,需要什么基础?
学习·低代码
西猫雷婶3 小时前
random.shuffle()函数随机打乱数据
开发语言·pytorch·python·学习·算法·线性回归·numpy
随机惯性粒子群3 小时前
STM32控制开发学习笔记【基于STM32 HAL库】
笔记·stm32·嵌入式硬件·学习
来生硬件工程师3 小时前
CH582 GPIO
c语言·开发语言·单片机
一條狗3 小时前
学习日报 20250930|多优惠券叠加核销及场景互斥逻辑
学习·核销
fly-phantomWing4 小时前
在命令提示符页面中用pip命令行安装Python第三方库的详细步骤
开发语言·python·pip
Nan_Shu_6144 小时前
学习:uniapp全栈微信小程序vue3后台-额外/精彩报错篇
前端·学习·微信小程序·小程序·uni-app·notepad++
VBA63374 小时前
VBA数据库解决方案第二十三讲:向一个已有数据表中添加数据记录
开发语言
杜子不疼.4 小时前
【C++】玩转模板:进阶之路
java·开发语言·c++