仿照muduo库实现一个高并发服务器

项目介绍

Muduo网络库是由陈硕开发的一个基于C++的现代网络编程库,专注于高并发TCP网络应用的开发。‌它采用非阻塞I/O和事件驱动机制,能够高效处理大量并发连接。‌核心设计基于主从Reactor模型,其线程模型被称为"one loop per thread"。在这种模型中,每个线程仅运行一个事件循环(EventLoop),负责监听和响应定时器事件及I/O事件;同时,每个文件描述符(如TCP连接)只能由一个线程进行读写操作,即每个连接必须固定归属于某个EventLoop进行管理。‌这种设计避免了复杂的线程同步问题,提升了多线程环境下的并发性能。

本项目的旨在于实现高并发的服务器组件,可以在短时间内简洁完成一个高性能的服务器搭建。同时我们的组件提供应用层的协议支持,通过该支持可以让我们快速完成服务器搭建。在本项目中为了方便演示,我们采用的应用层协议为HTTP协议。(但根据具体业务需求的不同,编写不同的协议支持即可)

项目地址:https://github.com/XiaoLiang428/Reactor

认识Reactor模式

Reactor(反应器)模式是一种事件驱动的设计模式,用于处理多个并发输入源。它通过事件分发机制将服务请求分发给相应的请求处理器。

在该项目实现中就是哪一个客户端向服务器发送了数据,服务器就会去处理该客户发送的数据;对于没有发送数据的客户端,服务器则不会采取任何行动。这就是事件驱动处理模式。那么服务器是如何知晓哪个客户端触发了对应的事件从而需要服务器采取行动呢?这就要使用一个非常关键的技术------I/O多路复用(上一篇文章提及),统一监听对应的文件描述符,当一个或多个文件描述符的对应事件触发后再将其分给从属线程进行处理。

三种Reactor模式

1.单Reactor单线程:在单个线程中进行监听和业务处理

优点:因为是单线程,使用操作都是串行化的,思想比较简单因此编码较为容易

缺点:所有的任务都是在一个线程中进行处理因此容易造成性能瓶颈

2.单Reactor多线程:一个Reactor线程 + 业务线程池

优点:充分利用CPU多核资源,处理效率可以更高,降低代码耦合度

缺点:唯一的Reactor线程需要充当监听新连接的到来还需要对客户端事件进行监控,以及客户端的一系列I/O操作,不利于高并发场景。

3.多Reactor多线程:多个Reactor线程 + 业务线程池

优点:充分利用CPU多核资源,并且进行合理分配

但是执行流并不是越多越好,不同执行流的切换也会有开销。

本项目的TcpServer实现实现方式为One Thread One Loop主从Reactor模型

1.主从Reactor模型服务器

2.一个线程对应一个循环,循环操作:对该线程中管理的连接进行I/O事件监控+I/O操作+业务处理

根据上面三种Reactor模式的优缺点我们最后选择的是多Reactor的模式,但我们并不适用业务线程池去处理业务,而是在从属Reactor中完成业务的处理,这是为了避免有多个执行流切换导致性能下降。对于每一个连接的所有操作来说,它只会在从属Reactor线程中执行因此我们不需要去考虑其被其他线程争抢,这就是我们one thread one loop思想的模式。

总体模块介绍

我们实现的是一个带有协议支持的Reactor模型高性能服务器,总体可以分为两块:

第一块是Server模块,主要功能是Reactor模型高性能TCP服务器

第二块是协议模块,对当前Reactor模型高性能服务器提供应用层的协议支持

Server模块

本模块需要对所有的连接和从属线程进行管理,让它们各司其职一起经营好这个高并发服务器。本模块主要分为一下几个方面:

1.监听连接管理,负责接收新到来的连接,然后处理新到的连接添加到连接管理列表之中进行有效的管理并初始化这些连接的信息。

2.通信连接管理,负责管理正在与服务器进行通信的连接,方便主线程进行连接的管理。

3.连接超时管理,负责释放长时间不活跃的连接,避免服务器的资源被浪费。

4.线程管理,创建从属线程并管理好从属线程。

下面介绍TcpServer下的各个子模块及其作用

Buffer模块

Buffer模块是一个缓冲区,用于扮演通信中接收缓冲区和发送缓冲区的角色。

Socket模块

Socket模块是对socket套接字操作的一系列封装,方便我们后续的使用。

Channel模块

Channel 模块封装了一个连接对应的文件描述符及其所关心的事件。Channel 并不直接监控这些事件,而是由 Poller 统一管理。当Poller 检测到某个文件描述符上的事件发生时,它会通知对应的 Channel,然后 Channel 会调用预先设置好的回调函数来处理该事件。简而言之,Channel 是文件描述符、事件和事件处理逻辑之间的桥梁。

Connection模块

Connection 模块是对单个 TCP 连接的抽象和封装。每个 Connection 对象代表一个客户端连接,并管理该连接的整个生命周期。

它的核心作用包括:

1.封装资源:它封装了与连接相关的核心资源,包括 socket 文件描述符、输入/输出 Buffer,以及用于事件分发的 Channel。

  1. 生命周期管理:它负责处理连接的建立、数据收发、以及最终的关闭和资源释放。通过状态(ConnectionStatus)来跟踪连接的当前状态(如 CONNECTED, DISCONNECTING)。

  2. 事件处理:它将底层的 Channel 事件(可读、可写、关闭、错误)转化为更高级别的回调函数(_conn_cb, _msg_cb, _close_cb),供用户(TcpServer 的使用者)定义业务逻辑。

  3. 线程安全:它保证了所有操作都在其所属的 EventLoop 线程中执行,从而避免了多线程环境下的竞态条件。外部线程的调用(如 Send, ShutDown)会被安全地转移到 EventLoop 线程中执行。

  4. 数据缓冲:通过内置的 _in_buffer 和 _out_buffer,它管理了数据的收发缓冲,并处理了非阻塞 I/O 中数据可能无法一次性完整读写的情况。

  5. 附加功能:它还提供了如非活跃连接自动释放(通过 TimeWheel)、协议升级(Upgrade)等高级功能。

总而言之,Connection 模块为上层业务逻辑提供了一个清晰、安全、易于使用的 TCP 连接接口,将复杂的底层网络 I/O 和并发细节隐藏了起来。

Acceptor模块

Acceptor 模块是服务器中专门负责接收新 TCP 连接的组件。它运行在主 EventLoop 中(即 TcpServer 里的 _base_loop),只专注于高效地接受新连接:监听端口,等待客户端连接,然后将建立好的连接(以文件描述符的形式)分发出去。

TimeWheel模块(TimeQueue模块)

TimeWheel 模块在您的代码中实现了一个时间轮,这是一种高效管理大量定时任务的数据结构,其核心作用是处理有时效性的事件,例如自动断开不活跃的连接。它通过 timerfd 创建一个每秒触发一次的定时器,每次触发时,时间轮的指针(_tick)就向前移动一格。当指针扫过某个格子时,该格子中所有到期的定时任务(Timer 对象)就会被释放,从而触发其析构函数中绑-定的回调任务。该模块通过将所有操作(如添加、刷新、取消定时器)都放入其所属的 EventLoop 线程中执行,解决了多线程环境下的同步问题。

Poller模块

Poller 包装了 Linux 的 epoll 机制,是整个事件驱动模型的心脏。它将底层的、面向文件描述符的 epoll 调用,抽象成了面向 Channel 的、更高级别的接口。它本身不处理任何具体的 I/O 逻辑,而是专注于"等待并报告哪些 Channel 上有事件发生"。实现了事件检测(Poller 的职责)与事件处理(Channel 和其回调函数的职责)的分离。

EventLoop模块

EventLoop 是连接所有模块的中央调度器。它通过 Poller 监听 I/O 事件,通过 Channel 分发事件,通过任务队列实现跨线程通信,通过 TimeWheel 管理定时任务。它将并发的、异步的事件处理流程,转化为在其自身线程内的一个有序的、串行的执行序列,极大地简化了并发网络编程的复杂性。

为什么要一个EvnetLoop对应一个线程?

一个 Connection 对象及其 Channel 都属于一个 EventLoop 线程。如果另一个线程(比如主线程或其他 I/O 线程)想对这个 Connection 进行操作(如发送数据),直接调用其方法会引发线程安全问题。

解决方案:EventLoop 提供了 RunInLoop 和 QueueInLoop 机制。当一个外部线程需要在一个 EventLoop 中执行某个操作时,它会调用该 loop 的 RunInLoop(task)。RunInLoop 判断当前调用者是否为 EventLoop 自己的线程。如果不是,它会将 task 放入一个由互斥锁 _mutex 保护的 _task 任务队列中。任务入队后,它会调用 WeakUpEventFd(),向一个特殊的 _event_fd (eventfd) 文件描述符写入一个字节。由于这个 _event_fd 也被 Poller 监控着,写入操作会使其变为可读,从而唤醒正在 _poller.Poll() 中阻塞的 EventLoop 线程。EventLoop 被唤醒后,在处理完 I/O 事件后调用的 RunAllTasks() 方法中,就会从任务队列里取出并执行这个 task。通过这种方式,所有对 EventLoop 所管理资源的操作都被安全地转移到了 EventLoop 自己的线程中执行,从而保证了线程安全。

TcpServer模块

TcpServer 模块是整个网络库的顶层封装和最终用户接口,它将所有底层组件(如 Acceptor、LoopThreadPool、Connection)组合在一起,构建了一个完整的多线程事件驱动服务器。它的核心职责是:通过持有的 Acceptor 在主 EventLoop 中监听并接受新连接,然后利用 LoopThreadPool 以轮询(Round-Robin)的方式将新创建的 Connection 对象分发给一个从属的 EventLoop 线程进行管理。TcpServer 还负责维护所有连接的映射表,并向用户提供了设置连接建立、消息到达和连接关闭等核心事件回调的接口,从而将复杂的网络并发细节与上层业务逻辑清晰地分离开来。

通信连接管理模块的关系图

当socket接收到数据后,数据被存放在接收缓冲区buffer中,然后对应的读事件被触发,从而调用TcpServer设置的读事件回调函数并且刷新连接的活跃度,同时还会去执行组件使用者自身设置的读事件触发的回调函数。

监听连接管理的关系图

一旦有新的连接到来就会触发对应的读事件,然后执行对应的回调函数,该回调函数会为新连接初始化Connection对象并设置一系列的回调函数然后将其添加到连接管理和Poller监控中,一旦该连接有事件被触发就可以执行对应的回调函数。

事件监听管理的关系图

这个模块是事件监听,是每个从属线程都具有的。该模块主要是从属线程管理对应主线程分配给从属线程的连接,管理该连接的事件触发、连接的活跃度等等。

Server中模块的具体实现

TimeWheel模块

该模块是用来判断连接是否超时,如果超时了需要需要释放该连接。

timefd的认识与使用

Linux是为用户提供的定时器:

timerfd_create(int clockid,int flags)

Linux下一切接文件,定时器创建返回的是一个文件描述符,对该描述符操作就是对定时器操作。

参数:clockid是指定时钟源。可以选择CLOCK_REALTIME:系统实时时间(可被系统时间修改影响)、CLOCK_MONOTONIC:从系统启动开始计时,不受系统时间修改影响;flags控制文件描述符的行为,可选TFD_NONBLOCK:设置文件描述符为非阻塞模式 、TFD_CLOEXEC:设置 close-on-exec 标志(执行 exec 时自动关闭)。

返回值:返回一个文件描述符

