LINUX RING BUFFER TUN/TAP 2

cpp 复制代码
#include <liburing.h>
#include <linux/if_tun.h>
#include <net/if.h>
#include <sys/ioctl.h>
#include <sys/uio.h>
#include <fcntl.h>
#include <unistd.h>
#include <cstring>
#include <atomic>
#include <thread>
#include <functional>
#include <memory>
#include <chrono>
#include <cstdint>
#include <mutex>
#include <condition_variable>
#include <queue>
#include <cstdio>
#include <cassert>
#include "blockingconcurrentqueue.h"

using Byte = uint8_t;

class IoUringTap {
public:
    explicit IoUringTap(int user_tun_fd) noexcept;
    ~IoUringTap() noexcept;

    bool Send(std::shared_ptr<Byte[]> packet, int length) noexcept;
    bool Send(const void* packet, int length) noexcept;
    void SetPacketInput(std::function<void(void* packet, int packet_length)> callback) noexcept;
    void Dispose() noexcept;
    [[nodiscard]] bool IsValid() const noexcept {
        return ring_initialized_.load(std::memory_order_acquire) && !disposed_.load(std::memory_order_acquire);
    }

private:
    IoUringTap(const IoUringTap&) = delete;
    IoUringTap& operator=(const IoUringTap&) = delete;

    void Loopback() noexcept;
    bool SubmitRead(size_t idx) noexcept;
    bool SubmitReadWithRetry(size_t idx, int max_retries = 5) noexcept;
    void SubmitWrites() noexcept;

    static constexpr size_t RECV_BUFFER_SIZE = 4096;
    static constexpr size_t NUM_RECV_BUFFERS = 512;
    static constexpr size_t QUEUE_DEPTH = 16384;
    static constexpr int MAX_INFLIGHT_WRITES = 1024;
    static constexpr size_t SEND_BATCH_SIZE = 128;
    static constexpr int IDLE_SLEEP_US = 10;
    static constexpr int CLEANUP_TIMEOUT_MS = 5000;
    static constexpr size_t MAX_SEND_QUEUE_SIZE = 10000;
    static constexpr size_t PENDING_MAX = 64;

    struct RecvBuffer { char data[RECV_BUFFER_SIZE]; };
    RecvBuffer recv_buffers_[NUM_RECV_BUFFERS];

    struct WriteContext {
        std::shared_ptr<Byte[]> data;
        size_t len;
    };

    moodycamel::BlockingConcurrentQueue<WriteContext*> send_queue_;
    std::queue<size_t> pending_reads_;

    struct io_uring ring_{};
    int tun_fd_ = -1;  // 私有 fd
    std::atomic<bool> disposed_{false};
    std::thread ring_thread_;
    std::function<void(void*, int)> packet_input_;
    std::mutex packet_input_mutex_;
    struct iovec recv_iovecs_[NUM_RECV_BUFFERS];
    std::atomic<bool> ring_initialized_{false};
    bool fixed_buffers_registered_ = false;
    size_t read_inflight_ = 0;
    int inflight_writes_ = 0;
    std::atomic<bool> ready_{false};
    std::condition_variable cv_;
    std::mutex init_mutex_;

    // 使用 data64 避免指针标记冲突
    static constexpr uint64_t READ_FLAG = 1ULL << 63;
    static inline uint64_t EncodeRead(size_t idx) noexcept { return READ_FLAG | idx; }
    static inline size_t DecodeRead(uint64_t data) noexcept { return data & ~READ_FLAG; }
    static inline bool IsRead(uint64_t data) noexcept { return (data & READ_FLAG) != 0; }
};

