cpp
复制代码
#include <liburing.h>
#include <linux/if_tun.h>
#include <net/if.h>
#include <sys/ioctl.h>
#include <sys/uio.h>
#include <fcntl.h>
#include <unistd.h>
#include <cstring>
#include <atomic>
#include <thread>
#include <functional>
#include <memory>
#include <chrono>
#include <cstdint>
#include <mutex>
#include <condition_variable>
#include <queue>
#include <cstdio>
#include <cassert>
#include "blockingconcurrentqueue.h"
using Byte = uint8_t;
class IoUringTap {
public:
explicit IoUringTap(int user_tun_fd) noexcept;
~IoUringTap() noexcept;
bool Send(std::shared_ptr<Byte[]> packet, int length) noexcept;
bool Send(const void* packet, int length) noexcept;
void SetPacketInput(std::function<void(void* packet, int packet_length)> callback) noexcept;
void Dispose() noexcept;
[[nodiscard]] bool IsValid() const noexcept {
return ring_initialized_.load(std::memory_order_acquire) && !disposed_.load(std::memory_order_acquire);
}
private:
IoUringTap(const IoUringTap&) = delete;
IoUringTap& operator=(const IoUringTap&) = delete;
void Loopback() noexcept;
bool SubmitRead(size_t idx) noexcept;
bool SubmitReadWithRetry(size_t idx, int max_retries = 5) noexcept;
void SubmitWrites() noexcept;
static constexpr size_t RECV_BUFFER_SIZE = 4096;
static constexpr size_t NUM_RECV_BUFFERS = 512;
static constexpr size_t QUEUE_DEPTH = 16384;
static constexpr int MAX_INFLIGHT_WRITES = 1024;
static constexpr size_t SEND_BATCH_SIZE = 128;
static constexpr int IDLE_SLEEP_US = 10;
static constexpr int CLEANUP_TIMEOUT_MS = 5000;
static constexpr size_t MAX_SEND_QUEUE_SIZE = 10000;
static constexpr size_t PENDING_MAX = 64;
struct RecvBuffer { char data[RECV_BUFFER_SIZE]; };
RecvBuffer recv_buffers_[NUM_RECV_BUFFERS];
struct WriteContext {
std::shared_ptr<Byte[]> data;
size_t len;
};
moodycamel::BlockingConcurrentQueue<WriteContext*> send_queue_;
std::queue<size_t> pending_reads_;
struct io_uring ring_{};
int tun_fd_ = -1; // 我们自己 dup 的私有 fd
std::atomic<bool> disposed_{false};
std::thread ring_thread_;
std::function<void(void*, int)> packet_input_;
std::mutex packet_input_mutex_;
struct iovec recv_iovecs_[NUM_RECV_BUFFERS];
std::atomic<bool> ring_initialized_{false};
bool fixed_buffers_registered_ = false;
size_t read_inflight_ = 0;
int inflight_writes_ = 0;
std::atomic<bool> ready_{false};
std::condition_variable cv_;
std::mutex init_mutex_;
static constexpr uintptr_t READ_TAG_MASK = uintptr_t(1) << (sizeof(uintptr_t) * 8 - 1);
};
// ====================== 【宇宙级强化构造 / 析构】 ======================
IoUringTap::IoUringTap(int user_tun_fd) noexcept {
if (user_tun_fd < 0) {
disposed_.store(true, std::memory_order_release);
ready_.store(true, std::memory_order_release);
cv_.notify_one();
return;
}
tun_fd_ = dup(user_tun_fd); // FIXED: 私有拥有,外部 fd 完全安全
if (tun_fd_ == -1) {
disposed_.store(true, std::memory_order_release);
ready_.store(true, std::memory_order_release);
cv_.notify_one();
return;
}
try {
ring_thread_ = std::thread(&IoUringTap::Loopback, this);
} catch (...) {
if (tun_fd_ != -1) { close(tun_fd_); tun_fd_ = -1; }
disposed_.store(true, std::memory_order_release);
ring_initialized_.store(false, std::memory_order_release);
ready_.store(true, std::memory_order_release);
cv_.notify_one();
return;
}
std::unique_lock<std::mutex> lock(init_mutex_);
cv_.wait(lock, this noexcept { return ready_.load(std::memory_order_acquire); });
}
IoUringTap::~IoUringTap() noexcept {
Dispose();
if (tun_fd_ != -1) {
close(tun_fd_);
tun_fd_ = -1;
}
}
// ====================== 公共接口(零浪费 + 零抛出) ======================
bool IoUringTap::Send(std::shared_ptr<Byte[]> packet, int length) noexcept {
if (disposed_.load(std::memory_order_acquire) || !packet || length <= 0 || length > 65535)
return false;
if (send_queue_.size_approx() >= MAX_SEND_QUEUE_SIZE) return false;
auto* ctx = new (std::nothrow) WriteContext{std::move(packet), static_cast<size_t>(length)};
if (!ctx) return false;
if (!send_queue_.enqueue(ctx)) {
delete ctx;
return false;
}
return true;
}
bool IoUringTap::Send(const void* packet, int length) noexcept {
if (disposed_.load(std::memory_order_acquire) || !packet || length <= 0 || length > 65535)
return false;
Byte* raw = new (std::nothrow) Byte[static_cast<size_t>(length)];
if (!raw) return false;
std::memcpy(raw, packet, static_cast<size_t>(length));
return Send(std::shared_ptr<Byte[]>(raw, std::default_delete<Byte[]>()), length);
}
void IoUringTap::SetPacketInput(std::function<void(void*, int)> callback) noexcept {
std::lock_guard<std::mutex> lock(packet_input_mutex_);
packet_input_ = std::move(callback);
// 【铁律】回调必须在本帧内完成深拷贝!指针仅本次合法,io_uring 零拷贝复用缓冲区!
}
void IoUringTap::Dispose() noexcept {
bool expected = false;
if (!disposed_.compare_exchange_strong(expected, true, std::memory_order_acq_rel)) return;
if (ring_thread_.joinable()) ring_thread_.join();
}
// ====================== 私有方法(全部钢化) ======================
bool IoUringTap::SubmitRead(size_t idx) noexcept {
if (disposed_.load(std::memory_order_acquire)) return false;
struct io_uring_sqe* sqe = io_uring_get_sqe(&ring_);
if (!sqe) return false;
if (fixed_buffers_registered_) {
io_uring_prep_read_fixed(sqe, tun_fd_, recv_buffers_[idx].data, RECV_BUFFER_SIZE, 0, static_cast<int>(idx));
} else {
io_uring_prep_read(sqe, tun_fd_, recv_buffers_[idx].data, RECV_BUFFER_SIZE, 0);
}
io_uring_sqe_set_data(sqe, reinterpret_cast<void*>(idx | READ_TAG_MASK));
read_inflight_++;
return true;
}
bool IoUringTap::SubmitReadWithRetry(size_t idx, int max_retries) noexcept {
for (int r = 0; r < max_retries; ++r) {
if (disposed_.load(std::memory_order_acquire)) return false;
if (SubmitRead(idx)) return true;
std::this_thread::sleep_for(std::chrono::microseconds(IDLE_SLEEP_US));
}
return false;
}
void IoUringTap::SubmitWrites() noexcept {
if (inflight_writes_ >= MAX_INFLIGHT_WRITES) return;
WriteContext* batch[SEND_BATCH_SIZE];
size_t max_can = static_cast<size_t>(MAX_INFLIGHT_WRITES - inflight_writes_);
size_t num = send_queue_.try_dequeue_bulk(batch, std::min(SEND_BATCH_SIZE, max_can));
if (num == 0) return;
for (size_t i = 0; i < num; ++i) {
struct io_uring_sqe* sqe = io_uring_get_sqe(&ring_);
if (!sqe) {
for (size_t j = i; j < num; ++j) send_queue_.enqueue(batch[j]); // FIXED: 全部回滚
return;
}
io_uring_prep_write(sqe, tun_fd_, batch[i]->data.get(), batch[i]->len, 0);
io_uring_sqe_set_data(sqe, batch[i]); // pointer bit63=0,永不与 READ_TAG_MASK 碰撞
inflight_writes_++;
}
}
void IoUringTap::Loopback() noexcept {
bool ring_inited = false;
struct io_uring_params params{};
params.flags = IORING_SETUP_SINGLE_ISSUER | IORING_SETUP_COOP_TASKRUN;
int ret = io_uring_queue_init_params(QUEUE_DEPTH, &ring_, ¶ms);
if (ret < 0) {
params.flags = 0;
ret = io_uring_queue_init_params(QUEUE_DEPTH, &ring_, ¶ms);
if (ret < 0) {
fprintf(stderr, "[IoUringTap] io_uring init failed: %d\n", -ret);
goto init_fail;
}
}
ring_inited = true;
for (size_t i = 0; i < NUM_RECV_BUFFERS; ++i) {
recv_iovecs_[i].iov_base = recv_buffers_[i].data;
recv_iovecs_[i].iov_len = RECV_BUFFER_SIZE;
}
fixed_buffers_registered_ = (io_uring_register_buffers(&ring_, recv_iovecs_, NUM_RECV_BUFFERS) == 0);
if (!fixed_buffers_registered_)
fprintf(stderr, "[IoUringTap] Fixed buffers fallback (still 极快 & 安全).\n");
for (size_t i = 0; i < NUM_RECV_BUFFERS; ++i) {
if (!SubmitReadWithRetry(i)) {
if (pending_reads_.size() < PENDING_MAX) pending_reads_.push(i);
}
}
if (io_uring_submit(&ring_) < 0) {
fprintf(stderr, "[IoUringTap] Initial submit failed\n");
goto partial_fail;
}
// FIXED: tag 防御断言
static_assert(sizeof(uintptr_t) >= 8, "tag safety requires 64-bit");
assert((reinterpret_cast<uintptr_t>(new WriteContext{}) & READ_TAG_MASK) == 0 && "WriteContext pointer must have bit63=0");
ring_initialized_.store(true, std::memory_order_release);
ready_.store(true, std::memory_order_release);
cv_.notify_one();
while (!disposed_.load(std::memory_order_acquire)) {
struct io_uring_cqe* cqe;
unsigned head, completed = 0;
io_uring_for_each_cqe(&ring_, head, cqe) {
++completed;
void* ud = io_uring_cqe_get_data(cqe);
ssize_t res = cqe->res;
uintptr_t tag = reinterpret_cast<uintptr_t>(ud);
if (tag & READ_TAG_MASK) { // 读路径
size_t idx = tag & ~READ_TAG_MASK;
if (read_inflight_ > 0) read_inflight_--;
std::function<void(void*, int)> cb;
{
std::lock_guard<std::mutex> lock(packet_input_mutex_);
cb = packet_input_;
}
if (res > 0 && cb) cb(recv_buffers_[idx].data, static_cast<int>(res));
else if (res < 0) fprintf(stderr, "[IoUringTap] Read err: %zd\n", res);
if (!disposed_.load(std::memory_order_acquire)) {
if (!SubmitRead(idx) && pending_reads_.size() < PENDING_MAX)
pending_reads_.push(idx);
}
} else { // 写路径
auto* ctx = static_cast<WriteContext*>(ud);
if (res < 0) fprintf(stderr, "[IoUringTap] Write err: %zd\n", res);
delete ctx;
if (inflight_writes_ > 0) inflight_writes_--;
}
}
io_uring_cq_advance(&ring_, completed);
if (!disposed_.load(std::memory_order_acquire)) {
size_t attempts = 0;
while (!pending_reads_.empty() && attempts < 64) {
size_t idx = pending_reads_.front();
if (SubmitRead(idx)) pending_reads_.pop();
else break;
++attempts;
}
SubmitWrites();
} else {
while (!pending_reads_.empty()) pending_reads_.pop();
}
int sub = io_uring_submit(&ring_);
if (sub < 0 && sub != -EBUSY) fprintf(stderr, "[IoUringTap] submit err: %d\n", sub);
bool idle = (inflight_writes_ == 0 && send_queue_.size_approx() == 0 &&
pending_reads_.empty() && io_uring_sq_ready(&ring_) == 0);
if (idle) {
std::this_thread::sleep_for(std::chrono::microseconds(IDLE_SLEEP_US));
} else {
std::this_thread::yield();
}
}
cleanup:
if (io_uring_sq_ready(&ring_) > 0) io_uring_submit(&ring_);
// FIXED: 更狠的清理
WriteContext* ctx;
while (send_queue_.try_dequeue(ctx)) delete ctx;
auto start = std::chrono::steady_clock::now();
int drain_round = 0;
while ((read_inflight_ > 0 || inflight_writes_ > 0) &&
std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - start).count() < CLEANUP_TIMEOUT_MS &&
drain_round++ < 300) { // 增加轮次
unsigned head, done = 0;
io_uring_for_each_cqe(&ring_, head, cqe) {
++done;
if (reinterpret_cast<uintptr_t>(io_uring_cqe_get_data(cqe)) & READ_TAG_MASK) {
if (read_inflight_ > 0) --read_inflight_;
} else {
delete static_cast<WriteContext*>(io_uring_cqe_get_data(cqe));
if (inflight_writes_ > 0) --inflight_writes_;
}
}
io_uring_cq_advance(&ring_, done);
if (done == 0) break;
io_uring_submit(&ring_); // FIXED: 每轮强制 submit
std::this_thread::yield();
}
if (fixed_buffers_registered_) io_uring_unregister_buffers(&ring_);
if (ring_inited) io_uring_queue_exit(&ring_);
return;
partial_fail:
init_fail:
if (fixed_buffers_registered_) io_uring_unregister_buffers(&ring_);
if (ring_inited) io_uring_queue_exit(&ring_);
disposed_.store(true, std::memory_order_release);
ring_initialized_.store(false, std::memory_order_release);
ready_.store(true, std::memory_order_release);
cv_.notify_one();
}