定时器的原理:每隔固定的时间(定时器被设置的时间),系统就会向对应的文件描述符(定时器)写入一个8字节的数据。如果一个定时器的时间设定为3s,那么每个3s就会向描述符写入1。你如果6s只后读取该文件描述符就会读到2,表示该定时器已经超时了2次。

timefd_settime(int fd,int flags,const struct itimerspec *new_value,struct itimerspec *_Nullable old_value)

该函数是用来启动定时器,可以通过new_value来设置首次触发事件和后续间隔时间。

参数:fd是定时器对应的文件描述符,flags可以选0(相对时间模式)或TFD_TIMER_ABSTIME(绝对时间模式),new_value定义定时器的首次触发事件和后续间隔事件,old_value获取之前的定时器设置(用于还原之前定时器的设置,如果没有此需求可以设置为NULL)。

时间轮定时器思想

我们的服务器势必会同时维护大量的连接,如果每隔固定时间都去遍历一遍这些连接判断是否超时肯定会耗费大量CPU资源导致效率低下。因此我们采用小根堆的方式来判断堆顶是否从超时,如果没有超时那么其他的连接肯定也不会超时,如果其超时那么就需要一致检测堆顶连接直到其不再超时。

但是我们介绍一个更好的方案------时间轮

时间轮的算法思想借鉴了时钟的运作机制。想象一下设定一个三点钟的闹钟:当时针指向数字三的那一刻,闹钟便会响起。我们将这一直观的时空映射转化为精妙的算法结构。在实现上,我们采用一个环形队列(通常基于数组实现)来模拟表盘。每个时间刻度对应队列中的一个槽位,代表一个特定的时间点或时间间隔。当需要为某个连接设置超时或定时任务(例如释放资源)时,我们并非将其放入一个需要主动轮询的列表,而是将它精确地"放置"在未来对应时刻的槽位中。系统维护一个指针,如同时钟的指针一样,随着时间推移指针在环形队列中步进。当指针移动到某个槽位时,该槽位中所有累积的任务(释放连接)便会自动触发执行。这一机制将 O(n) 的遍历检查复杂度优化为 O(1) 的指针移动操作,从而实现了高效的大规模定时任务管理。

时间轮算法通过环形队列结构实现了定时器的高效管理,但仍面临一个关键挑战:如何支持连接的动态刷新机制?为此,我们提出了一种融合类析构函数与智能指针的创新设计。

类的析构函数 + 智能指针shared_ptr

1.通过类封装与析构机制将定时任务抽象为独立对象。我们设计一个专门的定时任务类,在其析构函数中封装具体的超时处理逻辑(如释放不活跃连接)。这样,当该对象生命周期结束时,便会自动触发预设的回调操作,实现定时任务的自动执行。

2.借助智能指针与环形队列实现任务的灵活调度。我们将每个定时任务对象包装在shared_ptr中,并置入时间轮的环形队列相应槽位。随着时间指针的循环推进,当指针移动到某个槽位时,只需清空该位置所有智能指针指向的对象。由于shared_ptr的自动引用计数管理,当槽位中的指针被清除后,若该任务对象不再被其他代码引用,其引用计数将归零,从而自动触发析构函数,执行连接释放等清理工作。

cpp 复制代码
using TaskFunc = std::function<void()>;
using ReleaseFunc = std::function<void()>;
class Timer
{
private:
    uint64_t _id; // Timer的_id是由主base线程进行统一分配,防止多线程中有_id重复从而导致哈希表中的信息错乱
    uint32_t _timeout;
    bool _cancel;
    TaskFunc _task_cb;
    ReleaseFunc _release_cb;

public:
    Timer(uint64_t id, uint32_t timeout, const TaskFunc &cb) : _id(id), _timeout(timeout), _task_cb(cb), _cancel(false) {}
    void SetRelease(const ReleaseFunc &cb) { _release_cb = cb; }
    void Cancel() { _cancel = true; }
    ~Timer()
    {
        if (!_cancel) //任务未取消才执行
            _task_cb();
        _release_cb();
    }
    uint32_t GetTimeout() { return _timeout; }
};

class TimeWheel
{
    using PtrTask = std::shared_ptr<Timer>;
    using WeakTask = std::weak_ptr<Timer>;

private:
    int _tick;                                      // 当前的秒针,走到哪个位置就释放哪个位置
    int _capacity;                                  // 表盘的大小
    std::vector<std::vector<PtrTask>> _wheel;       // 存储表盘中的内容
    std::unordered_map<uint64_t, WeakTask> _timers; // 将Timer存入到哈希表中方便我们寻找到对应的事件进行更新操作 此处使用weak_ptr是防止_wheel中的内容引用计数无法减到零从而无法析构的问题
    void RemoveTimer(uint64_t id)                   // 从unordered_map中移除Timer信息
    {
        auto it = _timers.find(id);
        if (it == _timers.end())
        {
            return;
        }
        _timers.erase(it);
    }

public:
    TimeWheel() : _tick(0), _capacity(60), _wheel(_capacity) {};
    void AddTimer(uint64_t id, uint32_t delay, const TaskFunc &cb)
    {
        PtrTask pt(new Timer(id, delay, cb)); // 创建一个Timer对象
        // 将Timer对象添加到表盘中
        int pos = (_tick + delay) % _capacity;
        _wheel[pos].push_back(pt);
        // 将Timer对象添加到哈希表中方便管理
        _timers[id] = WeakTask(pt); // 构造一个匿名的WeakTask对象然后再赋值,简化为直接构造
        // 当Timer对象释放的时候需要将其对应的信息从unordered_map中移除,绑定回调函数
        pt->SetRelease(std::bind(&TimeWheel::RemoveTimer, this, id));
    }
    void RefreshTimer(uint64_t id)
    {
        // 通过id寻找到对应的weak_ptr从而形成一个新的shared_ptr对象存储到表盘中,此时之前的shared_ptr对象释放时候Timer对象就不会被析构了,这样就达到了更新生命周期的目的
        auto it = _timers.find(id);
        if (it == _timers.end())
        {
            return;
        }
        // 说明寻找到了相应的weak_ptr
        PtrTask pt = it->second.lock();
        int pos = (_tick + pt->GetTimeout()) % _capacity;
        _wheel[pos].push_back(pt);
    }
    void TimerCancel(uint64_t id)
    {
        auto it = _timers.find(id);
        if (it == _timers.end())
        {
            return;
        }
        PtrTask pt = it->second.lock();
        if(pt) pt->Cancel();//取消该任务
    }
    void RunTimerTask()
    {
        _tick = (_tick + 1) % _capacity;
        _wheel[_tick].clear();
    }
};

Buffer模块

该模块是用来存储从socket套接字接收到的数据和需要发送的数据

思想:使用vector<char>来管理一片内存空间,vector<char>可以更好的处理像图片、二进制文件这种数据(string则会有遇到\0截至的弊端),然后使用两个下标指针分别标记读位置和写位置,进而维护这块缓冲区。

读入数据:直接从读下标所指向的位置开始读取,但是读取内容的大小不能超过可读的数据大小,并且需要在读之后更新可读数据大小。

写入数据:指向哪里就从哪个位置开始写入,但需要判断是否有足够的剩余空间写入新数据,如果读下标之后没有足够的剩余空间但是总体剩余空间足够,我们就需要挪动未读数据到数组起始位置从而方便新数据的写入;如果总体的剩余空间不足那就需要扩容,扩容的时候并不需要挪动原有的可读数据,而是直接扩容直到写下标后的空间大小可以容纳新数据。

cpp 复制代码
#define DefaultBufferSize 1024
class Buffer {
private:
    std::vector<char> _buffer;  // 使用vector来充当缓冲区
    size_t _reader_idx;         // 读下标
    size_t _writer_idx;         // 写下标
public:
    Buffer() {
        _buffer.resize(DefaultBufferSize);
        _reader_idx = 0;
        _writer_idx = 0;
    }
    char *Begin() { return &*_buffer.begin(); }
    // 获取当前读出起始地址
    char *ReadPosition() { return Begin() + _reader_idx; }
    // 获取当前写入起始地址
    char *WritePosition() { return Begin() + _writer_idx; }

    // 获取缓冲区前部空闲空间大小
    size_t HeadIdleSize() { return _reader_idx; }
    // 获取缓冲区后部空闲空间大小
    size_t TailIdleSize() { return _buffer.size() - _writer_idx; }

    // 获取可读数据大小
    size_t ReadAbleSize() { return _writer_idx - _reader_idx; }

    // 将读下标向后移动特定距离
    void MoveReadOffset(size_t len) {
        assert(len <= ReadAbleSize());  // 确保读取数据不越界
        _reader_idx += len;
    }
    // 将写下标向后移动特定距离
    void MoveWriteOffset(size_t len) {
        assert(len <= TailIdleSize());  // 确保能够放得下数据
        _writer_idx += len;
    }
    // 确保空间足够(如果空间足够就将数据前移,不够就直接扩容)
    void EnsureEnoughSpace(size_t len) {
        if (len <= TailIdleSize())  // 缓冲区后部的空闲空间足够
            return;
        if (len <= HeadIdleSize() + TailIdleSize())  // 缓冲区总体空闲空间足够
        {
            // 将数据拷贝到数组的起始位置
            std::copy(ReadPosition(), ReadPosition() + ReadAbleSize(), Begin());
            _writer_idx = ReadAbleSize();
            _reader_idx = 0;  // 注意_writer_idx与_reader_idx顺序!
            return;
        }
        // 说明数组中空闲区域的位置不足,此时直接扩容但不移动数据
        _buffer.resize(_buffer.size() + len);
        return;
    }
    // 写入数据
    void Write(const void *data, size_t len) {
        EnsureEnoughSpace(len);
        char *str = (char *)data;
        std::copy(str, str + len, WritePosition());
    }
    void WriteString(const std::string &str) { Write(&str[0], str.size()); }
    void WriteBuffer(Buffer &buffer) {
        Write(buffer.ReadPosition(), buffer.ReadAbleSize());  //?
    }
    void WriteAndPush(const void *data, size_t len) {
        Write(data, len);
        MoveWriteOffset(len);
    }
    void WriteStringAndPush(const std::string &str) {
        WriteString(str);
        MoveWriteOffset(str.size());
    }
    void WriteBufferAndPush(Buffer &buffer) {
        WriteBuffer(buffer);
        MoveWriteOffset(buffer.ReadAbleSize());
    }
    // 读取数据
    void Read(void *buffer, size_t len) {
        assert(len <= ReadAbleSize());
        std::copy(ReadPosition(), ReadPosition() + len, (char *)buffer);
    }
    void ReadAndPop(void *buf, size_t len) {
        Read(buf, len);
        MoveReadOffset(len);
    }
    std::string ReadAsString(size_t len) {
        assert(len <= ReadAbleSize());
        std::string str;
        str.resize(len);
        Read(&str[0], len);  //?
        return str;
    }
    std::string ReadAsStringAndPop(size_t len) {
        std::string str = ReadAsString(len);
        MoveReadOffset(len);
        return str;
    }

    char *GetCRLF() {
        char *res = (char *)memchr(ReadPosition(), '\n', ReadAbleSize());
        return res;
    }
    std::string GetLine() {
        char *crlf = GetCRLF();
        if (crlf) {
            size_t len = crlf - ReadPosition() + 1;  // 将\n也读出来
            return ReadAsStringAndPop(len);
        }
        return "";
    }
    std::string GetLineAndPop() {
        std::string line = GetLine();
        MoveReadOffset(line.size());
        return line;
    }
};

Socket模块

对socket套接字一系列操作的封装,方便后续使用

