这里默认读者了解websocket协议,若是还不了解可以看下这篇文章wesocket协议。
websocket主要有三个步骤,1通过HTTP进行握手连接,2进行双向通信,3.协商断开连接
第一步的握手连接需要HTTP,所以还需要使用到上一节讲解的HTTP模块中的部分内容HttpContext类和HttpRequest类。
建立握手连接后,就不再需要使用HTTP了。之后就是通过帧的形式就行数据传输。
那可以给数据帧或者说是数据包封装成一个类WebsocketPacket。
1.WebsocketPacket类
该类一定是有帧头的一些信息,如fin,opcode等等。
cpp
enum WSOpcodeType : uint8_t
{
WSOpcode_Continue = 0x0,
WSOpcode_Text = 0x1,
WSOpcode_Binary = 0x2,
WSOpcode_Close = 0x8,
WSOpcode_Ping = 0x9,
WSOpcode_Pong = 0xA,
};
class WebsocketPacket
{
public:
WebsocketPacket()
:fin_(1) //1表示是消息的最后一个分片,表示不分包
, rsv1_(0)
, rsv2_(0)
, rsv3_(0)
, opcode_(1) //默认是发送文本帧
, mask_(0)
, payload_length_(0)
{
memset(masking_key_, 0, sizeof(masking_key_));
}
~WebsocketPacket(){ }
void reset()
{
fin_ = 1; //默认是1
rsv1_ = 0;
rsv2_ = 0;
rsv3_ = 0;
opcode_ = 1;//默认是发送文本帧
mask_ = 0;
memset(masking_key_, 0, sizeof(masking_key_));
payload_length_ = 0;
}
void decodeFrame(Buffer* frameBuf, Buffer* output);
void encodeFrame(Buffer* output, Buffer* data)const;
public:
uint8_t fin() const { return fin_; }
uint8_t rsv1() const { return rsv1_; }
uint8_t rsv2()const { return rsv2_; }
//省略部分成员的的获取函数如rsv3()等等,这里就没有显示出来,可查看完整代码
//...................
void set_fin(uint8_t fin) { fin_ = fin; }
void set_rsv1(uint8_t rsv1) { rsv1_ = rsv1; }
void set_rsv2(uint8_t rsv2) { rsv2_ = rsv2; }
//省略部分成员的设置的函数如set_rsv3(uint8_t rsv3)等等,这里就没有显示出来
private:
uint8_t fin_;
uint8_t rsv1_;
uint8_t rsv2_;
uint8_t rsv3_;
uint8_t opcode_;
uint8_t mask_;
uint8_t masking_key_[4];
uint64_t payload_length_;
};
这里重点就是两个函数decodeFrame和encodeFrame。从名字就可以看出来,一个是解帧,即是解析客户端发送过来的帧;另一个是封装成帧,发送给客户端。
1.1decodeFrame函数
按照websocket协议的数据帧进行解析即可。
这里要注意的是,若payloadlength是多字节的话,需要进行转序。
有掩码的操作就是这样,可以不用做过多了解,但想了解多点也可以。
cpp
void WebsocketPacket::decodeFrame(Buffer* frameBuf,Buffer* output)
{
const char* msg = frameBuf->peek();
int pos = 0;
//获取fin_
fin_=((unsigned char)msg[pos] >> 7);
//获取opcode_
opcode_ = msg[pos] & 0x0f;
pos++;
//获取mask_
mask_ = (unsigned char)msg[pos] >> 7;
//获取payload_length_
payload_length_ = msg[pos] & 0x7f;
pos++;
if (payload_length_ == 126) {
uint16_t length = 0;
memcpy(&length, msg + pos, 2);
pos += 2;
payload_length_ = ntohs(length);
}
else if (payload_length_ == 127) {
uint64_t length = 0;
memcpy(&length, msg + pos, 8);
pos += 8;
payload_length_ = ntohl(length);
}
//获取masking_key_
if (mask_ == 1) {
for (int i = 0; i < 4; i++)
masking_key_[i] = msg[pos + i];
pos += 4;
}
if (mask_ != 1) {
output->append(msg + pos, payload_length_);
}
else {
for (uint64 i = 0; i < payload_length_; i++) {
output->append(msg[pos + i] ^ masking_key_[i % 4], payload_length_);
}
}
}
1.2encodeFrame
也是按照websocket协议的数据帧进行封装帧即可。
注意是若payloadlength是多字节的话,需要进行转序。
还有服务器端发送的是没有掩码的。
cpp
void WebsocketPacket::encodeFrame(Buffer* output,Buffer* data)const
{
uint8_t onebyte = 0;
onebyte |= (fin_ << 7);
onebyte |= (rsv1_ << 6);
onebyte |= (rsv2_ << 5);
onebyte |= (rsv3_ << 4);
onebyte |= (opcode_ & 0x0F);
output->append((char*)&onebyte, 1);
onebyte = 0;
//set mask flag
onebyte = onebyte | (mask_ << 7);
int length = data->readableBytes();
if (length < 126){
onebyte |= length;
output->append((char*)&onebyte, 1);
}
else if (length == 126){
onebyte |= length;
output->append((char*)&onebyte, 1);
auto len = htons(length);
output->append((char*)&len, 2);
}
else if (length == 127){
onebyte |= length;
output->append((char*)&onebyte, 1);
// also can use htonll if you have it
onebyte = (payload_length_ >> 56) & 0xFF;
output->append((char*)&onebyte, 1);
onebyte = (payload_length_ >> 48) & 0xFF;
output->append((char*)&onebyte, 1);
onebyte = (payload_length_ >> 40) & 0xFF;
output->append((char*)&onebyte, 1);
onebyte = (payload_length_ >> 32) & 0xFF;
output->append((char*)&onebyte, 1);
onebyte = (payload_length_ >> 24) & 0xFF;
output->append((char*)&onebyte, 1);
onebyte = (payload_length_ >> 16) & 0xFF;
output->append((char*)&onebyte, 1);
onebyte = (payload_length_ >> 8) & 0xFF;
output->append((char*)&onebyte, 1);
onebyte = payload_length_ & 0XFF;
output->append((char*)&onebyte, 1);
}
if (mask_ == 1) //服务器发送给客户端的,是不带mask_key的,所以这个是没有用到的
{
output->append((char*)masking_key_, 4); // save masking key
char value = 0;
for (uint64_t i = 0; i < payload_length_; ++i) {
value = *(char*)(data->peek());
data->retrieve(1);
value = value ^ masking_key_[i % 4];
output->append(&value, 1);
}
}
else {
output->append(data->peek(), data->readableBytes());
}
}
数据帧解析和封装说完了,那就到握手连接和双向通信的了。可以封装个类WebsocketContext。
2.WebsocketContext类
该类有点类似上一节的HttpContext类,解包和封包的操作已有WebsocketPacket去处理。那这个类需要处理握手连接等问题。
WebsocketContext会拥有WebsocketPacket类型的请求包requestPacket_。其中函数parseData就是调用requestPacket_的decodeFrame。
websocketStatus_表示是否已握手连接的,构造函数是默认kUnconnect的。
cpp
class WebsocketContext {
public:
enum class WebsocketSTATUS { kUnconnect, kHandsharked };
WebsocketContext();
~WebsocketContext();
void handleShared(Buffer* buf, const std::string& server_key);
void parseData(Buffer* buf, Buffer* output);
void reset() { requestPacket_.reset(); }
void setwebsocketHandshared() { websocketStatus_ = WebsocketSTATUS::kHandsharked; }
WebsocketSTATUS getWebsocketSTATUS()const { return websocketStatus_; }
uint8_t getRequestOpcode()const { return requestPacket_.opcode(); }
private:
WebsocketSTATUS websocketStatus_;
WebsocketPacket requestPacket_;
};
那么接下来看看如何握手连接的
2.1handleShared
源代码里会有base64和sha1的代码。
这里就主要是按照给定的服务器端回复的握手格式进行回复。
cpp
#include "base64.h"
#include "sha1.h"
#define MAGIC_KEY "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
void WebsocketContext::handleShared(Buffer* buf, const std::string& serverKey)
{
buf->append("HTTP/1.1 101 Switching Protocols\r\n");
buf->append("Connection: upgrade\r\n");
buf->append("Sec-WebSocket-Accept: ");
std::string server_key = serverKey;
server_key += MAGIC_KEY;
SHA1 sha;
unsigned int message_digest[5];
sha.Reset();
sha << server_key.c_str();
sha.Result(message_digest);
for (int i = 0; i < 5; i++) {
message_digest[i] = htonl(message_digest[i]);
}
server_key = base64_encode(reinterpret_cast<const unsigned char*>(message_digest), 20);
server_key += "\r\n";
buf->append(server_key);
buf->append("Upgrade: websocket\r\n\r\n");
}
3.websocketServer
接着封装一个websocketServer类方便使用。这个类和HttpServer是很相似的,流程和HttpServer也是差不多的。
cpp
class websocketServer
{
public:
using WebsocketCallback = std::function<void(const Buffer*, Buffer*, WebsocketPacket& respondPacket)>;
websocketServer(EventLoop* loop, const InetAddr& listenAddr);
void setHttpCallback(const WebsocketCallback& cb) { websocketCallback_ = cb; }
void start(int numThreads);
private:
void onConnetion(const ConnectionPtr& conn); //连接到来的回调函数
void onMessage(const ConnectionPtr& conn, Buffer* buf); //消息到来的回调函数
void handleData(const ConnectionPtr& conn, WebsocketContext* websocket, Buffer* buf);
Server server_;
WebsocketCallback websocketCallback_;
};
setHttpCallback是设置用户的业务函数。
3.1onConnetion函数
连接到来时候会执行该函数
cpp
void websocketServer::onConnetion(const ConnectionPtr& conn)
{
if (conn->connected()) {
//conn->setContext(HttpContext()); //这是之前HttpServer的
conn->setContext(WebsocketContext());
//测试使用,用来测试绑定不符合的类型
//int a = 10; conn->setContext(a);
}
}
3.2onMessage
消息到来的时候会执行该函数。
该函数就先获取该conn的getMutableContext,得到该WebsocketContext类对象。
之后就两种情况,一种是还没进行握手的,一种是已进行握手的,进行通信的。
需要握手的 ,先通过解析http请求,获取请求头中的特定字段 ,发送特殊的HTTP响应头进行握手确认。
cpp
void websocketServer::onMessage(const ConnectionPtr& conn, Buffer* buf)
{
auto context = std::any_cast<WebsocketContext>(conn->getMutableContext()); //c++117
if (!context) {
printf("context kong...\n");
LOG_ERROR << "context is bad\n";
return;
}
if (context->getWebsocketSTATUS() == WebsocketContext::WebsocketSTATUS::kUnconnect) {
HttpContext http;
if (!http.parseRequest(buf)) {
conn->send("HTTP/1.1 400 Bad Request\r\n\r\n");
conn->shutdown();
}
if (http.gotAll()) {
auto httpRequese = http.request();
if (httpRequese.getHeader("Upgrade") != "websocket" ||
httpRequese.getHeader("Connection") != "Upgrade" ||
httpRequese.getHeader("Sec-WebSocket-Version") != "13" ||
httpRequese.getHeader("Sec-WebSocket-Key") == "") {
conn->send("HTTP/1.1 400 Bad Request\r\n\r\n");
conn->shutdown();
return; //表明不是websocket连接
}
Buffer handsharedbuf;
context->handleShared(&handsharedbuf, http.request().getHeader("Sec-WebSocket-Key"));
conn->send(&handsharedbuf);
context->setwebsocketHandshared();//设置建立握手
}
}
else {
handleData(conn, context, buf);
}
}
另一种情况,可以进行通信的,调用函数handleData。
主要流程:
- 先调用websocketContext的解析帧的函数parseData。之后得到fin,opcode等信息并把传输过来的数据写入到DataBuf中去。
- 之后再根据情况进行设置opcode。
- 之后再调用用户设置的回调函数来进行用户的业务处理。
- 再进行封装帧操作,发送给客户端。
这里需要注意的是:当收到客户主动发送过来的opcode是0x8(即是关闭),需要服务器端也返回ox8给客户端。因为websocket关闭是双方协商的。之后客户端收到0x8后就会关闭连接了。
cpp
void websocketServer::handleData(const ConnectionPtr& conn, WebsocketContext* websocket, Buffer* buf)
{
Buffer DataBuf;
websocket->parseData(buf, &DataBuf);
WebsocketPacket respondPacket;
int opcode = websocket->getRequestOpcode();
switch (opcode)
{
case WSOpcodeType::WSOpcode_Continue:
respondPacket.set_opcode(WSOpcodeType::WSOpcode_Continue);
break;
case WSOpcodeType::WSOpcode_Text:
respondPacket.set_opcode(WSOpcodeType::WSOpcode_Text);
break;
case WSOpcodeType::WSOpcode_Binary:
respondPacket.set_opcode(WSOpcodeType::WSOpcode_Binary);
break;
case WSOpcodeType::WSOpcode_Close:
respondPacket.set_opcode(WSOpcodeType::WSOpcode_Close);
break;
case WSOpcodeType::WSOpcode_Ping:
respondPacket.set_opcode(WSOpcodeType::WSOpcode_Pong); //进行心跳响应
break;
case WSOpcodeType::WSOpcode_Pong: //表示这是一个心跳响应(pong),那就不用回复了
return;
default:
LOG_INFO << "WebSocketEndpoint - recv an unknown opcode.\n";
return;
}
Buffer sendbuf;
if(opcode != WSOpcodeType::WSOpcode_Close && opcode != WSOpcode_Ping && opcode != WSOpcode_Pong)
websocketCallback_(&DataBuf, &sendbuf, respondPacket);
Buffer frameBuf;
respondPacket.encodeFrame(&frameBuf, &sendbuf);
conn->send(&frameBuf);
websocket->reset();
}
4.websocket的使用例子
用户主要就是写自己的业务函数,之后调用setHttpCallback设置自己的业务函数。
cpp
//用户的业务函数
void onRequest(const Buffer* input, Buffer* output){
//进行echo回复
output->append(input->peek(),input->readableBytes());
}
int main(int argc, char* argv[])
{
int numThreads = 0;
if (argc > 1) {
Logger::setLogLevel(Logger::LogLevel::WARN);
numThreads = atoi(argv[1]);
}
EventLoop loop;
websocketServer server(&loop, InetAddr(9999));
server.setHttpCallback(onRequest); //设置自己的业务函数
server.start(numThreads);
loop.loop();
return 0;
}
websocket的服务器基本就是这样了。
5.修复问题,Connection::handleRead()中的问题
这是在测试websocket的时候发现的问题。
在有新消息到来的时刻,是会调用Connection::handleRead()函数
那么在该函数中需要添加**inputBuffer_.retrieve(inputBuffer_.readableBytes());**这句代码。
cpp
void Connection::handleRead()
{
int savedErrno = 0;
auto n = inputBuffer_.readFd(fd(), &savedErrno);
if (n > 0) {
//这个是用户设置好的函数
messageCallback_(shared_from_this(), &inputBuffer_);
//新添加的,没有这句代码的话,那readindex可能就没有变化,那读取的数据就会包含上一次的
inputBuffer_.retrieve(inputBuffer_.readableBytes());//messageCallback_中处理好读取的数据后,更新readerIndex位置
}
else if (n == 0) {
//表示客户端关闭了连接
handleClose();
}
//....省略了对错误的处理
}
不然每次inputBuffer_的readerIndex就不会改变,那么每次input中获取到的数据都会包含上一次的数据。
也可以不添加,让用户在写业务函数的时候手动添加去更新readerIndex,但这样就不方便了,用户不应该去处理这些问题的。
在server_v10代码中,加不加这句代码是没有影响的,是因为用户的业务函数使用了Buffer::retrieveAllAsString()函数,该函数是会更新buf的readerIndex的,所以才会没有问题的。
cpp
//在代码server_v10中用户的业务函数
void onMessage(const ConnectionPtr& conn, Buffer* buf) {
std::string msg(buf->retrieveAllAsString());
printf("onMessage() %ld bytes reveived:%s\n", msg.size(), msg.c_str());
conn->send(msg);
}
int main(){
//..............
}
但不是每个用户编写自己的业务函数时候都一定使用这个函数的。所以需要在这添加这句代码inputBuffer_.retrieve(inputBuffer_.readableBytes());。
可以试试不添加这句代码和添加了这句代码的websocket服务器的效果。
完整源代码:https://github.com/liwook/CPPServer/tree/main/code/server_v21