手写一个C++ TCP服务器实现自定义协议(顺便解决粘包问题)

在之前的博客中,我们了解了关于UDP和TCP的网络编程,直观的感受了一下网络套接字是如何使用的,并且成功的完成了客户端与服务端的网络通信,但是其中还有一个小细节我们可能会忽略,就是UDP是基于数据报进行传输的,一下子就将所有我们要发送的信息传送给对方,但是我们的TCP可是基于字节流进行传输的,我们如何保证读取上来的数据,是一个完整的报文呢?

我们在进行TCP网络通信的时候,通过调用connec函数调用,使客户端可以和服务端保持链接之后,客户端将自己想要发送的数据通过write系统调用写进对应的socket函数调用给我们返回的文件描述符所对应的文件中。

现在有一个问题就是我们向文件中写入的时候,直接将其放入即可,但是想要往出拿的时候就有点困难了,想要往出拿的人如果不知道放的人是如何放的,就会造成一系列的错误,这就好比放数据时先放了一个整形,又放了一个浮点数,还放了一个字符串,然而拿的人按照字符串,整形,浮点数这样的方式进行获取,这就会导致数据不一致的现象,所以一旦我们要发送一些带有结构化的数据时,就必须再次制定------协议,这样才能满足我们想要返送一些结构化数据的需求。

TCP是传输控制协议,它主要负责的内容有:

  • 什么时候发送数据
  • 一次发送多少数据
  • 发送过程中出错了该怎么办

我们平时使用的read和write这些系统调用接口是从用户空间拷贝到内核空间,也就是我们在应用层将需要发送的数据拷贝到TCP的发送缓冲区中,这就结束了,至于到底什么时候发送这些数据给对方,全权由TCP进行控制,这样进行数据的传输。

但是正因如此,这就会倒是在接收端就会在接收的时候十分的困难,比如发送方发送了一个长报文,但是接收方只接收一点点数据就给用户返回了,这就会倒是接收方接收的数据时,会有不确定的情况。

这就好比假如现在和女朋友吵架了,这个时候你通过微信进行道歉,原本你想说的是我不后悔我爱你,结果由于接收方接收数据的问题,变成了我不后悔,这就会导致你女朋友意味你不后悔和她吵架,结果你就无辜的恢复了单身汉。

所以同理为了避免TCP在接收数据时的差异,我们就必须做好应用层的协议,这样才能保证发送的数据全部接收到并返回给用户,保证了接收方可以完整的接收到全部的数据。

那么应用层协议应该如何定制呢?我们通过一个列子进行理解。

现在假如我们在一个聊天群里进行聊天,大家彼此之间畅谈甚欢,现在假如我发送了一句哈哈哈,大家在群里看到的肯定不止只有哈哈哈这三个字,还会有我们的昵称,我们的头像,我什么时候发送的等等信息,所以看似我只是发送了一个哈哈哈三个字,但其实还有很多的附加数据同时进行了发送。

而我们的数据是通过字节流的方式进行传输的,所以我们必须将这些数据(昵称-头像-时间-信息)都转换为一个字符串,然后一起传送给对方,当对方接收到这个数据时,再将这一个字符串信息的内容进行分别拆解,最后显示到我们的显示器上。

所以这种方式就是序列化和反序列化。

所以现在我们实现一个简单的网络版本的计算器进行理解序列化和反序列化

TCP服务器整体结构

Client

│ request

TcpServer

│ callback

CalculatorServer

│ result

response

主要模块:

模块 作用
Socket 封装 socket API
TcpServer TCP服务器框架
Protocol 协议封装
CalculatorServer 计算逻辑

自定义应用层协议设计

由于 TCP 是 字节流协议 ,一次 read() 可能:

  • 读到半个数据

  • 读到多个数据

就比如:

客户端发送两个请求

10 + 20

5 * 6

服务器可能读到:

10 + 20\n5 * 6\n

或者:

10 +