cpp 复制代码
class Socket {
private:
    int _socketfd;

public:
    Socket()
            : _socketfd(-1) {}
    Socket(int fd)
            : _socketfd(fd) {}
    ~Socket() { Close(); }
    // 创建socket
    bool Create() {
        _socketfd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
        if (_socketfd < 0) {
            LOG_ERROR("SOCKET CREATE ERROR");
            return false;
        }
        return true;
    }
    // 绑定ip和端口
    bool Bind(const std::string &ip, uint16_t port) {
        struct sockaddr_in addr;
        addr.sin_family = AF_INET;
        addr.sin_port = htons(port);
        inet_pton(AF_INET, ip.c_str(), &addr.sin_addr);
        if (bind(_socketfd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
            LOG_ERROR("SOCKET BIND ERROR");
            return false;
        }
        return true;
    }
    // 设置为监听模式
    bool Listen(int backlog = DefaultBackLog) {
        int ret = listen(_socketfd, backlog);
        if (ret < 0) {
            LOG_ERROR("SOCKET LISTEN ERROR");
            return false;
        }
        return true;
    }
    // 建立连接
    bool Connect(const std::string &ip, uint16_t port) {
        struct sockaddr_in addr;
        addr.sin_family = AF_INET;
        addr.sin_port = htons(port);
        inet_pton(AF_INET, ip.c_str(), &addr.sin_addr);
        int ret = connect(_socketfd, (struct sockaddr *)&addr, sizeof(addr));
        if (ret < 0) {
            LOG_ERROR("SOCKET CONNECT ERROR");
            return false;
        }
        return true;
    }
    // 获取新连接并返回新的连接描述符
    int Accept() {
        // 不关心客户端的ip和端口信息
        int newfd = accept(_socketfd, nullptr, nullptr);
        if (newfd < 0) {
            if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
                LOG_DEBUG("SOCKET ACCEPT ERROR AGAIN");
            } else {
                LOG_ERROR("SOCKET ACCEPT ERROR");
            }
        }
        return newfd;
    }
    ssize_t Recv(void *buffer, size_t len, int flag = 0) {
        ssize_t ret = recv(_socketfd, buffer, len, flag);
        if (ret < 0)  // 出错或者对端关闭连接
        {
            if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
                LOG_DEBUG("SOCKET RECV ERROR AGAIN");
                return -1;
            } else {
                LOG_ERROR("SOCKET RECV ERROR errno=%d(%s)", errno, strerror(errno));
                return -2;
            }
        }
        return ret;
    }
    ssize_t NonBlockRecv(void *buffer, size_t len) { return Recv(buffer, len, MSG_DONTWAIT); }
    ssize_t Send(const void *buffer, size_t len, int flag = 0) {
        ssize_t ret = send(_socketfd, buffer, len, flag);
        if (ret <= 0)  // 出错或者对端关闭连接
        {
            if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR)
                return 0;
            LOG_ERROR("SOCKET SEND ERROR errno=%d(%s)", errno, strerror(errno));
            return -1;
        }
        return ret;
    }
    ssize_t NonBlockSend(void *buffer, size_t len) { return Send(buffer, len, MSG_DONTWAIT); }
    bool CreateServer(uint16_t port, const std::string &ip = "0.0.0.0", bool flag = 0) {
        // 1.创建socket
        if (Create() == false)
            return false;
        // 复用地址
        ReuseAddress();
        // 2.绑定ip和端口
        if (Bind(ip, port) == false)
            return false;
        // 3.设置为监听模式
        if (Listen() == false)
            return false;
        // 设置套接字为非阻塞
        // if (flag)
        // SetNonBlock();
        return true;
    }
    bool CreateClient(uint16_t port, const std::string &ip) {
        // 1.创建socket
        if (Create() == false)
            return false;
        // 无需手动绑定ip和端口号,系统自动分配从而防止端口号冲突
        // 2.连接服务器
        if (Connect(ip, port) == false)
            return false;
        return true;
    }
    bool Close() {
        if (_socketfd != -1) {
            close(_socketfd);
            _socketfd = -1;
        }
        return true;
    }
    void ReuseAddress() {
        int opt = 1;
        setsockopt(_socketfd, SOL_SOCKET, SO_REUSEADDR, (void *)&opt, sizeof(opt));
        opt = 1;
        setsockopt(_socketfd, SOL_SOCKET, SO_REUSEPORT, (void *)&opt, sizeof(opt));
    }
    void SetNonBlock() {
        int flags = fcntl(_socketfd, F_GETFL, 0);
        fcntl(_socketfd, F_SETFL, flags | O_NONBLOCK);
    }
    int Fd() { return _socketfd; }
};

Channel模块

该模块是方便我们对一个连接进行监控操作,并设置对应事件触发后的回调函数。

关于EPOLLHUP、EPOLLDRHUP

EPOLLHUP表示文件描述符上发生了挂起事件,一般会因为对端正常关闭连接但另一端却仍然对该文件描述符进行操作而触发。

EPOLLDRHUP表示对端已关闭写操作,或连接半关闭, 一般会因为操作的对应文件描述符对应的连接发送了 FIN 包 (正常关闭)或关闭了写通道(半关闭状态)而触发。

cpp 复制代码
class Channel {
    using EventCallBack = std::function<void()>;

private:
    int _fd;                  // 该Channel所绑定的文件描述符
    EventLoop *_loop;         // 该Channel所属的事件循环器
    uint32_t _events;         // 监控的事件
    uint32_t _revents;        // 发生的事件
    EventCallBack _read_cb;   // 读事件被触发的回调函数
    EventCallBack _write_cb;  // 写事件被触发的回调函数
    EventCallBack _error_cb;  // 错误事件被触发的回调函数
    EventCallBack _close_cb;  // 断开事件被触发的回调函数
    EventCallBack _event_cb;  // 任意事件被触发的回调函数
public:
    Channel(int fd, EventLoop *loop)
            : _fd(fd)
            , _loop(loop)
            , _events(0)
            , _revents(0) {}
    // 获取文件描述符
    int Fd() { return _fd; }
    // 获取监控的事件
    uint32_t Events() { return _events; }
    // 被触发的事件
    void SetREvents(uint32_t revents) { _revents = revents; }
    // 设置读事件的回调函数
    void SetReadCallback(EventCallBack read_cb) { _read_cb = read_cb; }
    // 设置写事件的回调函数
    void SetWriteCallback(EventCallBack write_cb) { _write_cb = write_cb; }
    // 设置错误事件的回调函数
    void SetErrorCallback(EventCallBack error_cb) { _error_cb = error_cb; }
    // 设置断开事件的回调函数
    void SetCloseCallback(EventCallBack close_cb) { _close_cb = close_cb; }
    // 设置任意事件的回调函数
    void SetEventCallback(EventCallBack event_cb) { _event_cb = event_cb; }
    // 是否可读
    bool ReadAble() { return _events & EPOLLIN; }
    // 是否可写
    bool WriteAble() { return _events & EPOLLOUT; }
    // 开启读事件监控
    void EnableRead() {
        _events |= (EPOLLIN | EPOLLRDHUP);
        Update();
    }
    // 开启写事件监控
    void EnableWrite() {
        _events |= EPOLLOUT;
        Update();
    }
    // 关闭读事件监控
    void DisableRead() {
        _events &= ~EPOLLIN;
        Update();
    }
    // 关闭写事件监控
    void DisableWrite() {
        _events &= ~EPOLLOUT;
        Update();
    }
    // 关闭所有事件监控
    void DisableAll() {
        _events = 0;
        Update();
    }
    // 从EventLoop中移除该Channel
    void Remove();
    void Update();
    void Handler() {
        if ((_revents & EPOLLIN) || (_revents & EPOLLPRI) || (_revents & EPOLLRDHUP)) {
            if (_read_cb)
                _read_cb();
            // 事件处理完毕后刷新活跃度
        }
        // 有可能释放连接的事件,一次只处理一个(保证安全)
        if (_revents & EPOLLOUT) {
            if (_write_cb)
                _write_cb();
            // 事件处理完毕后刷新活跃度
        }
        if (_revents & EPOLLERR) {
            if (_error_cb)
                _error_cb();
        }
        if (_revents & EPOLLHUP) {
            if (_close_cb)
                _close_cb();
        }
        if (_event_cb) {
            _event_cb();
        }
    }
};

Poller模块

每一个连接对应的Channel都会被添加到EventLoop进行事件监控,Poller模块就是为了让EventLoop能够监视该从属线程中的所有连接事件。

Poller模块功能:1.添加/修改描述符的事件监控 2.移除描述符的事件监控

处理逻辑:

1.Channel设置对应连接的描述符需要被监控的事件从而让Poller进行监控

2.当对应的文件描述符就绪时,通过文件描述符和Channel的映射关系来寻找到对应的Channel对象并调用对应的回调函数进行事件处理

cpp 复制代码
#define MAXEVENTS 1024
class Poller {
private:
    int _epfd;
    struct epoll_event _events[MAXEVENTS];
    std::unordered_map<int, Channel *> _channels;

private:
    void Update(Channel *channel, int op) {
        struct epoll_event ev;
        ev.data.fd = channel->Fd();
        ev.events = channel->Events();
        if (epoll_ctl(_epfd, op, channel->Fd(), &ev) < 0) {
            if (op & EPOLL_CTL_ADD)
                LOG_ERROR("EPOLL CTL ADD ERROR");
            LOG_ERROR("EPOLL CTL ERROR");
        }
    }
    bool CheckChannel(Channel *channel) {
        int fd = channel->Fd();
        auto it = _channels.find(fd);
        if (it == _channels.end()) {
            return false;
        }
        return true;
    }

public:
    Poller() {
        _epfd = epoll_create(MAXEVENTS);
        if (_epfd < 0) {
            LOG_ERROR("EPOLL CREATE ERROR");
        }
    }
    ~Poller() { close(_epfd); }
    void UpdateEvent(Channel *channel) {
        int fd = channel->Fd();
        if (fd < 0)
            return;
        bool exists = CheckChannel(channel);
        uint32_t evs = channel->Events();
        if (evs == 0) {
            if (exists) {
                // 删除channel
                Update(channel, EPOLL_CTL_DEL);
                _channels.erase(fd);
            }
            return;
        }
        if (!exists) {
            // 新增channel
            Update(channel, EPOLL_CTL_ADD);
            _channels[fd] = channel;
        } else {
            // 修改channel
            Update(channel, EPOLL_CTL_MOD);
        }
        // if (CheckChannel(channel) == false) {
        //     // 新增channel

        //     Update(channel, EPOLL_CTL_ADD);
        //     _channels[channel->Fd()] = channel;
        // } else {
        //     // 修改channel
        //     Update(channel, EPOLL_CTL_MOD);
        // }
    }
    void RemoveEvent(Channel *channel) {
        if (CheckChannel(channel) == false) {
            return;
        }
        Update(channel, EPOLL_CTL_DEL);
        _channels.erase(channel->Fd());
    }
    void Poll(std::vector<Channel *> *actvie) {
        int nfds = epoll_wait(_epfd, _events, MAXEVENTS, -1);
        if (nfds < 0) {
            if (errno == EINTR) {
                return;
            }
            LOG_ERROR("EPOLL WAIT ERROR,%s\n", strerror(errno));
            abort();
        }
        for (int i = 0; i < nfds; ++i) {
            int fd = _events[i].data.fd;
            auto it = _channels.find(fd);
            assert(it != _channels.end());
            Channel *channel = it->second;
            channel->SetREvents(_events[i].events);
            actvie->push_back(channel);
        }
    }
};

