skynet-socket.lua源码分析
- 源码
- 模块初始化和核心数据结构
- 协程管理和挂起机制
- Socket消息类型处理
-
- [类型1: 数据到达 (SKYNET_SOCKET_TYPE_DATA)](#类型1: 数据到达 (SKYNET_SOCKET_TYPE_DATA))
- [类型2: 连接建立 (SKYNET_SOCKET_TYPE_CONNECT)](#类型2: 连接建立 (SKYNET_SOCKET_TYPE_CONNECT))
- [类型3: 连接关闭 (SKYNET_SOCKET_TYPE_CLOSE)](#类型3: 连接关闭 (SKYNET_SOCKET_TYPE_CLOSE))
- [类型4: 接受连接 (SKYNET_SOCKET_TYPE_ACCEPT)](#类型4: 接受连接 (SKYNET_SOCKET_TYPE_ACCEPT))
- 其他类型处理
- 协议注册和消息分发
- 核心API方法
- 业务调用链路分析
源码
lua
local driver = require "skynet.socketdriver"
local skynet = require "skynet"
local skynet_core = require "skynet.core"
local assert = assert
local BUFFER_LIMIT = 128 * 1024
local socket = {} -- api
local socket_pool = setmetatable( -- store all socket object
{},
{ __gc = function(p)
for id,v in pairs(p) do
driver.close(id)
p[id] = nil
end
end
}
)
local socket_onclose = {}
local socket_message = {}
local function wakeup(s)
local co = s.co
if co then
s.co = nil
skynet.wakeup(co)
end
end
local function pause_socket(s, size)
if s.pause ~= nil then
return
end
if size then
skynet.error(string.format("Pause socket (%d) size : %d" , s.id, size))
else
skynet.error(string.format("Pause socket (%d)" , s.id))
end
driver.pause(s.id)
s.pause = true
skynet.yield() -- there are subsequent socket messages in mqueue, maybe.
end
local function suspend(s)
assert(not s.co)
s.co = coroutine.running()
if s.pause then
skynet.error(string.format("Resume socket (%d)", s.id))
driver.start(s.id)
skynet.wait(s.co)
s.pause = nil
else
skynet.wait(s.co)
end
-- wakeup closing corouting every time suspend,
-- because socket.close() will wait last socket buffer operation before clear the buffer.
if s.closing then
skynet.wakeup(s.closing)
end
end
-- read skynet_socket.h for these macro
-- SKYNET_SOCKET_TYPE_DATA = 1
socket_message[1] = function(id, size, data)
local s = socket_pool[id]
if s == nil then
skynet.error("socket: drop package from " .. id)
driver.drop(data, size)
return
end
local sz = driver.push(s.buffer, s.pool, data, size)
local rr = s.read_required
local rrt = type(rr)
if rrt == "number" then
-- read size
if sz >= rr then
s.read_required = nil
if sz > BUFFER_LIMIT then
pause_socket(s, sz)
end
wakeup(s)
end
else
if s.buffer_limit and sz > s.buffer_limit then
skynet.error(string.format("socket buffer overflow: fd=%d size=%d", id , sz))
driver.close(id)
return
end
if rrt == "string" then
-- read line
if driver.readline(s.buffer,nil,rr) then
s.read_required = nil
if sz > BUFFER_LIMIT then
pause_socket(s, sz)
end
wakeup(s)
end
elseif sz > BUFFER_LIMIT and not s.pause then
pause_socket(s, sz)
end
end
end
-- SKYNET_SOCKET_TYPE_CONNECT = 2
socket_message[2] = function(id, ud , addr)
local s = socket_pool[id]
if s == nil then
return
end
-- log remote addr
if not s.connected then -- resume may also post connect message
if s.listen then
s.addr = addr
s.port = ud
end
s.connected = true
wakeup(s)
end
end
-- SKYNET_SOCKET_TYPE_CLOSE = 3
socket_message[3] = function(id)
local s = socket_pool[id]
if s then
s.connected = false
wakeup(s)
else
driver.close(id)
end
local cb = socket_onclose[id]
if cb then
cb(id)
socket_onclose[id] = nil
end
end
-- SKYNET_SOCKET_TYPE_ACCEPT = 4
socket_message[4] = function(id, newid, addr)
local s = socket_pool[id]
if s == nil then
driver.close(newid)
return
end
s.callback(newid, addr)
end
-- SKYNET_SOCKET_TYPE_ERROR = 5
socket_message[5] = function(id, _, err)
local s = socket_pool[id]
if s == nil then
driver.shutdown(id)
skynet.error("socket: error on unknown", id, err)
return
end
if s.callback then
skynet.error("socket: accept error:", err)
return
end
if s.connected then
skynet.error("socket: error on", id, err)
elseif s.connecting then
s.connecting = err
end
s.connected = false
driver.shutdown(id)
wakeup(s)
end
-- SKYNET_SOCKET_TYPE_UDP = 6
socket_message[6] = function(id, size, data, address)
local s = socket_pool[id]
if s == nil or s.callback == nil then
skynet.error("socket: drop udp package from " .. id)
driver.drop(data, size)
return
end
local str = skynet.tostring(data, size)
skynet_core.trash(data, size)
s.callback(str, address)
end
local function default_warning(id, size)
local s = socket_pool[id]
if not s then
return
end
skynet.error(string.format("WARNING: %d K bytes need to send out (fd = %d)", size, id))
end
-- SKYNET_SOCKET_TYPE_WARNING
socket_message[7] = function(id, size)
local s = socket_pool[id]
if s then
local warning = s.on_warning or default_warning
warning(id, size)
end
end
skynet.register_protocol {
name = "socket",
id = skynet.PTYPE_SOCKET, -- PTYPE_SOCKET = 6
unpack = driver.unpack,
dispatch = function (_, _, t, ...)
socket_message[t](...)
end
}
local function connect(id, func)
local newbuffer
if func == nil then
newbuffer = driver.buffer()
end
local s = {
id = id,
buffer = newbuffer,
pool = newbuffer and {},
connected = false,
connecting = true,
read_required = false,
co = false,
callback = func,
protocol = "TCP",
}
assert(not socket_onclose[id], "socket has onclose callback")
local s2 = socket_pool[id]
if s2 and not s2.listen then
error("socket is not closed")
end
socket_pool[id] = s
suspend(s)
local err = s.connecting
s.connecting = nil
if s.connected then
return id
else
socket_pool[id] = nil
return nil, err
end
end
function socket.open(addr, port)
local id = driver.connect(addr,port)
return connect(id)
end
function socket.bind(os_fd)
local id = driver.bind(os_fd)
return connect(id)
end
function socket.stdin()
return socket.bind(0)
end
function socket.start(id, func)
driver.start(id)
return connect(id, func)
end
function socket.pause(id)
local s = socket_pool[id]
if s == nil then
return
end
pause_socket(s)
end
function socket.shutdown(id)
local s = socket_pool[id]
if s then
-- the framework would send SKYNET_SOCKET_TYPE_CLOSE , need close(id) later
driver.shutdown(id)
end
end
function socket.close_fd(id)
assert(socket_pool[id] == nil,"Use socket.close instead")
driver.close(id)
end
function socket.close(id)
local s = socket_pool[id]
if s == nil then
return
end
driver.close(id)
if s.connected then
s.pause = false -- Do not resume this fd if it paused.
if s.co then
-- reading this socket on another coroutine, so don't shutdown (clear the buffer) immediately
-- wait reading coroutine read the buffer.
assert(not s.closing)
s.closing = coroutine.running()
skynet.wait(s.closing)
else
suspend(s)
end
s.connected = false
end
socket_pool[id] = nil
end
function socket.read(id, sz)
local s = socket_pool[id]
assert(s)
if sz == nil then
-- read some bytes
local ret = driver.readall(s.buffer, s.pool)
if ret ~= "" then
return ret
end
if not s.connected then
return false, ret
end
assert(not s.read_required)
s.read_required = 0
suspend(s)
ret = driver.readall(s.buffer, s.pool)
if ret ~= "" then
return ret
else
return false, ret
end
end
local ret = driver.pop(s.buffer, s.pool, sz)
if ret then
return ret
end
if s.closing or not s.connected then
return false, driver.readall(s.buffer, s.pool)
end
assert(not s.read_required)
s.read_required = sz
suspend(s)
ret = driver.pop(s.buffer, s.pool, sz)
if ret then
return ret
else
return false, driver.readall(s.buffer, s.pool)
end
end
function socket.readall(id)
local s = socket_pool[id]
assert(s)
if not s.connected then
local r = driver.readall(s.buffer, s.pool)
return r ~= "" and r
end
assert(not s.read_required)
s.read_required = true
suspend(s)
assert(s.connected == false)
return driver.readall(s.buffer, s.pool)
end
function socket.readline(id, sep)
sep = sep or "\n"
local s = socket_pool[id]
assert(s)
local ret = driver.readline(s.buffer, s.pool, sep)
if ret then
return ret
end
if not s.connected then
return false, driver.readall(s.buffer, s.pool)
end
assert(not s.read_required)
s.read_required = sep
suspend(s)
if s.connected then
return driver.readline(s.buffer, s.pool, sep)
else
return false, driver.readall(s.buffer, s.pool)
end
end
function socket.block(id)
local s = socket_pool[id]
if not s or not s.connected then
return false
end
assert(not s.read_required)
s.read_required = 0
suspend(s)
return s.connected
end
socket.write = assert(driver.send)
socket.lwrite = assert(driver.lsend)
socket.header = assert(driver.header)
function socket.invalid(id)
return socket_pool[id] == nil
end
function socket.disconnected(id)
local s = socket_pool[id]
if s then
return not(s.connected or s.connecting)
end
end
function socket.listen(host, port, backlog)
if port == nil then
host, port = string.match(host, "([^:]+):(.+)$")
port = tonumber(port)
end
local id = driver.listen(host, port, backlog)
local s = {
id = id,
connected = false,
listen = true,
}
assert(socket_pool[id] == nil)
socket_pool[id] = s
suspend(s)
return id, s.addr, s.port
end
-- abandon use to forward socket id to other service
-- you must call socket.start(id) later in other service
function socket.abandon(id)
local s = socket_pool[id]
if s then
s.connected = false
wakeup(s)
socket_onclose[id] = nil
socket_pool[id] = nil
end
end
function socket.limit(id, limit)
local s = assert(socket_pool[id])
s.buffer_limit = limit
end
---------------------- UDP
local function create_udp_object(id, cb)
assert(not socket_pool[id], "socket is not closed")
socket_pool[id] = {
id = id,
connected = true,
protocol = "UDP",
callback = cb,
}
end
function socket.udp(callback, host, port)
local id = driver.udp(host, port)
create_udp_object(id, callback)
return id
end
function socket.udp_connect(id, addr, port, callback)
local obj = socket_pool[id]
if obj then
assert(obj.protocol == "UDP")
if callback then
obj.callback = callback
end
else
create_udp_object(id, callback)
end
driver.udp_connect(id, addr, port)
end
function socket.udp_listen(addr, port, callback)
local id = driver.udp_listen(addr, port)
create_udp_object(id, callback)
return id
end
function socket.udp_dial(addr, port, callback)
local id = driver.udp_dial(addr, port)
create_udp_object(id, callback)
return id
end
socket.sendto = assert(driver.udp_send)
socket.udp_address = assert(driver.udp_address)
socket.netstat = assert(driver.info)
socket.resolve = assert(driver.resolve)
function socket.warning(id, callback)
local obj = socket_pool[id]
assert(obj)
obj.on_warning = callback
end
function socket.onclose(id, callback)
socket_onclose[id] = callback
end
return socket
模块初始化和核心数据结构
引入依赖和常量
lua
local driver = require "skynet.socketdriver" -- 底层C驱动
local skynet = require "skynet"
local skynet_core = require "skynet.core"
local BUFFER_LIMIT = 128 * 1024 -- 缓冲区限制128KB
核心数据结构
lua
local socket = {} -- 对外API
local socket_pool = setmetatable({}, { -- 所有socket对象池
__gc = function(p) -- GC时自动关闭所有socket
for id,v in pairs(p) do
driver.close(id)
p[id] = nil
end
end
})
local socket_onclose = {} -- socket关闭回调
local socket_message = {} -- 消息类型处理函数
协程管理和挂起机制
协程控制函数
lua
local function wakeup(s)
local co = s.co
if co then
s.co = nil
skynet.wakeup(co) -- 唤醒挂起的协程
end
end
local function pause_socket(s, size)
if s.pause ~= nil then return end
driver.pause(s.id) -- 底层暂停接收数据
s.pause = true
skynet.yield() -- 让出CPU,处理其他消息
end
local function suspend(s)
assert(not s.co)
s.co = coroutine.running() -- 保存当前协程
if s.pause then
driver.start(s.id) -- 恢复数据接收
skynet.wait(s.co) -- 等待唤醒
s.pause = nil
else
skynet.wait(s.co) -- 直接等待
end
-- 如果有关闭操作在等待,唤醒它
if s.closing then
skynet.wakeup(s.closing)
end
end
Socket消息类型处理
Skynet定义了7种socket消息类型
类型1: 数据到达 (SKYNET_SOCKET_TYPE_DATA)
lua
socket_message[1] = function(id, size, data)
local s = socket_pool[id]
if s == nil then
skynet.error("socket: drop package from " .. id)
driver.drop(data, size) -- 丢弃数据
return
end
local sz = driver.push(s.buffer, s.pool, data, size) -- 数据压入缓冲区
-- 根据读取需求唤醒等待的协程
local rr = s.read_required
if type(rr) == "number" then -- 需要读取指定字节数
if sz >= rr then
s.read_required = nil
if sz > BUFFER_LIMIT then
pause_socket(s, sz) -- 缓冲区过大,暂停接收
end
wakeup(s) -- 唤醒读取协程
end
else
-- 其他读取模式处理...
end
end
类型2: 连接建立 (SKYNET_SOCKET_TYPE_CONNECT)
lua
socket_message[2] = function(id, ud, addr)
local s = socket_pool[id]
if s == nil then return end
if not s.connected then
if s.listen then -- 监听socket
s.addr = addr
s.port = ud
end
s.connected = true -- 标记为已连接
wakeup(s) -- 唤醒等待连接的协程
end
end
类型3: 连接关闭 (SKYNET_SOCKET_TYPE_CLOSE)
lua
socket_message[3] = function(id)
local s = socket_pool[id]
if s then
s.connected = false
wakeup(s) -- 唤醒所有等待的协程
else
driver.close(id) -- 直接关闭
end
-- 执行关闭回调
local cb = socket_onclose[id]
if cb then
cb(id)
socket_onclose[id] = nil
end
end
类型4: 接受连接 (SKYNET_SOCKET_TYPE_ACCEPT)
lua
socket_message[4] = function(id, newid, addr)
local s = socket_pool[id]
if s == nil then
driver.close(newid) -- 监听socket已关闭,拒绝连接
return
end
s.callback(newid, addr) -- 调用accept回调
end
其他类型处理
- 类型5: 错误处理
- 类型6: UDP数据包
- 类型7: 发送缓冲区警告
协议注册和消息分发
lua
skynet.register_protocol {
name = "socket",
id = skynet.PTYPE_SOCKET, -- PTYPE_SOCKET = 6
unpack = driver.unpack, -- 使用驱动解包
dispatch = function (_, _, t, ...)
socket_message[t](...) -- 根据类型分发处理
end
}
核心API方法
连接建立相关
lua
function socket.open(addr, port)
local id = driver.connect(addr,port) -- 底层连接
return connect(id) -- 等待连接完成
end
function socket.listen(host, port, backlog)
local id = driver.listen(host, port, backlog) -- 创建监听socket
local s = {
id = id,
connected = false,
listen = true,
}
socket_pool[id] = s
suspend(s) -- 等待监听成功
return id, s.addr, s.port -- 返回实际监听的地址和端口
end
function socket.start(id, func)
driver.start(id) -- 开始接收数据
return connect(id, func) -- 对于监听socket,func是accept回调
end
数据读取相关
lua
function socket.read(id, sz)
local s = socket_pool[id]
assert(s)
if sz == nil then
-- 读取所有可用数据
local ret = driver.readall(s.buffer, s.pool)
if ret ~= "" then return ret end
if not s.connected then return false, ret end
s.read_required = 0 -- 标记需要数据
suspend(s) -- 挂起等待数据
ret = driver.readall(s.buffer, s.pool)
return ret ~= "" and ret or false, ret
else
-- 读取指定大小数据
local ret = driver.pop(s.buffer, s.pool, sz)
if ret then return ret end
if s.closing or not s.connected then
return false, driver.readall(s.buffer, s.pool)
end
s.read_required = sz -- 设置读取需求
suspend(s) -- 等待足够数据
ret = driver.pop(s.buffer, s.pool, sz)
return ret or false, driver.readall(s.buffer, s.pool)
end
end
function socket.readline(id, sep)
sep = sep or "\n"
local s = socket_pool[id]
assert(s)
local ret = driver.readline(s.buffer, s.pool, sep)
if ret then return ret end
if not s.connected then
return false, driver.readall(s.buffer, s.pool)
end
s.read_required = sep -- 设置为行读取模式
suspend(s)
if s.connected then
return driver.readline(s.buffer, s.pool, sep)
else
return false, driver.readall(s.buffer, s.pool)
end
end
数据写入和连接管理
lua
socket.write = assert(driver.send) -- 异步发送
socket.lwrite = assert(driver.lsend) -- 低级别发送
function socket.close(id)
local s = socket_pool[id]
if s == nil then return end
driver.close(id) -- 底层关闭
if s.connected then
s.pause = false
if s.co then
-- 有协程在读取,等待读取完成
assert(not s.closing)
s.closing = coroutine.running()
skynet.wait(s.closing)
else
suspend(s) -- 等待清理
end
s.connected = false
end
socket_pool[id] = nil
end
UDP相关功能
lua
function socket.udp(callback, host, port)
local id = driver.udp(host, port)
create_udp_object(id, callback) -- 创建UDP socket对象
return id
end
socket.sendto = assert(driver.udp_send) -- UDP发送
业务调用链路分析
服务启动监听流程
lua
-- gateserver.lua 中的调用
function CMD.open(source, conf)
local address = conf.address or "0.0.0.0"
local port = assert(conf.port)
socket = socketdriver.listen(address, port, backlog) -- 创建监听socket
socketdriver.start(socket) -- 开始接受连接
end
客户端连接处理流程
lua
1. 客户端连接 → 底层驱动 → socket_message[4]
→ handler.connect() → gateserver.openclient()
2. 数据到达 → socket_message[1]
→ 数据压入缓冲区 → 检查读取需求 → wakeup()唤醒读取协程
3. 业务读取 → socket.read()
→ 缓冲区有数据立即返回 / 无数据则suspend()等待
数据发送流程
lua
-- 业务代码调用
socket.write(fd, data)
→ driver.send(fd, data) -- 异步发送到底层
可以把socket.lua想象成一个高效的快递分拣中心:
- socket_pool = 快递货架(存放所有包裹)
- buffer = 临时存放区(数据缓冲区)
- 协程机制 = 智能调度系统(有人取件时才通知)
- driver = 装卸工人(底层实际操作)
工作流程:
1. 收货(数据到达):
- 快递车(网络数据)到达 → 分拣员(socket_message)处理
- 根据标签(消息类型)放到对应货架
- 如果有人预订(read_required),就打电话通知(wakeup)
2. 取货(数据读取):
- 客户(业务代码)来取件 → 看货架有没有
- 有货直接拿走 → 没货就登记需求(read_required)然后等待(suspend)
3. 发货(数据发送):
- 客户要寄件 → 直接交给装卸工(driver.send)处理
- 装卸工负责打包发送,不阻塞分拣中心
4. 特殊服务:
- 暂停服务:货架太满时暂停收货(pause_socket)
- 超时提醒:发货堆积时发出警告(socket_message[7])
- 连接管理:新客户登记、老客户离开的接待流程