正是因为如此,所以我们要自定义协议,接下来就是我们网络版本的计算器的协议设计:

协议格式

协议的格式设计如下:

len\n

content\n

举个例子就是:

6

10 + 5

编码表示就是:

6\n10 + 5\n

总之就是如下的格式:

|len| \n |content| \n

协议封装实现

Encode(封装报文)

复制代码
std::string Encode(std::string &content)
{
    std::string s;
    size_t len = content.size();

    s += std::to_string(len);
    s += "\n";
    s += content;
    s += "\n";

    return s;
}

例如:

10 + 20

就会变为:

7\n10 + 20\n

Decode(解析报文)

复制代码
bool Decode(std::string &s, std::string *content)
{
    size_t left_pos = s.find("\n");

    if (left_pos == std::string::npos)
        return false;

    std::string content_len = s.substr(0, left_pos);
    int len = std::stoi(content_len);

    if (s.size() < content_len.size() + len + 2)
        return false;

    *content = s.substr(left_pos + 1, len);

    s.erase(0, content_len.size() + len + 2);

    return true;
}

Decode 做了三件事:

  1. 判断是否有完整头部(长度)

  2. 判断数据是否完整

  3. 解析 + 从缓冲区移除

请求与响应设计

请求 request

客户端发送:

10 + 5

结构:

复制代码
class request
{
public:
    int x_;
    int y_;
    char op_;
};

序列化:

"10 + 5"

反序列化:

string -> request

响应 response

服务器返回:

"15 0"

结构:

复制代码
class response
{
public:
    int result_;
    int code_;
};

code的含义:

code 含义
0 成功
1 除0
2 取模0
3 非法操作

核心难点:解决 TCP 粘包问题

❓ 什么是粘包?

TCP 是面向字节流的协议:

👉 发送:

复制代码
请求1 + 请求2 + 请求3

👉 接收:

复制代码
可能变成:
请求1请求2 | 请求3

解码函数 Decode(重点)

复制代码
bool Decode(std::string &s, std::string *content)
{
    size_t pos = s.find("\n");
    if (pos == std::string::npos)
        return false;

    int len = std::stoi(s.substr(0, pos));

    if (s.size() < pos + 1 + len + 1)
        return false;

    *content = s.substr(pos + 1, len);
    s.erase(0, pos + 1 + len + 1);

    return true;
}

服务器的处理方法Calculator(重点)

复制代码
    std::string Calculator(std::string& s)
    {
        std::string content;
        if (Decode(s, &content) == false)
        {
            return "";
        }
        request req;
        bool r = req.Deserialization(content);

        if (!r)
        {
            return "";
        }
        response res = CalculatorHandler(req);

        std::string ret = res.serialization();
        ret = Encode(ret);
        return ret;
    }

服务器核心代码:(重点)

复制代码
while (true)
{
    int sockfd = listenfd_.Accept(&client_port, &client_ip);

    if (fork() == 0)
    {
        while (1)
        {
            char buffer[1280];

            ssize_t s = read(sockfd, buffer, sizeof buffer - 1);

            if (s > 0)
            {
                inbuffer_stream += buffer;

                while (true)
                {
                    std::string info = callback_(inbuffer_stream);

                    if (info.empty())
                        break;

                    write(sockfd, info.c_str(), info.size());
                }
            }
        }
    }
}

可以看到我们的服务器在收到一个报文之后,首先会调用服务器的处理方法Calculator,在处理方法中会进行解码,如果收到的报文不能够分解为类似(6\n10 + 5\n)这样的格式,我们就会返回一个空字符串,而一旦返回的是一个空字符串,我们的服务器就知道这个报文不完整,就会继续接收新的报文,直到接收到一个完整的报文且可以通过Decode解码成功之后,我们的程序才会继续进行,这样就保证了我们每次服务器处理的肯定是一个正确格式的报文。

完整代码

自定义协议

复制代码
#include <string>

#define blank_sep " "
#define protocol_sep "\n"