EventLoop模块

我们的思想是one thread one loop,一个从属线程对应一个EventLoop对象,从属线程需要管理分配到该线程中的所有连接及其事件处理。

为了让该连接的所有事件处理均在其对应的EventLoop线程中进行,我们使用将所有对连接的操作进行封装并添加到任务队列的方式来完成,不使用加锁是因为大量连接对应的大量I/O事件如果都加锁,那么将会导致效率下降。

EventLoop处理流程:

1.对所管理的连接进行事件监控

2.如果有描述符对应的事件被触发了就直接进行处理

3.所有的就绪事件完成之后就去查看任务队列中是否有任务,如果有任务就执行。

但是EventLoop线程可能因为连接迟迟没有事件触发而导致阻塞,从而使任务队列中的任务迟迟得不到执行!因此我们将eventfd也添加到监视之中,如果任务队列有任务那么eventfd可读就会被触发。

cpp 复制代码
class EventLoop {
private:
    using Functor = std::function<void()>;
    std::thread::id _thread_id;               // 事件循环所属线程的id
    int _event_fd;                            // 用于唤醒事件循环的文件描述符
    Poller _poller;                           // 事件监控
    std::unique_ptr<Channel> _event_channel;  // 用于监听_event_fd的Channel
    std::vector<Functor> _task;               // 任务队列
    std::mutex _mutex;
    TimeWheel _time_wheel;

public:
    static int CreateEventFd() {
        int evtfd = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC);
        if (evtfd < 0) {
            LOG_ERROR("EVENTFD CREATE ERROR");
            abort();
        }
        return evtfd;
    }

public:
    EventLoop()
            : _thread_id(std::this_thread::get_id())
            , _event_fd(CreateEventFd())
            , _poller()
            , _event_channel(new Channel(_event_fd, this))
            , _time_wheel(this) {
        _event_channel->SetReadCallback(std::bind(&EventLoop::ReadEventFd, this));
        _event_channel->EnableRead();
    }
    void ReadEventFd()  // 读取_event_fd中的数据,防止一直触发可读事件
    {
        uint64_t value;
        ssize_t ret = read(_event_fd, &value, sizeof(value));
        if (ret < 0) {
            if (errno == EAGAIN)
                return;

            LOG_ERROR("EVENTFD READ ERROR");
            abort();
        }
    }
    void WeakUpEventFd()  // 写入数据到_event_fd,唤醒事件循环
    {
        uint64_t value = 1;
        int ret = write(_event_fd, &value, sizeof(value));
        if (ret < 0) {
            if (errno == EAGAIN)
                return;
            LOG_ERROR("EVENTFD WRITE ERROR");
            abort();
        }
    }
    // 事件监控 -> 就绪事件处理 -> 执行任务
    void Start() {
        while (true) {
            std::vector<Channel *> actives;
            _poller.Poll(&actives);
            for (auto channel : actives) { channel->Handler(); }
            RunAllTasks();
        }
    }
    void RunAllTasks() {
        std::vector<Functor> tasks;
        {
            std::unique_lock<std::mutex> lock(_mutex);
            tasks.swap(_task);
        }
        if (tasks.empty())
            return;
        for (const Functor &task : tasks) { task(); }
    }
    void RunInLoop(const Functor cb) {
        if (IsInLoop()) {
            return cb();
        }
        return QueueInLoop(cb);
    }
    // 将任务压入任务队列
    void QueueInLoop(const Functor cb) {
        {
            std::unique_lock<std::mutex> lock(_mutex);
            _task.push_back(cb);
        }
        // 有可能因为没有事件触发导致epoll一直等待,因此我们需要使用EventFd唤醒它
        // 使用EventFd写入数据来触发可读事件
        WeakUpEventFd();
    }
    bool IsInLoop()  // 判断是否属于EventLoop线程
    {
        return _thread_id == std::this_thread::get_id();
    }
    void AssertInLoop() { return assert(_thread_id == std::this_thread::get_id()); }
    void UpdateEvent(Channel *channel)  // 更新Channel事件
    {
        _poller.UpdateEvent(channel);
    }
    void RemoveEvent(Channel *channel)  // 移除Channel事件
    {
        _poller.RemoveEvent(channel);
    }
    void TimerAdd(uint64_t id, uint32_t delay, const TaskFunc &cb) { _time_wheel.TimerAdd(id, delay, cb); }
    void TimerRefresh(uint64_t id) { _time_wheel.TimerRefresh(id); }
    void TimerCancel(uint64_t id) { _time_wheel.TimerCancel(id); }
    bool TimerCheck(uint64_t id) { return _time_wheel.TimerCheck(id); }
    bool HasTimer(uint64_t id) { return _time_wheel.TimerCheck(id); }
};

EventLoop模块内部关系图

Connection模块

Connection模块是对单个连接的全面管理,对连接的所有操作都是通过Connection对象完成。

对连接管理主要分为下面几个方面:

1.对连接套接字的管理,能够进行套接字操作

2.对连接事件的管理,方便进行监听和采取对应的回调函数进行处理

3.对连接接收和发送缓冲区的管理,可以接收连接对端发送的数据和向对端发送数据

4.协议上下文的管理,记录请求数据的处理过程(上层从接收缓冲区拿去数据并进行处理,如果数据不完整则会继续等待下一次数据直到数据完整在进行处理)

5.对连接回调函数的管理,包括连接建立回调函数(连接建立后需要对连接进行什么处理,由组件使用者决定)、消息到达回调函数(从对应连接接收到数据后应该如何处理,由组件使用者决定)、连接关闭回调函数(连接关闭的时候应该另外做什么处理,又组件使用者决定)、任何事件触发回调函数(任意事件触发了是否还需要处理什么业务,由组件使用者决定)。

cpp 复制代码
typedef enum { CONNECTED, CONNECTING, DISCONNECTED, DISCONNECTING } ConnectionStatus;
class Connection : public std::enable_shared_from_this<Connection> {
    using PTRConnection = std::shared_ptr<Connection>;

private:
    uint64_t _conn_id;              // 连接的id,用来唯一标识该连接
    uint64_t _timer_id;             // 定时器的id,用来唯一标识定时器,为了方便我们将其与_conn_id的值设置一致
    bool _enable_inactive_release;  // 是否启用不活跃连接释放
    int _sockfd;                    // 连接对应的文件描述符
    EventLoop *_loop;               // 连接关联的对应EventLoop线程
    ConnectionStatus _conn_status;  // 连接的当前状态
    Socket _socket;                 // 管理连接套接字
    Channel _channel;               // 管理连接事件
    Buffer _in_buffer;              // 接收缓冲区
    Buffer _out_buffer;             // 发送缓冲区
    Any _context;                   // 请求的上下文
    // 由组件使用者自定义的回调函数
    using ConnectionCallback = std::function<void(const PTRConnection &)>;
    using MessageCallback = std::function<void(const PTRConnection &, Buffer *)>;
    using CloseCallback = std::function<void(const PTRConnection &)>;
    using AnyEventCallback = std::function<void(const PTRConnection &)>;
    ConnectionCallback _conn_cb;  // 连接建立回调函数
    MessageCallback _msg_cb;      // 消息到达回调函数
    CloseCallback _close_cb;      // 连接关闭回调函数
    AnyEventCallback _event_cb;   // 任何事件触发回调函数
    // 组件内关闭连接所用的回调函数,服务器将所有连接管理起来,当连接关闭时需要将其从服务器的连接列表中移除
    CloseCallback _server_close_cb;

private:
    // 处理从socket套接字读取数据事件,需要将数据从socket中取出然后放入到接收缓冲区中
    void HandleRead() {
        char buffer[65536];
        int ret = _socket.NonBlockRecv(buffer, sizeof(buffer));
        if (ret <= 0) {
            if (ret == 0)
                return Release();
            if (ret == -1)
                return;
            return ShutdownInLoop();
        }
        // 将读取到的数据放入到接收缓冲区中
        buffer[ret] = '\0';
        _in_buffer.WriteAndPush(buffer, ret);
        // 调用消息到达回调函数
        if (_msg_cb)
            _msg_cb(shared_from_this(), &_in_buffer);
    }
    // 处理向socket套接字写入数据事件,需要将发送缓冲区中的数据发送到socket中
    void HandleWrite() {
        ssize_t ret = _socket.NonBlockSend(_out_buffer.ReadPosition(), _out_buffer.ReadAbleSize());
        if (ret < 0) {
            if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
                // 无法发送数据,等待下一次写事件触发
                return;
            }
            // 发送数据出错,关闭连接
            return Release();
        }
        _out_buffer.MoveReadOffset(ret);
        // 判断发送缓冲区中是否还有数据需要发送
        if (_out_buffer.ReadAbleSize() == 0) {
            // 关闭写事件的监控,否则会一致触发写事件因为存在空闲区域
            _channel.DisableWrite();
            if (_conn_status == DISCONNECTING) {
                // 说明是用户调用了关闭连接操作,并且数据已经发送完毕,可以关闭连接了
                return Release();
            }
        }
    }
    void HandleClose() {
        // 处理剩余未处理数据
        if (_in_buffer.ReadAbleSize() > 0) {
            _msg_cb(shared_from_this(), &_in_buffer);
        }
        // 调用EventLoop内释放函数
        return Release();
    }
    void HandleError() { return HandleClose(); }
    void HandleEvent() {
        // 刷新事件活跃度
        if (_enable_inactive_release == true)
            _loop->TimerRefresh(_timer_id);
        // 调用用户设置的任意事件回调函数
        if (_event_cb)
            _event_cb(shared_from_this());
    }
    void EstablishInLoop() {
        // 1.修改连接状态 2.启动读事件 3.调用连接建立回调函数
        _conn_status = CONNECTED;
        _channel.EnableRead();
        if (_conn_cb)
            _conn_cb(shared_from_this());
    }
    // 真正的释放连接
    void ReleaseInLoop() {
        // 1.修改连接状态 2.移除连接的事件监控 3.关闭socket文件描述符 4.如果有定时器就将其移除 5.调用连接关闭回调函数
        // 6.调用服务器的关闭连接回调函数
        if (_conn_status == DISCONNECTED)
            return;
        _conn_status = DISCONNECTED;
        _channel.Remove();
        _socket.Close();
        if (_loop->HasTimer(_timer_id))
            _loop->TimerCancel(_timer_id);
        if (_close_cb)
            _close_cb(shared_from_this());
        if (_server_close_cb) {  // _server_close_cb
            _server_close_cb(shared_from_this());
        }
        // 在这里打印
    }
    // 不是实际的发送,而是将数据放入发送缓冲区并且启动写事件监控
    void SendInLoop(const char *data, size_t len) {
        // 1.将数据放入发送缓冲区 2.启动写事件监控
        // 防止数据是临时数据,及时将数据进行保存
        Buffer buf;
        buf.WriteAndPush(data, len);
        _out_buffer.WriteBufferAndPush(buf);
        _channel.EnableWrite();
    }
    // 不是实际的关闭连接,而是修改连接状态并且等待发送缓冲区的数据发送完毕和接收缓冲区的数据处理完毕后再关闭连接
    void ShutdownInLoop() {
        // 1.设置连接状态 2.处理接收缓冲区数据 3.处理发送缓冲区数据
        if (_conn_status == DISCONNECTING || _conn_status == DISCONNECTED)
            return;
        _conn_status = DISCONNECTING;
        if (_in_buffer.ReadAbleSize() > 0)
            _msg_cb(shared_from_this(), &_in_buffer);
        if (_out_buffer.ReadAbleSize() == 0)
            Release();  // 说明发送缓冲区没有数据可以发送,直接关闭连接
        else
            _channel.EnableWrite();  // 说明发送缓冲区还有数据需要发送,启动写事件监控
    }
    // 启动非活跃连接释放定时器
    void EnableInactiveReleaseInLoop(int sec) {
        // 1.修改对应的标志位 2.如果存在定时器就刷新 3.如果不存在就直接新增
        _enable_inactive_release = true;
        if (_loop->TimerCheck(_timer_id)) {
            _loop->TimerRefresh(_timer_id);
        } else {
            _loop->TimerAdd(_timer_id, sec, std::bind(&Connection::ReleaseInLoop, this));
        }
    }
    void CancelInactiveReleaseInLoop() {
        // 1.修改对应的标志位 2.如果存在定时器就取消
        _enable_inactive_release = false;
        if (_loop->TimerCheck(_timer_id)) {
            _loop->TimerCancel(_timer_id);
        }
    }
    void UpgradeInLoop(
        const Any &context, const ConnectionCallback &conn_cb, const MessageCallback &msg_cb,
        const CloseCallback &close_cb, const AnyEventCallback &event_cb) {
        _context = context;
        _conn_cb = conn_cb;
        _msg_cb = msg_cb;
        _close_cb = close_cb;
        _event_cb = event_cb;
    }