IoUringTap::IoUringTap(int user_tun_fd) noexcept {
    if (user_tun_fd < 0) {
        disposed_.store(true, std::memory_order_release);
        ready_.store(true, std::memory_order_release);
        cv_.notify_one();
        return;
    }

    tun_fd_ = dup(user_tun_fd);
    if (tun_fd_ == -1) {
        disposed_.store(true, std::memory_order_release);
        ready_.store(true, std::memory_order_release);
        cv_.notify_one();
        return;
    }

    try {
        ring_thread_ = std::thread(&IoUringTap::Loopback, this);
    } catch (...) {
        if (tun_fd_ != -1) {
            close(tun_fd_);
            tun_fd_ = -1;
        }
        disposed_.store(true, std::memory_order_release);
        ring_initialized_.store(false, std::memory_order_release);
        ready_.store(true, std::memory_order_release);
        cv_.notify_one();
        return;
    }

    std::unique_lock<std::mutex> lock(init_mutex_);
    cv_.wait(lock, [this] { return ready_.load(std::memory_order_acquire); });
}

IoUringTap::~IoUringTap() noexcept {
    Dispose();  // 会关闭 fd 并等待线程结束
}

bool IoUringTap::Send(std::shared_ptr<Byte[]> packet, int length) noexcept {
    if (disposed_.load(std::memory_order_acquire) || !packet || length <= 0 || length > 65535)
        return false;
    if (send_queue_.size_approx() >= MAX_SEND_QUEUE_SIZE)
        return false;

    auto* ctx = new (std::nothrow) WriteContext{std::move(packet), static_cast<size_t>(length)};
    if (!ctx) return false;
    if (!send_queue_.enqueue(ctx)) {
        delete ctx;
        return false;
    }
    return true;
}

bool IoUringTap::Send(const void* packet, int length) noexcept {
    if (disposed_.load(std::memory_order_acquire) || !packet || length <= 0 || length > 65535)
        return false;

    Byte* raw = new (std::nothrow) Byte[static_cast<size_t>(length)];
    if (!raw) return false;
    std::memcpy(raw, packet, static_cast<size_t>(length));
    return Send(std::shared_ptr<Byte[]>(raw, std::default_delete<Byte[]>()), length);
}

void IoUringTap::SetPacketInput(std::function<void(void*, int)> callback) noexcept {
    std::lock_guard<std::mutex> lock(packet_input_mutex_);
    packet_input_ = std::move(callback);
    // 回调必须在本帧内深拷贝数据,因为 io_uring 复用缓冲区
}

void IoUringTap::Dispose() noexcept {
    bool expected = false;
    if (!disposed_.compare_exchange_strong(expected, true, std::memory_order_acq_rel))
        return;

    // 关闭 fd,强制未完成的 IO 快速返回错误,加速清理
    if (tun_fd_ != -1) {
        close(tun_fd_);
        tun_fd_ = -1;
    }

    if (ring_thread_.joinable())
        ring_thread_.join();
}

bool IoUringTap::SubmitRead(size_t idx) noexcept {
    if (disposed_.load(std::memory_order_acquire)) return false;
    struct io_uring_sqe* sqe = io_uring_get_sqe(&ring_);
    if (!sqe) return false;

    if (fixed_buffers_registered_) {
        io_uring_prep_read_fixed(sqe, tun_fd_, recv_buffers_[idx].data, RECV_BUFFER_SIZE, 0, static_cast<int>(idx));
    } else {
        io_uring_prep_read(sqe, tun_fd_, recv_buffers_[idx].data, RECV_BUFFER_SIZE, 0);
    }
    io_uring_sqe_set_data64(sqe, EncodeRead(idx));
    read_inflight_++;
    return true;
}

bool IoUringTap::SubmitReadWithRetry(size_t idx, int max_retries) noexcept {
    for (int r = 0; r < max_retries; ++r) {
        if (disposed_.load(std::memory_order_acquire)) return false;
        if (SubmitRead(idx)) return true;
        std::this_thread::sleep_for(std::chrono::microseconds(IDLE_SLEEP_US));
    }
    return false;
}