class request
{
public:
    request(int x, int y, char op)
        : x_(x), y_(y), op_(op)
    {
    }

    request()
    {
    }

    ~request()
    {
    }

    std::string serialization()
    {
        std::string str;
        str += std::to_string(x_);
        str += blank_sep;
        str += op_;
        str += blank_sep;
        str += std::to_string(y_);
        return str;
    }

    bool Deserialization(std::string &in)
    {
        size_t leftpos = in.find(blank_sep);
        if (leftpos == std::string::npos)
        {
            return false;
        }
        std::string str_x = in.substr(0, leftpos);
        x_ = std::stoi(str_x);

        op_ = in[leftpos + 1];

        size_t rightpos = in.rfind(blank_sep);
        if (rightpos == std::string::npos)
        {
            return false;
        }
        std::string str_y = in.substr(rightpos + 1);
        y_ = std::stoi(str_y);
        return true;
    }
    void DebugPrint()
    {
        std::cout << "新请求构建完成:  " << x_ << op_ << y_ << "=?" << std::endl;
    }

public:
    int x_;
    int y_;
    char op_;
};

class response
{
public:
    response()
    {
    }
    response(int result, int code)
        : result_(result), code_(code)
    {
    }

    ~response()
    {
    }

    std::string serialization()
    {
        std::string str;
        str += std::to_string(result_);
        str += blank_sep;
        str += std::to_string(code_);
        return str;
    }

    bool Deserialization(std::string &in)
    {
        size_t pos = in.find(blank_sep);

        if (pos == std::string::npos)
        {
            return false;
        }

        std::string str_result = in.substr(0, pos);
        result_ = std::stoi(str_result);

        std::string str_code = in.substr(pos + 1);
        code_ = std::stoi(str_code);
        return true;
    }
    void DebugPrint()
    {
        std::cout << "结果响应完成, result: " << result_ << ", code: " << code_ << std::endl;
    }

public:
    int result_;
    int code_;
};

std::string Encode(std::string &content)
{
    std::string s;
    size_t len = content.size();
    s += std::to_string(len);
    s += protocol_sep;
    s += content;
    s += protocol_sep;
    return s;
}

bool Decode(std::string &s, std::string *content)
{
    size_t left_pos = s.find(protocol_sep);
    if (left_pos == std::string::npos)
    {
        return false;
    }
    std::string content_len = s.substr(0, left_pos);
    int len = std::stoi(content_len);
    if (s.size() < content_len.size() + len + 2)
    {
        return false;
    }
    *content = s.substr(left_pos + 1, len);
    s.erase(0, content_len.size() + len + 2);
    return true;
}

服务端

复制代码
#include <iostream>
#include <unistd.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <string>

enum error
{
    SocketErr = 2,
    BindErr,
    ListenErr,
    ConnectErr,
};

class Socket
{
public:
    Socket()
    {
        sockfd_ = socket(AF_INET, SOCK_STREAM, 0);
        if (sockfd_ < 0)
        {
            std::cout << "socket fail" << std::endl;
            exit(SocketErr);
        }
    }

    void Bind(uint16_t &port, std::string &ip)
    {
        struct sockaddr_in server;
        server.sin_family = AF_INET;
        server.sin_port = htons(port);
        inet_pton(AF_INET, ip.c_str(), &server.sin_addr);
        if (bind(sockfd_, (struct sockaddr *)&server, sizeof(server)) < 0)
        {
            std::cout << "server bind fail!" << std::endl;
            exit(BindErr);
        }
        std::cout << "server bind successful" << std::endl;
    }

    void Listen()
    {
        if (listen(sockfd_, 10) < 0)
        {
            std::cout << "server listen fail!" << std::endl;
            exit(ListenErr);
        }
        std::cout << "server listen successful" << std::endl;
    }