public:
    Connection(EventLoop *loop, uint64_t conn_id, int sockfd)
            : _conn_id(conn_id)
            , _timer_id(conn_id)
            , _enable_inactive_release(false)
            , _sockfd(sockfd)
            , _loop(loop)
            , _conn_status(CONNECTING)
            , _socket(sockfd)
            , _channel(sockfd, loop) {
        _channel.SetReadCallback(std::bind(&Connection::HandleRead, this));
        _channel.SetWriteCallback(std::bind(&Connection::HandleWrite, this));
        _channel.SetCloseCallback(std::bind(&Connection::HandleClose, this));
        _channel.SetErrorCallback(std::bind(&Connection::HandleError, this));
        _channel.SetEventCallback(std::bind(&Connection::HandleEvent, this));
    }
    ~Connection() { LOG_INFO("Connection %lu destructed", _conn_id); }
    int Fd() { return _sockfd; }
    uint64_t GetConnId() { return _conn_id; }
    bool Connected() { return _conn_status == CONNECTED; }
    void SetContext(const Any &context) { _context = context; }
    Any *GetContext() { return &_context; }
    void SetConnectionCallback(const ConnectionCallback &cb) { _conn_cb = cb; }
    void SetMessageCallback(const MessageCallback &cb) { _msg_cb = cb; }
    void SetCloseCallback(const CloseCallback &cb) { _close_cb = cb; }
    void SetAnyEventCallback(const AnyEventCallback &cb) { _event_cb = cb; }
    void SetSvrCloseCallback(const CloseCallback &cb) { _server_close_cb = cb; }
    void Established() { _loop->RunInLoop(std::bind(&Connection::EstablishInLoop, this)); }
    void Send(const char *data, size_t len) {
        _loop->RunInLoop(std::bind(&Connection::SendInLoop, this, data, len));
    }
    void ShutDown()  // 关闭连接,但需要检查是否有数据未发送完毕
    {
        _loop->RunInLoop(std::bind(&Connection::ShutdownInLoop, this));
    }
    // 确保连接释后事件事件不在会被触发,防止前面的事件为定时器事件然后导致连接被释放从而影响后面连接事件触发时候访问非法内存
    void Release() { _loop->QueueInLoop(std::bind(&Connection::ReleaseInLoop, this)); }
    void EnableInactiveRelease(int sec) {
        _loop->RunInLoop(std::bind(&Connection::EnableInactiveReleaseInLoop, this, sec));
    }
    void CancelInactiveRelease() {
        _loop->RunInLoop(std::bind(&Connection::CancelInactiveReleaseInLoop, this));
    }
    // 协议切换------重置上下文和回调函数,该函数必须在对应的EventLoop线程中调用
    // 防备新事件触发时候切换任务还未执行,从而导致数据和协议不对等!
    void Upgrade(
        const Any &context, const ConnectionCallback &conn_cb, const MessageCallback &msg_cb,
        const CloseCallback &close_cb, const AnyEventCallback &event_cb) {
        _loop->AssertInLoop();
        _loop->RunInLoop(
            std::bind(&Connection::UpgradeInLoop, this, context, conn_cb, msg_cb, close_cb, event_cb));
    }
};

Acceptor模块

该模块负责管理监听套接字。主要功能是创建监听套接字、启动对应描述符的可读事件、处理新达到的连接、对新到达的连接进行处理。

cpp 复制代码
class Acceptor {
private:
    Socket _listenfd;
    EventLoop *_loop;
    Channel _channel;

    using AccptorCallback = std::function<void(int)>;
    AccptorCallback _accept_cb;

private:
    int CreateSvrFd(uint16_t port) {
        bool ret = _listenfd.CreateServer(port);
        assert(ret == true);
        return _listenfd.Fd();
    }
    void HandleRead() {
        while (true) {
            //_listenfd.SetNonBlock();
            int newfd = _listenfd.Accept();
            if (newfd < 0) {
                if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
                    // 没有新的连接请求
                    break;
                }
                LOG_ERROR("ACCEPT ERROR,%s", strerror(errno));
                break;
            }
            if (_accept_cb)
                _accept_cb(newfd);
        }
    }

public:
    Acceptor(EventLoop *loop, uint16_t port)
            : _listenfd(CreateSvrFd(port))
            , _loop(loop)
            , _channel(_listenfd.Fd(), loop) {
        _channel.SetReadCallback(std::bind(&Acceptor::HandleRead, this));
    }
    void SetAcceptCallback(AccptorCallback cb) { _accept_cb = cb; }
    // Listen函数是确保在设置好回调函数后才开始监听连接请求,防止因为没有设置好回调函数而错过连接请求并且会有资源浪费(新的连接没有close)
    void Listen() { _channel.EnableRead(); }
};

LoopThread模块

该模块是为了更好的实现one thread one loop思想,让EventLoop和线程一一对应。

为什么不采用先创建EventLoop对象然后传入对应的从属线程中?

C++ 的"传入线程"不等于"安全迁移":要么需要移动/拷贝对象(很多包含 mutex、fd、unique_ptr 的类型不可拷贝/不宜移动),要么传指针/引用共享同一个对象(那就必须额外保证生命周期、发布/可见性、并发访问不数据竞争)。这些复杂度在 reactor 里一般不值得。因此我们使用先创建线程,然后在线程函数内执行创建EventLoop对象的方式,这样更可控。

cpp 复制代码
class LoopThread {
private:
    std::mutex _mutex;              // 互斥锁
    std::condition_variable _cond;  // 条件变量
    EventLoop *_loop;               // 事件循环指针
    std::thread _thread;            // 线程对象 新线程的入口函数
public:
    void ThreadFunc() {
        EventLoop loop;
        {
            std::unique_lock<std::mutex> lock(_mutex);
            _loop = &loop;
            // 有可能_loop还没有实例化就GetLoop被调用了,因此需要通知其创建成功
            _cond.notify_one();  // 通知创建成功
        }
        loop.Start();  // 启动事件循环
    }

public:
    LoopThread()
            : _loop(nullptr)
            , _thread(std::thread(&LoopThread::ThreadFunc, this)) {}
    EventLoop *GetLoop() {
        EventLoop *loop = nullptr;
        {
            std::unique_lock<std::mutex> lock(_mutex);
            _cond.wait(lock, [&]() { return _loop != nullptr; });
            loop = _loop;
        }
        return loop;
    }
};

LoopThreadPool模块

该模块是针对LoopThread创建的一个线程池,可以对所有的从属线程LoopThread进行管理。主要功能是以主线程的 EventLoop 作为 base_loop,在需要多线程时创建若干 LoopThread(每个线程内部各自运行一个独立的 EventLoop),并通过 GetNextLoop() 以 RR 轮询的方式选择下一个从属 EventLoop,从而把新连接/任务分发到不同的事件循环线程上处理,实现 Reactor 的多线程扩展;当线程数为 0 时则退化为所有工作都由主 EventLoop 处理。

cpp 复制代码
class LoopThreadPool {
private:
    int _thread_count;
    EventLoop *_base_loop;               // 主线程的事件循环
    std::vector<LoopThread *> _threads;  // 存储子线程对象
    std::vector<EventLoop *> _loops;     // 存储子线程的事件循环
    int _next_idx;                       // 下一个被选中的线程索引
public:
    LoopThreadPool(EventLoop *base_loop)
            : _base_loop(base_loop)
            , _thread_count(0)
            , _next_idx(0) {}
    void SetThreadCount(int thread_count) { _thread_count = thread_count; }
    // 创建从属线程
    void CreateDependedLoops() {
        if (_thread_count > 0) {
            _threads.resize(_thread_count);
            _loops.resize(_thread_count);
            for (int i = 0; i < _thread_count; ++i) {
                _threads[i] = new LoopThread();
                _loops[i] = _threads[i]->GetLoop();
            }
        }
    }
    // RR轮转获取下一个从属线程
    EventLoop *GetNextLoop() {
        if (_thread_count == 0) {
            return _base_loop;
        }
        _next_idx = (_next_idx + 1) % _thread_count;
        return _loops[_next_idx];
    }
};

TcpServer模块

TcpServer 把 TCP 的"接入、事件监听、线程分发、连接状态、收发缓冲、超时与生命周期"都收敛到 Reactor 框架里,而把协议/业务通过回调注入进去,从而实现高内聚的网络内核 + 可插拔的上层逻辑。

主从 Reactor 分工:主线程持有一个 base loop只负责监听/接入连接,accept 到新 fd 后把它"分发"给某个工作 EventLoop;工作线程(EventLoop)负责该连接后续所有 I/O 事件处理与回调执行。

连接抽象成 Connection 对象:每条 TCP 连接对应一个 Connection,内部把 "socket + channel(epoll 事件) +接收和发送缓冲区+ 状态机" 封装起来;读写/关闭/错误等事件被 Channel 触发后,转到 Connection 的HandleRead/HandleWrite/HandleClose等方法处理,避免业务层直接碰 epoll 细节。

回调驱动、业务与网络解耦:TcpServer 持有用户设置的连接建立/消息到达/连接关闭/任意事件回调(ConnectionCallback/MessageCallback/CloseCallback/AnyEventCallback),新连接建立时把这些回调注入到 Connection;这样框架只负责事件分发与资源管理,协议/业务逻辑由回调提供。

线程间任务投递靠 RunInLoop/QueueInLoop + eventfd 唤醒:EventLoop 里维护任务队列,跨线程调用通过 QueueInLoop 入队,并写 eventfd 唤醒 epoll_wait,保证"所有实际状态修改在所属 loop 线程执行",减少并发数据竞争风险。

连接生命周期统一管理(可选超时释放):TcpServer 用 unordered_map 维护所有活跃连接;Connection 释放时通过 _server_close_cb 回调回到 TcpServer 删除记录;同时通过 TimeWheel实现"非活跃连接定时释放"(连接任何事件触发会 TimerRefresh),把超时控制也收敛到 loop 线程语义内。

