cpp
复制代码
#include <liburing.h>
#include <linux/if_tun.h>
#include <net/if.h>
#include <sys/ioctl.h>
#include <sys/uio.h>
#include <fcntl.h>
#include <unistd.h>
#include <cstring>
#include <atomic>
#include <thread>
#include <functional>
#include <memory>
#include <chrono>
#include <cstdint>
#include <mutex>
#include <condition_variable>
#include <queue>
#include <cstdio>
#include <cassert>
#include "blockingconcurrentqueue.h"
using Byte = uint8_t;
class IoUringTap {
public:
explicit IoUringTap(int user_tun_fd) noexcept;
~IoUringTap() noexcept;
bool Send(std::shared_ptr<Byte[]> packet, int length) noexcept;
bool Send(const void* packet, int length) noexcept;
void SetPacketInput(std::function<void(void* packet, int packet_length)> callback) noexcept;
void Dispose() noexcept;
[[nodiscard]] bool IsValid() const noexcept {
return ring_initialized_.load(std::memory_order_acquire) && !disposed_.load(std::memory_order_acquire);
}
private:
IoUringTap(const IoUringTap&) = delete;
IoUringTap& operator=(const IoUringTap&) = delete;
void Loopback() noexcept;
bool SubmitRead(size_t idx) noexcept;
bool SubmitReadWithRetry(size_t idx, int max_retries = 5) noexcept;
void SubmitWrites() noexcept;
static constexpr size_t RECV_BUFFER_SIZE = 4096;
static constexpr size_t NUM_RECV_BUFFERS = 512;
static constexpr size_t QUEUE_DEPTH = 16384;
static constexpr int MAX_INFLIGHT_WRITES = 1024;
static constexpr size_t SEND_BATCH_SIZE = 128;
static constexpr int IDLE_SLEEP_US = 10;
static constexpr int CLEANUP_TIMEOUT_MS = 5000;
static constexpr size_t MAX_SEND_QUEUE_SIZE = 10000;
static constexpr size_t PENDING_MAX = 64;
struct RecvBuffer { char data[RECV_BUFFER_SIZE]; };
RecvBuffer recv_buffers_[NUM_RECV_BUFFERS];
struct WriteContext {
std::shared_ptr<Byte[]> data;
size_t len;
};
moodycamel::BlockingConcurrentQueue<WriteContext*> send_queue_;
std::queue<size_t> pending_reads_;
struct io_uring ring_{};
int tun_fd_ = -1; // 私有 fd
std::atomic<bool> disposed_{false};
std::thread ring_thread_;
std::function<void(void*, int)> packet_input_;
std::mutex packet_input_mutex_;
struct iovec recv_iovecs_[NUM_RECV_BUFFERS];
std::atomic<bool> ring_initialized_{false};
bool fixed_buffers_registered_ = false;
size_t read_inflight_ = 0;
int inflight_writes_ = 0;
std::atomic<bool> ready_{false};
std::condition_variable cv_;
std::mutex init_mutex_;
// 使用 data64 避免指针标记冲突
static constexpr uint64_t READ_FLAG = 1ULL << 63;
static inline uint64_t EncodeRead(size_t idx) noexcept { return READ_FLAG | idx; }
static inline size_t DecodeRead(uint64_t data) noexcept { return data & ~READ_FLAG; }
static inline bool IsRead(uint64_t data) noexcept { return (data & READ_FLAG) != 0; }
};
IoUringTap::IoUringTap(int user_tun_fd) noexcept {
if (user_tun_fd < 0) {
disposed_.store(true, std::memory_order_release);
ready_.store(true, std::memory_order_release);
cv_.notify_one();
return;
}
tun_fd_ = dup(user_tun_fd);
if (tun_fd_ == -1) {
disposed_.store(true, std::memory_order_release);
ready_.store(true, std::memory_order_release);
cv_.notify_one();
return;
}
try {
ring_thread_ = std::thread(&IoUringTap::Loopback, this);
} catch (...) {
if (tun_fd_ != -1) {
close(tun_fd_);
tun_fd_ = -1;
}
disposed_.store(true, std::memory_order_release);
ring_initialized_.store(false, std::memory_order_release);
ready_.store(true, std::memory_order_release);
cv_.notify_one();
return;
}
std::unique_lock<std::mutex> lock(init_mutex_);
cv_.wait(lock, [this] { return ready_.load(std::memory_order_acquire); });
}
IoUringTap::~IoUringTap() noexcept {
Dispose(); // 会关闭 fd 并等待线程结束
}
bool IoUringTap::Send(std::shared_ptr<Byte[]> packet, int length) noexcept {
if (disposed_.load(std::memory_order_acquire) || !packet || length <= 0 || length > 65535)
return false;
if (send_queue_.size_approx() >= MAX_SEND_QUEUE_SIZE)
return false;
auto* ctx = new (std::nothrow) WriteContext{std::move(packet), static_cast<size_t>(length)};
if (!ctx) return false;
if (!send_queue_.enqueue(ctx)) {
delete ctx;
return false;
}
return true;
}
bool IoUringTap::Send(const void* packet, int length) noexcept {
if (disposed_.load(std::memory_order_acquire) || !packet || length <= 0 || length > 65535)
return false;
Byte* raw = new (std::nothrow) Byte[static_cast<size_t>(length)];
if (!raw) return false;
std::memcpy(raw, packet, static_cast<size_t>(length));
return Send(std::shared_ptr<Byte[]>(raw, std::default_delete<Byte[]>()), length);
}
void IoUringTap::SetPacketInput(std::function<void(void*, int)> callback) noexcept {
std::lock_guard<std::mutex> lock(packet_input_mutex_);
packet_input_ = std::move(callback);
// 回调必须在本帧内深拷贝数据,因为 io_uring 复用缓冲区
}
void IoUringTap::Dispose() noexcept {
bool expected = false;
if (!disposed_.compare_exchange_strong(expected, true, std::memory_order_acq_rel))
return;
// 关闭 fd,强制未完成的 IO 快速返回错误,加速清理
if (tun_fd_ != -1) {
close(tun_fd_);
tun_fd_ = -1;
}
if (ring_thread_.joinable())
ring_thread_.join();
}
bool IoUringTap::SubmitRead(size_t idx) noexcept {
if (disposed_.load(std::memory_order_acquire)) return false;
struct io_uring_sqe* sqe = io_uring_get_sqe(&ring_);
if (!sqe) return false;
if (fixed_buffers_registered_) {
io_uring_prep_read_fixed(sqe, tun_fd_, recv_buffers_[idx].data, RECV_BUFFER_SIZE, 0, static_cast<int>(idx));
} else {
io_uring_prep_read(sqe, tun_fd_, recv_buffers_[idx].data, RECV_BUFFER_SIZE, 0);
}
io_uring_sqe_set_data64(sqe, EncodeRead(idx));
read_inflight_++;
return true;
}
bool IoUringTap::SubmitReadWithRetry(size_t idx, int max_retries) noexcept {
for (int r = 0; r < max_retries; ++r) {
if (disposed_.load(std::memory_order_acquire)) return false;
if (SubmitRead(idx)) return true;
std::this_thread::sleep_for(std::chrono::microseconds(IDLE_SLEEP_US));
}
return false;
}
void IoUringTap::SubmitWrites() noexcept {
if (inflight_writes_ >= MAX_INFLIGHT_WRITES) return;
WriteContext* batch[SEND_BATCH_SIZE];
size_t max_can = static_cast<size_t>(MAX_INFLIGHT_WRITES - inflight_writes_);
size_t num = send_queue_.try_dequeue_bulk(batch, std::min(SEND_BATCH_SIZE, max_can));
if (num == 0) return;
for (size_t i = 0; i < num; ++i) {
struct io_uring_sqe* sqe = io_uring_get_sqe(&ring_);
if (!sqe) {
// 回滚未提交的 ctx
for (size_t j = i; j < num; ++j)
send_queue_.enqueue(batch[j]);
return;
}
io_uring_prep_write(sqe, tun_fd_, batch[i]->data.get(), batch[i]->len, 0);
io_uring_sqe_set_data64(sqe, reinterpret_cast<uint64_t>(batch[i]));
inflight_writes_++;
}
}
void IoUringTap::Loopback() noexcept {
bool ring_inited = false;
struct io_uring_params params{};
params.flags = IORING_SETUP_SINGLE_ISSUER | IORING_SETUP_COOP_TASKRUN;
int ret = io_uring_queue_init_params(QUEUE_DEPTH, &ring_, ¶ms);
if (ret < 0) {
params.flags = 0;
ret = io_uring_queue_init_params(QUEUE_DEPTH, &ring_, ¶ms);
if (ret < 0) {
fprintf(stderr, "[IoUringTap] io_uring init failed: %d\n", -ret);
goto init_fail;
}
}
ring_inited = true;
for (size_t i = 0; i < NUM_RECV_BUFFERS; ++i) {
recv_iovecs_[i].iov_base = recv_buffers_[i].data;
recv_iovecs_[i].iov_len = RECV_BUFFER_SIZE;
}
fixed_buffers_registered_ = (io_uring_register_buffers(&ring_, recv_iovecs_, NUM_RECV_BUFFERS) == 0);
if (!fixed_buffers_registered_)
fprintf(stderr, "[IoUringTap] Fixed buffers fallback (still fast & safe).\n");
for (size_t i = 0; i < NUM_RECV_BUFFERS; ++i) {
if (!SubmitReadWithRetry(i)) {
if (pending_reads_.size() < PENDING_MAX)
pending_reads_.push(i);
}
}
if (io_uring_submit(&ring_) < 0) {
fprintf(stderr, "[IoUringTap] Initial submit failed\n");
goto partial_fail;
}
ring_initialized_.store(true, std::memory_order_release);
ready_.store(true, std::memory_order_release);
cv_.notify_one();
while (!disposed_.load(std::memory_order_acquire)) {
struct io_uring_cqe* cqe;
unsigned head, completed = 0;
io_uring_for_each_cqe(&ring_, head, cqe) {
++completed;
uint64_t data = io_uring_cqe_get_data64(cqe);
ssize_t res = cqe->res;
if (IsRead(data)) { // 读路径
size_t idx = DecodeRead(data);
if (read_inflight_ > 0) read_inflight_--;
std::function<void(void*, int)> cb;
{
std::lock_guard<std::mutex> lock(packet_input_mutex_);
cb = packet_input_;
}
if (res > 0 && cb)
cb(recv_buffers_[idx].data, static_cast<int>(res));
else if (res < 0)
fprintf(stderr, "[IoUringTap] Read err: %zd\n", res);
if (!disposed_.load(std::memory_order_acquire)) {
if (!SubmitRead(idx) && pending_reads_.size() < PENDING_MAX)
pending_reads_.push(idx);
}
} else { // 写路径
auto* ctx = reinterpret_cast<WriteContext*>(static_cast<uintptr_t>(data));
if (res < 0)
fprintf(stderr, "[IoUringTap] Write err: %zd\n", res);
delete ctx;
if (inflight_writes_ > 0) inflight_writes_--;
}
}
io_uring_cq_advance(&ring_, completed);
if (!disposed_.load(std::memory_order_acquire)) {
size_t attempts = 0;
while (!pending_reads_.empty() && attempts < 64) {
size_t idx = pending_reads_.front();
if (SubmitRead(idx))
pending_reads_.pop();
else
break;
++attempts;
}
SubmitWrites();
} else {
while (!pending_reads_.empty())
pending_reads_.pop();
}
int sub = io_uring_submit(&ring_);
if (sub < 0 && sub != -EBUSY)
fprintf(stderr, "[IoUringTap] submit err: %d\n", sub);
bool idle = (inflight_writes_ == 0 && send_queue_.size_approx() == 0 &&
pending_reads_.empty() && io_uring_sq_ready(&ring_) == 0);
if (idle) {
std::this_thread::sleep_for(std::chrono::microseconds(IDLE_SLEEP_US));
} else {
std::this_thread::yield();
}
}
// ----- 清理阶段 -----
// 1. 清空发送队列
WriteContext* ctx;
while (send_queue_.try_dequeue(ctx))
delete ctx;
// 2. 提交所有剩余的 SQE(如果有)
if (io_uring_sq_ready(&ring_) > 0)
io_uring_submit(&ring_);
// 3. 等待所有飞行中的操作完成(fd 已关闭,会快速返回错误)
auto start = std::chrono::steady_clock::now();
while ((read_inflight_ > 0 || inflight_writes_ > 0) &&
std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - start).count() < CLEANUP_TIMEOUT_MS) {
unsigned head, done = 0;
io_uring_for_each_cqe(&ring_, head, cqe) {
++done;
uint64_t data = io_uring_cqe_get_data64(cqe);
if (IsRead(data)) {
if (read_inflight_ > 0) --read_inflight_;
} else {
delete reinterpret_cast<WriteContext*>(static_cast<uintptr_t>(data));
if (inflight_writes_ > 0) --inflight_writes_;
}
}
io_uring_cq_advance(&ring_, done);
if (done == 0)
break; // 没有新完成,但可能还有未完成的,下次循环再试
io_uring_submit(&ring_); // 促使内核推进
std::this_thread::yield();
}
// 如果超时后仍有未完成的,忽略(进程退出时会回收,但理论上不应发生)
if (fixed_buffers_registered_)
io_uring_unregister_buffers(&ring_);
if (ring_inited)
io_uring_queue_exit(&ring_);
return;
partial_fail:
init_fail:
if (fixed_buffers_registered_)
io_uring_unregister_buffers(&ring_);
if (ring_inited)
io_uring_queue_exit(&ring_);
disposed_.store(true, std::memory_order_release);
ring_initialized_.store(false, std::memory_order_release);
ready_.store(true, std::memory_order_release);
cv_.notify_one();
}