    int Accept(uint16_t *client_port, std::string *client_ip)
    {
        struct sockaddr_in client;
        socklen_t len = sizeof(client);
        int sockfd = accept(sockfd_, (struct sockaddr *)&client, &len);
        if (sockfd < 0)
        {
            std::cout << "accept fail!" << std::endl;
            return -1;
        }
        std::cout << "accept successful" << std::endl;
        *client_port = ntohs(client.sin_port);
        char ip[64];
        inet_ntop(AF_INET, &client.sin_addr, ip, sizeof ip);
        *client_ip = ip;
        return sockfd;
    }

    void Connect(uint16_t &server_port, std::string &server_ip)
    {
        struct sockaddr_in server;
        server.sin_family = AF_INET;
        server.sin_port = htons(server_port);
        inet_pton(AF_INET, server_ip.c_str(), &server.sin_addr);
        if (connect(sockfd_, (struct sockaddr *)&server, sizeof server) < 0)
        {
            std::cout << "connect fail!" << std::endl;
            exit(ConnectErr);
        }
        std::cout << "connect successful!" << std::endl;
    }
    void Close()
    {
        close(sockfd_);
    }

    int fd()
    {
        return sockfd_;
    }

    ~Socket()
    {
        close(sockfd_);
    }

private:
    int sockfd_;
};
class CalculatorServer
{
public:
    CalculatorServer()
    {
    }

    response CalculatorHandler(const request &req)
    {
        response res(0, 0);
        switch (req.op_)
        {
        case '+':
            res.result_ = req.x_ + req.y_;
            break;
        case '-':
            res.result_ = req.x_ - req.y_;
            break;
        case '*':
            res.result_ = req.x_ * req.y_;
            break;
        case '/':
            if (req.y_ == 0)
            {
                res.code_ = 1;
                break;
            }
            res.result_ = req.x_ / req.y_;
            break;
        case '%':
            if (req.y_ == 0)
            {
                res.code_ = 2;
                break;
            }
            res.result_ = req.x_ % req.y_;
            break;
        default:
            res.code_ = 3;
            break;
        }
        return res;
    }

    std::string Calculator(std::string& s)
    {
        std::string content;
        if (Decode(s, &content) == false)
        {
            return "";
        }
        request req;
        bool r = req.Deserialization(content);

        if (!r)
        {
            return "";
        }
        response res = CalculatorHandler(req);

        std::string ret = res.serialization();
        ret = Encode(ret);
        return ret;
    }

    ~CalculatorServer()
    {
    }
};
using func_t = std::function<std::string(std::string &)>;
class TcpServer
{
public:
    TcpServer(uint16_t port, std::string ip, func_t callback)
        : port_(port), ip_(ip), callback_(callback)
    {
    }

    void InitServer()
    {
        listenfd_.Bind(port_, ip_);
        listenfd_.Listen();
        std::cout << "init server successful!" << std::endl;
    }

    void start()
    {
        signal(SIGCHLD, SIG_IGN);
        signal(SIGPIPE, SIG_IGN);
        while (true)
        {
            uint16_t client_port;
            std::string client_ip;
            int sockfd = listenfd_.Accept(&client_port, &client_ip);

            if (sockfd < 0)
            {
                continue;
            }

            if (fork() == 0)
            {
                // 子进程
                listenfd_.Close();
                std::string inbuffer_stream;

                while (1)
                {
                    char buffer[1280];
                    ssize_t s = read(sockfd, buffer, sizeof buffer - 1);
                    if (s > 0)
                    {
                        buffer[s] = 0;
                        // std::cout << buffer << std::endl;
                        inbuffer_stream += buffer;
                        while (true)
                        {
                            std::string info = callback_(inbuffer_stream);
                            std::cout << info << std::endl;
                            if (info.empty())
                            {
                                break;
                            }

                            write(sockfd, info.c_str(), info.size());
                        }
                    }
                    else if (s == 0)
                    {
                        break;
                    }
                    else
                    {
                        break;
                    }
                }
                close(sockfd);
                exit(0);
            }
            close(sockfd);
        }
    }