cpp 复制代码
class TcpServer {
    using PTRConnection = std::shared_ptr<Connection>;

private:
    int _next_id;
    int _thread_count;
    int _timeout;  // 非活跃连接释放时间
    bool _enable_inactive_release;
    uint16_t _port;
    EventLoop _base_loop;
    Acceptor _acceptor;
    LoopThreadPool _pool;
    std::unordered_map<uint64_t, PTRConnection> _conns;
    using ConnectionCallback = std::function<void(const PTRConnection &)>;
    using MessageCallback = std::function<void(const PTRConnection &, Buffer *)>;
    using CloseCallback = std::function<void(const PTRConnection &)>;
    using AnyEventCallback = std::function<void(const PTRConnection &)>;
    ConnectionCallback _conn_cb;  // 连接建立回调函数
    MessageCallback _msg_cb;      // 消息到达回调函数
    CloseCallback _close_cb;      // 连接关闭回调函数
    AnyEventCallback _event_cb;   // 任何事件触发回调函数
    using Functor = std::function<void()>;

private:
    void RunAfterInLoop(const Functor &cb, uint32_t delay) {
        _next_id++;
        _base_loop.TimerAdd(0, delay, cb);
    }
    void HandleNewConnection(int sockfd) {
        uint64_t conn_id = _next_id++;
        EventLoop *loop = _pool.GetNextLoop();
        PTRConnection conn(new Connection(loop, conn_id, sockfd));
        conn->SetConnectionCallback(_conn_cb);
        conn->SetMessageCallback(_msg_cb);
        conn->SetCloseCallback(_close_cb);
        conn->SetAnyEventCallback(_event_cb);
        conn->SetSvrCloseCallback(std::bind(&TcpServer::RemoveConnectionInLoop, this, std::placeholders::_1));
        if (_enable_inactive_release)
            conn->EnableInactiveRelease(_timeout);
        _conns.emplace(conn_id, conn);
        conn->Established();
    }
    void RemoveConnectionInLoop(const PTRConnection &conn) {
        LOG_DEBUG("RemoveConnectionInLoop called");
        uint64_t id = conn->GetConnId();
        if (_conns.find(id) == _conns.end()) {
            return;
        }
        _conns.erase(id);
    }
    void RemoveConnection(const PTRConnection &conn) {
        LOG_DEBUG("RemoveConnection called");
        _base_loop.RunInLoop(std::bind(&TcpServer::RemoveConnectionInLoop, this, conn));
    }

public:
    TcpServer(uint16_t port)
            : _next_id(0)
            , _thread_count(0)
            , _timeout(0)
            , _enable_inactive_release(false)
            , _port(port)
            , _base_loop()
            , _acceptor(&_base_loop, port)
            , _pool(&_base_loop) {
        _acceptor.SetAcceptCallback(std::bind(&TcpServer::HandleNewConnection, this, std::placeholders::_1));
    }
    void SetThreadCount(int thread_count) {
        _thread_count = thread_count;
        _pool.SetThreadCount(thread_count);
    }
    void SetInactiveReleaseTime(int timeout) {
        _enable_inactive_release = true;
        _timeout = timeout;
    }
    void SetConnectionCallback(const ConnectionCallback &cb) { _conn_cb = cb; }
    void SetMessageCallback(const MessageCallback &cb) { _msg_cb = cb; }
    void SetCloseCallback(const CloseCallback &cb) { _close_cb = cb; }
    void SetAnyEventCallback(const AnyEventCallback &cb) { _event_cb = cb; }
    void EnableInactiveRelease(int timeout) {
        _timeout = timeout;
        _enable_inactive_release = true;
    }
    // 添加定时任务
    void RunAfter(const Functor &cb, uint32_t delay) {
        _base_loop.RunInLoop(std::bind(&TcpServer::RunAfterInLoop, this, cb, delay));
    }
    void Start() {
        LOG_INFO("Server starting...");
        _pool.CreateDependedLoops();
        _acceptor.Listen();
        _base_loop.Start();
    }
};

HTTP模块

http模块主要分为以下几块:

1.Util

工具类集合:提供 URL 编解码、字符串分割、文件读写、状态码描述(StatusDesc)、根据后缀推导 MIME(ExternMime)、判断路径/文件类型(IsDir/IsRegular)、以及路径合法性校验(ValidPath,用于防目录穿越)。

2.HttpRequest

HTTP 请求数据模型:保存方法/URL/版本/正文、请求头(_headers)、查询参数(_params),并提供 ContentLength()、Close() 等便捷判断;_matches 用于路由正则匹配后提取分组结果。

3.HttpResponse

HTTP 响应数据模型:保存状态码、重定向信息、响应正文、响应头;提供 SetContent()(填充 body 并设置长度/类型)、SetRedirect()(设置 Location 等)、Close()(是否短连接)等。

4.HttpContext

"连接级别"的 HTTP 解析上下文/状态机:把一条 TCP 连接上的字节流按 RECV_STATE_LINE/HEAD/BODY/OVER 分阶段解析成一个完整 HttpRequest;同时维护解析失败时的响应状态码(如 400/414),并支持 ReSet() 以便复用在 keep-alive 的下一次请求解析。

5.HttpServer

基于底层 TcpServer 的 HTTP 应用层封装:通过 OnConnected() 给每个连接绑定 HttpContext。通过 OnMessage() 驱动 HttpContext 解析缓冲区数据:解析未完成就继续等,解析完成就进行路由与响应构造,并用 WriteResponse() 按 HTTP 格式拼响应后发送。提供路由注册接口 Get/Post/Put/Delete:内部用"正则 + handler"的表做匹配分发。支持静态资源:SetBaseDir() 后,GET/HEAD 且路径合法时走 FileHandler() 读取文件并设置 MIME。连接管理策略:根据请求/响应的 Connection 头决定是否 ShutDown();并启用底层的非活跃连接超时释放。

Util模块

本模块主要是为HTTP解析与静态资源服务提供工具函数,包括字符串处理、URL编码与解码、文件读写、状态码映射、路径合法性校验等。