void IoUringTap::SubmitWrites() noexcept {
    if (inflight_writes_ >= MAX_INFLIGHT_WRITES) return;

    WriteContext* batch[SEND_BATCH_SIZE];
    size_t max_can = static_cast<size_t>(MAX_INFLIGHT_WRITES - inflight_writes_);
    size_t num = send_queue_.try_dequeue_bulk(batch, std::min(SEND_BATCH_SIZE, max_can));
    if (num == 0) return;

    for (size_t i = 0; i < num; ++i) {
        struct io_uring_sqe* sqe = io_uring_get_sqe(&ring_);
        if (!sqe) {
            // 回滚未提交的 ctx
            for (size_t j = i; j < num; ++j)
                send_queue_.enqueue(batch[j]);
            return;
        }
        io_uring_prep_write(sqe, tun_fd_, batch[i]->data.get(), batch[i]->len, 0);
        io_uring_sqe_set_data64(sqe, reinterpret_cast<uint64_t>(batch[i]));
        inflight_writes_++;
    }
}

void IoUringTap::Loopback() noexcept {
    bool ring_inited = false;
    struct io_uring_params params{};
    params.flags = IORING_SETUP_SINGLE_ISSUER | IORING_SETUP_COOP_TASKRUN;
    int ret = io_uring_queue_init_params(QUEUE_DEPTH, &ring_, &params);
    if (ret < 0) {
        params.flags = 0;
        ret = io_uring_queue_init_params(QUEUE_DEPTH, &ring_, &params);
        if (ret < 0) {
            fprintf(stderr, "[IoUringTap] io_uring init failed: %d\n", -ret);
            goto init_fail;
        }
    }
    ring_inited = true;

    for (size_t i = 0; i < NUM_RECV_BUFFERS; ++i) {
        recv_iovecs_[i].iov_base = recv_buffers_[i].data;
        recv_iovecs_[i].iov_len = RECV_BUFFER_SIZE;
    }

    fixed_buffers_registered_ = (io_uring_register_buffers(&ring_, recv_iovecs_, NUM_RECV_BUFFERS) == 0);
    if (!fixed_buffers_registered_)
        fprintf(stderr, "[IoUringTap] Fixed buffers fallback (still fast & safe).\n");

    for (size_t i = 0; i < NUM_RECV_BUFFERS; ++i) {
        if (!SubmitReadWithRetry(i)) {
            if (pending_reads_.size() < PENDING_MAX)
                pending_reads_.push(i);
        }
    }

    if (io_uring_submit(&ring_) < 0) {
        fprintf(stderr, "[IoUringTap] Initial submit failed\n");
        goto partial_fail;
    }

    ring_initialized_.store(true, std::memory_order_release);
    ready_.store(true, std::memory_order_release);
    cv_.notify_one();

    while (!disposed_.load(std::memory_order_acquire)) {
        struct io_uring_cqe* cqe;
        unsigned head, completed = 0;
        io_uring_for_each_cqe(&ring_, head, cqe) {
            ++completed;
            uint64_t data = io_uring_cqe_get_data64(cqe);
            ssize_t res = cqe->res;

            if (IsRead(data)) {  // 读路径
                size_t idx = DecodeRead(data);
                if (read_inflight_ > 0) read_inflight_--;

                std::function<void(void*, int)> cb;
                {
                    std::lock_guard<std::mutex> lock(packet_input_mutex_);
                    cb = packet_input_;
                }

                if (res > 0 && cb)
                    cb(recv_buffers_[idx].data, static_cast<int>(res));
                else if (res < 0)
                    fprintf(stderr, "[IoUringTap] Read err: %zd\n", res);

                if (!disposed_.load(std::memory_order_acquire)) {
                    if (!SubmitRead(idx) && pending_reads_.size() < PENDING_MAX)
                        pending_reads_.push(idx);
                }
            } else {  // 写路径
                auto* ctx = reinterpret_cast<WriteContext*>(static_cast<uintptr_t>(data));
                if (res < 0)
                    fprintf(stderr, "[IoUringTap] Write err: %zd\n", res);
                delete ctx;
                if (inflight_writes_ > 0) inflight_writes_--;
            }
        }
        io_uring_cq_advance(&ring_, completed);

        if (!disposed_.load(std::memory_order_acquire)) {
            size_t attempts = 0;
            while (!pending_reads_.empty() && attempts < 64) {
                size_t idx = pending_reads_.front();
                if (SubmitRead(idx))
                    pending_reads_.pop();
                else
                    break;
                ++attempts;
            }
            SubmitWrites();
        } else {
            while (!pending_reads_.empty())
                pending_reads_.pop();
        }

        int sub = io_uring_submit(&ring_);
        if (sub < 0 && sub != -EBUSY)
            fprintf(stderr, "[IoUringTap] submit err: %d\n", sub);

        bool idle = (inflight_writes_ == 0 && send_queue_.size_approx() == 0 &&
                     pending_reads_.empty() && io_uring_sq_ready(&ring_) == 0);
        if (idle) {
            std::this_thread::sleep_for(std::chrono::microseconds(IDLE_SLEEP_US));
        } else {
            std::this_thread::yield();
        }
    }

    // ----- 清理阶段 -----
    // 1. 清空发送队列
    WriteContext* ctx;
    while (send_queue_.try_dequeue(ctx))
        delete ctx;

    // 2. 提交所有剩余的 SQE(如果有)
    if (io_uring_sq_ready(&ring_) > 0)
        io_uring_submit(&ring_);

    // 3. 等待所有飞行中的操作完成(fd 已关闭,会快速返回错误)
    auto start = std::chrono::steady_clock::now();
    while ((read_inflight_ > 0 || inflight_writes_ > 0) &&
           std::chrono::duration_cast<std::chrono::milliseconds>(
               std::chrono::steady_clock::now() - start).count() < CLEANUP_TIMEOUT_MS) {
        unsigned head, done = 0;
        io_uring_for_each_cqe(&ring_, head, cqe) {
            ++done;
            uint64_t data = io_uring_cqe_get_data64(cqe);
            if (IsRead(data)) {
                if (read_inflight_ > 0) --read_inflight_;
            } else {
                delete reinterpret_cast<WriteContext*>(static_cast<uintptr_t>(data));
                if (inflight_writes_ > 0) --inflight_writes_;
            }
        }
        io_uring_cq_advance(&ring_, done);
        if (done == 0)
            break;  // 没有新完成,但可能还有未完成的,下次循环再试
        io_uring_submit(&ring_);  // 促使内核推进
        std::this_thread::yield();
    }

    // 如果超时后仍有未完成的,忽略(进程退出时会回收,但理论上不应发生)
    if (fixed_buffers_registered_)
        io_uring_unregister_buffers(&ring_);
    if (ring_inited)
        io_uring_queue_exit(&ring_);
    return;

partial_fail:
init_fail:
    if (fixed_buffers_registered_)
        io_uring_unregister_buffers(&ring_);
    if (ring_inited)
        io_uring_queue_exit(&ring_);
    disposed_.store(true, std::memory_order_release);
    ring_initialized_.store(false, std::memory_order_release);
    ready_.store(true, std::memory_order_release);
    cv_.notify_one();
}
相关推荐
UnicornDev2 小时前
从零开始的C++编程之旅——第六篇:数组与字符串——批量数据的存储与处理
java·开发语言·算法
小陈工2 小时前
2026年3月23日技术资讯洞察:AI Agent失控,Claude Code引领AI编程新趋势
开发语言·数据库·人工智能·后端·python·性能优化·ai编程
妙蛙种子3112 小时前
【Java八股 |JUC并发编程类】线程
java·开发语言·后端·多线程·八股
qq_334903152 小时前
C++中的装饰器模式高级应用
开发语言·c++·算法
枫叶丹42 小时前
【HarmonyOS 6.0】Network Kit 深度解析:TLS 认证全面支持国密证书
开发语言·网络安全·华为·harmonyos
2401_851272992 小时前
编译器内建函数使用
开发语言·c++·算法
Chase_______2 小时前
【2026最新保姆级】VMware 安装与虚拟机创建指南 (Window版)
linux
Rhystt2 小时前
代码随想录算法训练营第五十五天|图论理论基础、深搜理论基础、98. 所有可达路径、广搜理论基础
数据结构·c++·算法·深度优先·图论
caimouse2 小时前
Node.js的http服务
开发语言