    ~TcpServer()
    {
    }

private:
    Socket listenfd_;
    uint16_t port_;
    std::string ip_;
    func_t callback_;
};
int main(int argc,char* argv[])
{
    if(argc != 3)
    {
        exit(0);
    }

    uint16_t server_port = std::atoi(argv[2]);
    std::string server_ip = argv[1];
    CalculatorServer cal;
    TcpServer* ser = new TcpServer(server_port,server_ip,std::bind(&CalculatorServer::Calculator, &cal, std::placeholders::_1));
    ser->InitServer();
    ser->start();
    return 0;
}

客户端

复制代码
#include <iostream>
#include <cassert>
#include <unistd.h>
#include "Protocol.hpp"
#include "Socket.hpp"

static void Usage(const std::string &proc)
{
    std::cout << "\nUsage: " << proc << " serverip serverport\n"
              << std::endl;
}

// ./clientcal ip port
int main(int argc, char *argv[])
{
    if (argc != 3)
    {
        Usage(argv[0]);
        exit(0);
    }
    std::string serverip = argv[1];
    uint16_t serverport = std::stoi(argv[2]);

    Socket sockfd;
    sockfd.Connect(serverport, serverip);

    srand(time(nullptr) ^ getpid());
    int cnt = 1;
    const std::string opers = "+-*/%=-=&^";

    std::string inbuffer_stream;
    while (cnt <= 10)
    {
        std::cout << "===============第" << cnt << "次测试....., " << "===============" << std::endl;
        int x = rand() % 100 + 1;
        usleep(1234);
        int y = rand() % 100;
        usleep(4321);
        char oper = opers[rand() % opers.size()];
        request req(x, y, oper);
        req.DebugPrint();

        std::string package;
        package = req.serialization();

        package = Encode(package);
        std::cout << package << std::endl;
        write(sockfd.fd(), package.c_str(), package.size());


        char buffer[128];
        ssize_t n = read(sockfd.fd(), buffer, sizeof(buffer) - 1); 
        if (n > 0)
        {
            buffer[n] = 0;
            inbuffer_stream += buffer; // "len"\n"result code"\n
            std::cout << inbuffer_stream << std::endl;
            std::string content;
            bool r = Decode(inbuffer_stream, &content); // "result code"
            assert(r);

            response resp;
            r = resp.Deserialization(content);
            assert(r);

            resp.DebugPrint();
        }

        std::cout << "=================================================" << std::endl;
        sleep(1);

        cnt++;
    }

    sockfd.Close();
    return 0;
}

到这里,我们已经完整实现了一个基于 C++ 的 TCP 计算器服务器,到现在,我们应该建立这样一个认知:

TCP编程的本质不是收发数据,而是如何正确解析数据的边界

总而言之,核心思想就一句话,就是TCP没有消息边界,所以我们必须设计应用层协议。

相关推荐
lucia_zl2 小时前
linux收集进程性能数据
linux·运维·chrome
无限码力2 小时前
华为OD机试真题2026双机位C卷 C++实现【日志解析】
c++·华为od·华为od机试真题·华为od机考真题·华为od机试真题-日志解析
道亦无名3 小时前
Linux下是STM32的编译修改配置文件tensorflow
linux·运维
炸膛坦客10 小时前
Linux - Ubuntu - PC端:(三)切换中英文,Fcitx5
linux·ubuntu
7yewh10 小时前
jetson_yolo_deployment 01_linux_dev_env
linux·嵌入式硬件·yolo·机器人·嵌入式
cyber_两只龙宝10 小时前
【Haproxy】Haproxy的算法详解及配置
linux·运维·服务器·云原生·负载均衡·haproxy·调度算法
阿常呓语10 小时前
Linux命令 jq详解
linux·运维·shell·jq
myloveasuka10 小时前
Java与C++多态访问成员变量/方法 对比
java·开发语言·c++
2301_8217005311 小时前
C++编译期多态实现
开发语言·c++·算法