cpp 复制代码
class Util {
public:
    static int HexToI(int c) {
        if (c >= '0' && c <= '9') {
            return c - '0';
        } else if (c >= 'a' && c <= 'f') {
            return c - 'a' + 10;
        } else if (c >= 'A' && c <= 'F') {
            return c - 'A' + 10;
        } else {
            return -1;
        }
    }
    // 字符串分割
    // 将str按照sep进行分割,结果存入out中,返回分割出的字符串个数
    static size_t SplitString(const std::string str, const std::string sep, std::vector<std::string> *out) {
        size_t offset = 0;
        while (offset < str.size()) {
            size_t pos = str.find(sep, offset);
            if (pos == std::string::npos) {
                // 说明找到了字符串的末尾
                out->push_back(str.substr(offset, str.size() - offset));
                return out->size();
            }
            if (pos != offset)  // 如果内容为空就不加入到结果中
                out->push_back(str.substr(offset, pos - offset));
            offset = pos + sep.size();  // 移动偏移量
        }
        return out->size();
    }
    // 读取文件内容 将文件的内容放入一个Buffer对象中
    static bool ReadFile(const std::string &filename, std::string *buf) {
        std::ifstream ifs(filename, std::ios::binary);
        if (!ifs.is_open()) {
            LOG_INFO("ReadFile %s failed", filename.c_str());
            return false;
        }
        // 获取文件的大小然后一次性读取
        ifs.seekg(0, ifs.end);            // 将读写位置跳转到文件末尾
        size_t file_size = ifs.tellg();   // 获取当前位置与文件开头的偏移量即为文件大小
        ifs.seekg(0, ifs.beg);            // 将读写位置跳转到文件开头
        buf->resize(file_size);           // 预留空间
        ifs.read(&(*buf)[0], file_size);  // 读取文件内容
        if (ifs.good() == false) {
            LOG_INFO("ReadFile %s failed during read", filename.c_str());
            ifs.close();
            return false;
        }
        ifs.close();
        return true;
    }
    // 向文件中写入内容
    static bool WriteFile(const std::string &filename, const std::string &buf) {
        std::ofstream ofs(filename, std::ios::binary | std::ios::trunc);  // 以二进制写入且清空原有内容
        if (ofs.is_open() == false) {
            LOG_INFO("WriteFile %s failed", filename.c_str());
            return false;
        }
        ofs.write(buf.data(), buf.size());
        if (ofs.good() == false) {
            LOG_INFO("WriteFile %s failed during write", filename.c_str());
            ofs.close();
            return false;
        }
        ofs.close();
        return true;
    }
    // Url编码
    // RFC3986文档 绝对不编码字符 数字 大小写字母 - _ . ~
    // W3C标准规定 空格需要转化为+
    // RFC2396文档规定Url保留字符需要转化为%HH格式
    static std::string UrlEncode(const std::string &url, bool convert_sep_to_plus) {
        std::string result;
        for (auto &c : url) {
            if (c == '-' || c == '_' || c == '.' || c == '~'
                || isalnum(c))  // isalnum可以判断字符是否为数字或字母
            {
                // 不编码字符
                result += c;
                continue;
            } else if (c == ' ' && convert_sep_to_plus) {
                // 空格转化为+
                result += '+';
            } else {
                // 其他字符转化为%HH格式
                char buf[4] = {0};
                snprintf(buf, 4, "%%%02X", c);
                result += buf;
            }
        }
        return result;
    }
    // Url解码
    static std::string UrlDecode(const std::string &url, bool convert_plus_to_space) {
        std::string result;
        for (int i = 0; i < url.size(); i++) {
            if (url[i] == '%' && i + 2 < url.size()) {
                int high = HexToI(url[i + 1]);
                int low = HexToI(url[i + 2]);
                if (high != -1 && low != -1) {
                    char decoded_char = (high << 4) | low;
                    result += decoded_char;
                    i += 2;  // 跳过已经处理的两个字符
                }
            } else if (url[i] == '+' && convert_plus_to_space == true) {
                // + 转化为空格
                result += ' ';
            } else {
                result += url[i];
            }
        }
        return result;
    }
    // 获取文件状态
    static std::string StatusDesc(int code) {
        std::unordered_map<int, std::string> status_map = {
            {100, "Continue"},
            {101, "Switching Protocols"},
            {200, "OK"},
            {201, "Created"},
            {202, "Accepted"},
            {203, "Non-Authoritative Information"},
            {204, "No Content"},
            {205, "Reset Content"},
            {206, "Partial Content"},
            {300, "Multiple Choices"},
            {301, "Moved Permanently"},
            {302, "Found"},
            {303, "See Other"},
            {304, "Not Modified"},
            {305, "Use Proxy"},
            {307, "Temporary Redirect"},
            {400, "Bad Request"},
            {401, "Unauthorized"},
            {402, "Payment Required"},
            {403, "Forbidden"},
            {404, "Not Found"},
            {405, "Method Not Allowed"},
            {406, "Not Acceptable"},
            {407, "Proxy Authentication Required"},
            {408, "Request Time-out"},
            {409, "Conflict"},
            {410, "Gone"},
            {411, "Length Required"},
            {412, "Precondition Failed"},
            {413, "Request Entity Too Large"},
            {414, "Request-URI Too Large"},
            {415, "Unsupported Media Type"},
            {416, "Requested range not satisfiable"},
            {417, "Expectation Failed"},
            {500, "Internal Server Error"},
            {501, "Not Implemented"},
            {502, "Bad Gateway"},
            {503, "Service Unavailable"},
            {504, "Gateway Time-out"},
            {505, "HTTP Version not supported"}};
        auto it = status_map.find(code);
        if (it == status_map.end()) {
            return "Unknown Status";
        }
        return it->second;
    }
    // 获取文件Mime
    static std::string ExternMime(const std::string &filename) {
        size_t pos = filename.find_last_of('.');
        if (pos == std::string::npos)
            return "application/octet-stream";  // 默认二进制流
        std::string ext = filename.substr(pos + 1);
        static std::unordered_map<std::string, std::string> mime_map = {
            // 文本与文档类型
            {"txt", "text/plain"},
            {"html", "text/html"},
            {"htm", "text/html"},
            {"css", "text/css"},
            {"csv", "text/csv"},
            {"xml", "text/xml"},

            // 图像类型
            {"jpg", "image/jpeg"},
            {"jpeg", "image/jpeg"},
            {"png", "image/png"},
            {"gif", "image/gif"},
            {"bmp", "image/bmp"},
            {"webp", "image/webp"},
            {"svg", "image/svg+xml"},
            {"ico", "image/x-icon"},

            // 应用程序与二进制文件
            {"pdf", "application/pdf"},
            {"json", "application/json"},
            {"zip", "application/zip"},
            {"rar", "application/vnd.rar"},
            {"7z", "application/x-7z-compressed"},
            {"tar", "application/x-tar"},
            {"gz", "application/gzip"},
            {"exe", "application/x-msdownload"},
            {"doc", "application/msword"},
            {"docx", "application/vnd.openxmlformats-officedocument.wordprocessingml.document"},
            {"xls", "application/vnd.ms-excel"},
            {"xlsx", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"},
            {"ppt", "application/vnd.ms-powerpoint"},
            {"pptx", "application/vnd.openxmlformats-officedocument.presentationml.presentation"},
            {"odt", "application/vnd.oasis.opendocument.text"},

            // 音频与视频类型
            {"mp3", "audio/mpeg"},
            {"wav", "audio/wav"},
            {"ogg", "audio/ogg"},
            {"aac", "audio/aac"},
            {"mp4", "video/mp4"},
            {"avi", "video/x-msvideo"},
            {"mpeg", "video/mpeg"},
            {"mpg", "video/mpeg"},
            {"webm", "video/webm"},
            {"ogv", "video/ogg"},

            // 字体类型
            {"ttf", "font/ttf"},
            {"otf", "font/otf"},
            {"woff", "font/woff"},
            {"woff2", "font/woff2"},

            // JavaScript
            {"js", "application/javascript"},
            {"mjs", "application/javascript"}  // 可以根据需要添加更多的扩展名和对应的MIME类型
        };
        auto it = mime_map.find(ext);
        if (it == mime_map.end())
            return "application/octet-stream";  // 默认二进制流
        return it->second;
    }
    // 判断该文件是否是目录
    static bool IsDir(const std::string &filename) {
        struct stat path_stat;
        if (stat(filename.c_str(), &path_stat) != 0) {
            return false;  // 无法获取文件状态,可能文件不存在
        }
        return S_ISDIR(path_stat.st_mode);
    }
    // 判断该文件是否是普通文件
    static bool IsRegular(const std::string &filename) {
        struct stat path_stat;
        if (stat(filename.c_str(), &path_stat) != 0) {
            return false;  // 无法获取文件状态,可能文件不存在
        }
        return S_ISREG(path_stat.st_mode);
    }
    // 判断文件路径是否有效
    // 用户只能在相对于某个根目录的范围内访问文件,防止目录遍历攻击
    static bool ValidPath(const std::string &path) {
        // 按照/来分割路径
        std::vector<std::string> subdirs;
        int ret = SplitString(path, "/", &subdirs);
        int level = 0;
        for (auto &dir : subdirs) {
            if (dir == "..") {
                level--;
                if (level < 0)
                    return false;  // 试图访问根目录之外
            } else if (dir == "." || dir.empty()) {
                // 当前目录或空字符串,忽略
                continue;
            } else
                level++;
        }
        return true;
    }
};

HttpRequest模块

该模块的职责是解析Http请求并将其对应的字段解析并存储起来,提供的方法包括头部字段的查询、插入和获取,Query参数的查询、插入和获取,正文长度获取,是否是长连接等。

cpp 复制代码
class HttpRequest {
public:
    std::string _method;
    std::string _url;
    std::string _version;
    std::string _body;
    std::smatch _matches;
    std::unordered_map<std::string, std::string> _headers;
    std::unordered_map<std::string, std::string> _params;

public:
    HttpRequest()
            : _version("HTTP/1.1") {}
    // 重置
    void Reset() {
        _method.clear();
        _url.clear();
        _version = "HTTP/1.1";
        _body.clear();
        std::smatch tmp;
        _matches.swap(tmp);
        _headers.clear();
        _params.clear();
    }
    // 插入头部字段
    void SetHeader(const std::string &key, const std::string &value) { _headers[key] = value; }
    // 检查头部字段
    bool HasHeader(const std::string &key) const {
        auto it = _headers.find(key);
        if (it == _headers.end())
            return false;
        return true;
    }
    // 获取头部字段的值
    std::string GetHeader(const std::string &key) const {
        auto it = _headers.find(key);
        if (it != _headers.end())
            return it->second;
        return "";
    }
    // 插入查询字符串
    void SetParam(const std::string &key, const std::string &value) { _params[key] = value; }
    // 检查查询字符串
    bool HasParam(const std::string &key) const {
        auto it = _params.find(key);
        if (it == _params.end())
            return false;
        return true;
    }
    // 获取查询字符串的值
    std::string GetParam(const std::string &key) const {
        auto it = _params.find(key);
        if (it != _params.end())
            return it->second;
        return "";
    }
    // 获取正文长度
    size_t ContentLength() const {
        if (HasHeader("Content-Length") == false) {
            return 0;
        }
        return std::stoul(GetHeader("Content-Length"));
    }
    // 判断是否是长连接
    bool Close() {
        if (HasHeader("Connection") && GetHeader("Connection") == "keep-alive")
            return false;  // 长连接
        return true;
    }
};

HttpResponse模块

该模块负责存储HTTP响应的要素(响应状态码、受否重定向、重定向url、正文),提供简单的功能接口,包括头部字段的设置、检查和获取,正文设置,重定向url设置和长短连接判断。

cpp 复制代码
class HttpResponse {
public:
    int _status_code;           // 相应状态码
    bool _redirect_flag;        // 是否重定向
    std::string _redirect_url;  // 重定向地址
    std::string _body;          // 相应正文
    std::unordered_map<std::string, std::string> _headers;

public:
    HttpResponse()
            : _status_code(200)
            , _redirect_flag(false) {}
    HttpResponse(int code)
            : _status_code(code)
            , _redirect_flag(false) {}
    void ReSet() {
        _status_code = 200;
        _redirect_flag = false;
        _redirect_url.clear();
        _body.clear();
        _headers.clear();
    }
    void SetHeader(const std::string &key, const std::string &value) { _headers[key] = value; }
    bool HasHeader(const std::string &key) {
        auto it = _headers.find(key);
        if (it == _headers.end())
            return false;
        return true;
    }
    std::string GetHeader(const std::string &key) {
        auto it = _headers.find(key);
        if (it != _headers.end())
            return it->second;
        return "";
    }
    void SetContent(const std::string &body, const std::string &content_type) {
        _body = body;
        SetHeader("Content-Length", std::to_string(_body.size()));
        SetHeader("Content-Type", content_type);
    }
    void SetRedirect(const std::string &url, int status_code = 302) {
        _redirect_flag = true;
        _redirect_url = url;
        _status_code = status_code;
        SetHeader("Location", url);
    }
    bool Close() {
        if (HasHeader("Connection") && GetHeader("Connection") == "keep-alive")
            return false;  // 长连接
        return true;
    }
};

HttpContext模块

本模块是处理对端发送给服务端的数据,采取状态机的策略来把处理分为四个截断,分别是处理请求行、处理请求头、处理请求正文和处理完毕。只有完成上一个阶段才可以进入下一个阶段,否则在下个阶段会被检测处上个截断并未处理完成从而返回处理错误。并且该模块实现了粘包/半包友好,如果数据不够可以等待数组足够再重新进行处理。

cpp 复制代码
typedef enum {
    RECV_STATE_ERROR,
    RECV_STATE_LINE,
    RECV_STATE_HEAD,
    RECV_STATE_BODY,
    RECV_STATE_OVER
} HttpRecvState;
#define MAX_LINE 8192
class HttpContext {
private:
    int _resp_state;            // 响应状态码
    HttpRecvState _recv_state;  // 接收状态码
    HttpRequest _request;

private:
    bool RecvHttpLine(Buffer *buf)  // 从Buffer中读取数据
    {
        std::string line = buf->GetLine();
        if (line.size() == 0)  // 说明没有读取到数据
        {
            // 存在数据但是没有换行符
            if (buf->ReadAbleSize() > MAX_LINE) {
                // 请求行过于长,不符合要求
                _resp_state = 414;
                _recv_state = RECV_STATE_ERROR;
                return false;
            }
        }
        buf->MoveReadOffset(line.size());
        if (line.size() > MAX_LINE) {
            // 说明请求行过长
            _resp_state = 414;
            _recv_state = RECV_STATE_ERROR;
            return false;
        }
        return ParseHttpLine(line);
    }
    bool ParseHttpLine(std::string &line) {
        std::smatch matches;
        // std::regex e("(GET|HEAD|POST|PUT|DELETE) ([^?]*)(?:\\?(.*))? (HTTP/1.\\.[01])(?:\n|\r\n)?");  //
        // 忽略大小写
        std::regex e(
            R"(^(GET|HEAD|POST|PUT|DELETE)\s+([^\s\?]+)(?:\?([^\s#]*))?\s+(HTTP/1\.[01])\r?\n?$)",
            std::regex::icase);
        bool ret = regex_match(line, matches, e);
        if (ret == false) {
            _resp_state = 400;
            _recv_state = RECV_STATE_ERROR;
            return false;
        }
        // 0 原url
        // 1 GET 方法
        // 2 url
        // 3 key=val&key=val......
        // 4 HTTP版本
        _request._method = matches[1];
        std::transform(_request._method.begin(), _request._method.end(), _request._method.begin(), ::toupper);
        _request._url = Util::UrlDecode(matches[2], false);
        _request._version = matches[4];
        std::vector<std::string> query_string_array;
        std::string query_string = matches[3];
        // 将键值对分组方便分离key val
        Util::SplitString(query_string, "&", &query_string_array);
        for (auto &str : query_string_array) {
            size_t pos = str.find("=");
            if (pos == std::string::npos) {
                _resp_state = 400;
                _recv_state = RECV_STATE_ERROR;
                return false;
            }
            std::string key = Util::UrlDecode(str.substr(0, pos), true);
            std::string val = Util::UrlDecode(str.substr(pos + 1), true);
            _request.SetParam(key, val);
        }
        _recv_state = RECV_STATE_HEAD;
        return true;
    }
    bool RecvHttpHead(Buffer *buf) {
        // 一行一行取出数据即可
        if (_recv_state != RECV_STATE_HEAD)
            return false;
        while (true) {
            std::string line = buf->GetLine();
            if (line.size() == 0)  // 说明没有读取到数据
            {
                // 存在数据但是没有换行符
                if (buf->ReadAbleSize() > MAX_LINE) {
                    // 请求行过于长,不符合要求
                    _resp_state = 414;
                    _recv_state = RECV_STATE_ERROR;
                    return false;
                }
            }
            buf->MoveReadOffset(line.size());
            if (line.size() > MAX_LINE) {
                // 说明请求行过长
                _resp_state = 414;
                _recv_state = RECV_STATE_ERROR;
                return false;
            }
            if (line == "\n" || line == "\r\n") {
                _recv_state = RECV_STATE_BODY;
                return true;  // 说明头部读取完毕
            }
            int ret = ParseHttpHead(line);
            if (ret == false)
                return false;
        }
        return true;
    }
    // 解析请求头
    bool ParseHttpHead(std::string &line) {
        if (line.back() == '\n')
            line.pop_back();
        if (line.back() == '\r')
            line.pop_back();
        size_t pos = line.find(": ");
        if (pos == std::string::npos) {
            _resp_state = 400;
            _recv_state = RECV_STATE_ERROR;
            return false;
        }
        std::string key = line.substr(0, pos);
        std::string val = line.substr(pos + 2);
        _request.SetHeader(key, val);
        //_recv_state = RECV_STATE_BODY;
        return true;
    }
    bool RecvHttpBody(Buffer *buf) {
        if (_recv_state != RECV_STATE_BODY)
            return false;
        // 1.判断该请求是否有正文
        if (_request.ContentLength() == 0) {
            _recv_state = RECV_STATE_OVER;
            return true;
        }
        // 2.存在正文 需要判断Buffer中是否存在足够的数据读取
        // 用ContentLength对应的长度减去目前正文就是还需要读取的正文长度
        size_t rest_len = _request.ContentLength() - _request._body.size();
        if (buf->ReadAbleSize() >= rest_len) {
            // Buffer中的数据大小大于需要读取的正文长度,说明可以读完
            _request._body.append(buf->ReadPosition(), rest_len);
            buf->MoveReadOffset(rest_len);
            _recv_state = RECV_STATE_OVER;
            return true;
        }
        // 说明正文还没有完整读取,需要等待下次数据到达,接收状态保持为RECV_STATE_BODY
        _request._body.append(buf->ReadPosition(), buf->ReadAbleSize());
        buf->MoveReadOffset(buf->ReadAbleSize());
        return true;
    }

public:
    HttpContext()
            : _resp_state(200)
            , _recv_state(RECV_STATE_LINE) {}
    void ReSet() {
        _resp_state = 200;
        _recv_state = RECV_STATE_LINE;
        _request.Reset();
    }
    int RespState() { return _resp_state; }            // 获取请求的响应状态码
    HttpRecvState RecvState() { return _recv_state; }  // 获取当前处理进度
    HttpRequest &Request() { return _request; }        // 获取http请求
    // 接收并解析http请求
    bool RecvHttpRequest(Buffer *buf) {
        // 无需使用break,因为需要顺序向下执行
        switch (_recv_state) {
        case RECV_STATE_LINE:
            RecvHttpLine(buf);
        case RECV_STATE_HEAD:
            RecvHttpHead(buf);
        case RECV_STATE_BODY:
            RecvHttpBody(buf);
        case RECV_STATE_OVER:
            break;
        case RECV_STATE_ERROR:
            return false;
        }
        return true;
    }
};

HttpServer模块

本模块把 TCP 连接上的字节流包装成 HTTP 服务"的顶层组件:负责连接生命周期、HTTP 解析驱动、路由分发、静态资源托管,以及把 HttpResponse 序列化成标准 HTTP/1.1 响应发送出去。

cpp 复制代码
#define DEFAULT_TIMEOUT 10
class HttpServer {
    using PtrConnection = std::shared_ptr<Connection>;

private:
    TcpServer _server;
    using Handler = const std::function<void(const HttpRequest &, HttpResponse *)>;
    using Handlers = std::vector<std::pair<std::regex, Handler>>;
    Handlers _get_route;
    Handlers _post_route;
    Handlers _put_route;
    Handlers _delete_route;
    std::string _base_dir;

private:
    void ErrorHandler(const HttpRequest &req, HttpResponse *resp) {
        // 1.组织错误页面内容
        std::string body;
        body += "<html>";
        body += "<head>";
        body += "<meta http-equiv=\"Content-Type\" content=\"text/html; charset=UTF-8\" />";
        body += "</head>";
        body += "<body>";
        body += "<h1>";
        body += std::to_string(resp->_status_code) + " " + Util::StatusDesc(resp->_status_code);
        body += "</h1>";
        body += "<hr/>";
        body += "</body>";
        body += "</html>";
        // 2.为响应对象填充内容
        resp->SetContent(body, "text/html");
    }
    // 对功能性事件进行派发
    void Dispatcher(HttpRequest &req, HttpResponse *resp, Handlers &handlers) {
        // 在对应的方法路由表中进行查找,如果存在匹配的路由则调用对应的处理函数,反之则返回404
        for (auto &handler : handlers) {
            const std::regex re = handler.first;
            const Handler &functor = handler.second;
            bool ret = std::regex_match(req._url, req._matches, re);  // 将匹配结果存入req._matches中
            if (ret == false)
                continue;
            return functor(req, resp);
        }
        resp->_status_code = 404;  // 未找到
        return;
    }
    // 路由
    void Route(HttpRequest &req, HttpResponse *resp) {
        // 1.对请求进行分辨,判断是静态资源请求还是功能性请求
        if (IsFileHandler(req) == true) {
            return FileHandler(req, resp);
        }
        if (req._method == "GET") {
            return Dispatcher(req, resp, _get_route);
        } else if (req._method == "POST") {
            return Dispatcher(req, resp, _post_route);
        } else if (req._method == "PUT") {
            return Dispatcher(req, resp, _put_route);
        } else if (req._method == "DELETE") {
            return Dispatcher(req, resp, _delete_route);
        }

        resp->_status_code = 405;  // 方法不被允许
        return;
    }
    bool IsFileHandler(const HttpRequest &req) {
        // 静态资源的路径必须已经设置
        if (_base_dir.empty() == true)
            return false;
        // 判断请求的方法
        if (req._method != "GET" && req._method != "HEAD")
            return false;
        // 判断请求的url是否合法
        if (Util::ValidPath(req._url) == false)
            return false;
        std::string real_path = _base_dir + req._url;
        if (!real_path.empty() && real_path.back() == '/')
            real_path += "index.html";
        // 判断该路径是否存在且是一个普通文件
        if (Util::IsRegular(real_path) == false)
            return false;

        return true;
    }
    // 对静态资源获取
    void FileHandler(HttpRequest &req, HttpResponse *resp) {
        std::string req_path = _base_dir + req._url;
        if (req._url.back() == '/')
            req_path += "index.html";
        req._url = req_path;
        bool ret = Util::ReadFile(req._url, &resp->_body);
        if (ret == false) {
            resp->_status_code = 404;  // 未找到
            return;
        }
        std::string mime = Util::ExternMime(req._url);
        resp->SetHeader("Content-Type", mime);
        return;
    }
    // 生成http Response格式的内容进行发送
    void WriteResponse(const PtrConnection &conn, HttpRequest &req, HttpResponse *resp) {
        // 1.完善头部字段
        if (req.Close() == true) {
            resp->SetHeader("Connection", "close");
        } else {
            resp->SetHeader("Connection", "keep-alive");
        }
        if (resp->HasHeader("Content-Length") == false) {
            resp->SetHeader("Content-Length", std::to_string(resp->_body.size()));
        }
        if (resp->_body.empty() == false && resp->HasHeader("Content-Type") == false) {
            resp->SetHeader("Content-Type", "application/octet-stream");
        }
        if (resp->_redirect_flag == true && resp->HasHeader("Location") == false) {
            resp->SetHeader("Location", resp->_redirect_url);
        }
        // 2.根据http协议的格式来组织内容
        std::stringstream resp_str;
        resp_str << "HTTP/1.1 " << resp->_status_code << " " << Util::StatusDesc(resp->_status_code)
                 << "\r\n";
        for (auto &header : resp->_headers) { resp_str << header.first << ": " << header.second << "\r\n"; }
        resp_str << "\r\n";  // 头部和正文之间需要有一个空行
        resp_str << resp->_body;
        LOG_INFO("Response:%s", resp_str.str().c_str());
        // 3.发送Response
        conn->Send(resp_str.str().c_str(), resp_str.str().size());
    }
    // 设置上下文
    void OnConnected(const PtrConnection &conn) {
        conn->SetContext(HttpContext());
        LOG_INFO("A new Connection:%p", conn.get());
    }
    // 对缓冲区数据进行分析和处理
    void OnMessage(const PtrConnection &conn, Buffer *buf) {
        // 1.获取上下文
        HttpContext *context = conn->GetContext()->get<HttpContext>();
        // 2.通过上下文对缓冲区的数据进行解析
        context->RecvHttpRequest(buf);

        HttpResponse resp(context->RespState());
        HttpRequest &req = context->Request();
        if (context->RespState() >= 400) {
            // 说明出错
            ErrorHandler(req, &resp);  // 为错误页面填充信息
            WriteResponse(conn, req, &resp);
            context->ReSet();//重置上下文,防止下次请求受到影响
            buf->MoveReadOffset(buf->ReadAbleSize()); // 清空缓冲区数据
            conn->ShutDown();
            return;
        }
        if (context->RecvState() != RECV_STATE_OVER) {
            // 说明数据还未接收完毕,还需要继续接收新数据
            return;
        }
        // 3.请求路由 + 业务分配
        Route(req, &resp);
        // 4.对HttpResponse进行发送
        WriteResponse(conn, req, &resp);
        // 5.重置上下文
        context->ReSet();
        // 6.判断长短连接
        if (resp.Close() == true)
            conn->ShutDown();
    }

public:
    HttpServer(uint16_t port, int timeout = DEFAULT_TIMEOUT)
            : _server(port) {
        _server.SetConnectionCallback(std::bind(&HttpServer::OnConnected, this, std::placeholders::_1));
        _server.SetMessageCallback(
            std::bind(&HttpServer::OnMessage, this, std::placeholders::_1, std::placeholders::_2));
        _server.EnableInactiveRelease(timeout);
    }
    void Get(const std::string &pattern, Handler &Handler) {
        _get_route.push_back(make_pair(std::regex(pattern), Handler));
    }
    void Post(const std::string &pattern, Handler &Handler) {
        _post_route.push_back(make_pair(std::regex(pattern), Handler));
    }
    void Put(const std::string &pattern, Handler &Handler) {
        _put_route.push_back(make_pair(std::regex(pattern), Handler));
    }
    void Delete(const std::string &pattern, Handler &Handler) {
        _delete_route.push_back(make_pair(std::regex(pattern), Handler));
    }
    void SetBaseDir(const std::string &filepath) { _base_dir = filepath; }
    void SetThreadCount(int count) { _server.SetThreadCount(count); }
    void Listen() { _server.Start(); }
};
相关推荐
yesyesyoucan2 小时前
安全工具集:一站式密码生成、文件加密与二维码生成解决方案
服务器·mysql·安全
小豆子范德萨2 小时前
cursor连接远程window服务器的WSL-ubuntu
运维·服务器·ubuntu
lizz317 小时前
C++模板编程:从入门到精通
java·开发语言·c++
Queenie_Charlie8 小时前
HASH表
数据结构·c++·哈希算法
Xの哲學8 小时前
Linux grep命令:文本搜索的艺术与科学
linux·服务器·算法·架构·边缘计算
superman超哥9 小时前
仓颉语言中锁的实现机制深度剖析与并发实践
c语言·开发语言·c++·python·仓颉
夜月yeyue9 小时前
Linux 调度类(sched_class)
linux·运维·c语言·单片机·性能优化
郝学胜-神的一滴9 小时前
OpenGL的glDrawElements函数详解
开发语言·c++·程序人生·游戏·图形渲染
WBluuue9 小时前
AtCoder Beginner Contest 436(ABCDEF)
c++·算法