IOCP实现UDP Server

IOCP实现UDP Server

1、IOCP原理图

参考文献1:IOCP详解-阿里云开发者社区 (aliyun.com)

参考文献2:IOCP编程之基本原理 - 史D芬周 - 博客园 (cnblogs.com)

原理图

同步以及异步

2、UDP Server代码以及测试代码

c++ 复制代码
// iocpudpdemo.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。
//

// UDP Server
// RIOTest.cpp : Defines the entry point for the console application.
//
#pragma comment(lib, "ws2_32.lib")

#include <WS2tcpip.h>
#include <map>
#include <memory>
#include <cstring>
#include <thread>
#include <iostream>

using namespace std;

SOCKET g_s;
HANDLE g_hIOCP = 0;
long g_workIterations = 0;
LARGE_INTEGER g_frequency;
LARGE_INTEGER g_startCounter;
LARGE_INTEGER g_stopCounter;

volatile long g_packets = 0;

static const DWORD EXPECTED_DATA_SIZE = 8192;
static const DWORD RIO_MAX_RESULTS = 1000;
static const DWORD TIMING_THREAD_AFFINITY_MASK = 1;
static const unsigned short PORT = 8081;

struct EXTENDED_OVERLAPPED : public OVERLAPPED
{
    WSABUF buf;
};

inline void ErrorExit(
    const char* pFunction,
    const DWORD lastError)
{
    cout << "Error: " << pFunction << " failed: " << lastError << endl;
    exit(0);
}

inline void ErrorExit(
    const char* pFunction)
{
    const DWORD lastError = ::GetLastError();

    ErrorExit(pFunction, lastError);
}

inline void SetupTiming(
    const char* pProgramName,
    const bool lockToThreadForTiming = true)
{
    cout << pProgramName << endl;
    cout << "Work load: " << g_workIterations << endl;
    cout << "Max results: " << RIO_MAX_RESULTS << endl;
    if (lockToThreadForTiming)
    {
        HANDLE hThread = ::GetCurrentThread();

        if (0 == ::SetThreadAffinityMask(hThread, TIMING_THREAD_AFFINITY_MASK))
        {
            ErrorExit("SetThreadAffinityMask");
        }
    }
    if (!::QueryPerformanceFrequency(&g_frequency))
    {
        ErrorExit("QueryPerformanceFrequency");
    }
}

inline void PrintTimings(
    const char* pDirection = "Received ")
{
    LARGE_INTEGER elapsed;

    elapsed.QuadPart = (g_stopCounter.QuadPart - g_startCounter.QuadPart) / (g_frequency.QuadPart / 1000);

    cout << "Complete in " << elapsed.QuadPart << "ms" << endl;
    cout << pDirection << g_packets << " datagrams" << endl;

    if (elapsed.QuadPart != 0)
    {
        const double perSec = g_packets / elapsed.QuadPart * 1000.00;

        cout << perSec << " datagrams per second" << endl;
    }
}

inline void InitialiseWinsock()
{
    WSADATA data;
    WORD wVersionRequested = 0x202;
    if (0 != ::WSAStartup(wVersionRequested, &data))
    {
        ErrorExit("WSAStartup");
    }
}

inline SOCKET CreateSocket(
    const DWORD flags = 0)
{
    g_s = ::WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, NULL, 0, flags);
    if (g_s == INVALID_SOCKET)
    {
        ErrorExit("WSASocket");
    }
    return g_s;
}

inline HANDLE CreateIOCP()
{
    g_hIOCP = ::CreateIoCompletionPort(INVALID_HANDLE_VALUE, 0, 0, 0);
    if (0 == g_hIOCP)
    {
        ErrorExit("CreateIoCompletionPort");
    }
    return g_hIOCP;
}

inline void Bind(
    SOCKET s,
    const unsigned short port)
{
    sockaddr_in addr;
    addr.sin_family = AF_INET;
    addr.sin_port = htons(port);
    addr.sin_addr.s_addr = INADDR_ANY;
    if (SOCKET_ERROR == ::bind(s, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)))
    {
        ErrorExit("bind");
    }
}

template <typename TV, typename TM>
inline TV RoundDown(TV Value, TM Multiple)
{
    return((Value / Multiple) * Multiple);
}

template <typename TV, typename TM>
inline TV RoundUp(TV Value, TM Multiple)
{
    return(RoundDown(Value, Multiple) + (((Value % Multiple) > 0) ? Multiple : 0));
}

inline void StartTiming()
{
    if (!::QueryPerformanceCounter(&g_startCounter))
    {
        ErrorExit("QueryPerformanceCounter");
    }

    cout << "Timing started" << endl;
}

inline void StopTiming()
{
    if (!::QueryPerformanceCounter(&g_stopCounter))
    {
        ErrorExit("QueryPerformanceCounter");
    }

    cout << "Timing stopped" << endl;
}

inline char* AllocateBufferSpace(
    const DWORD recvBufferSize,
    const DWORD pendingRecvs,
    DWORD& bufferSize,
    DWORD& receiveBuffersAllocated)
{
    const DWORD preferredNumaNode = 0;
    const SIZE_T largePageMinimum = 0;
    SYSTEM_INFO systemInfo;
    ::GetSystemInfo(&systemInfo);
    systemInfo.dwAllocationGranularity;
    const unsigned __int64 granularity = (largePageMinimum == 0 ? systemInfo.dwAllocationGranularity : largePageMinimum);
    const unsigned __int64 desiredSize = recvBufferSize * pendingRecvs;
    unsigned __int64 actualSize = RoundUp(desiredSize, granularity);
    if (actualSize > (std::numeric_limits<DWORD>::max)())
    {
        actualSize = ((std::numeric_limits<DWORD>::max)() / granularity) * granularity;
    }
    receiveBuffersAllocated = std::min<DWORD>(pendingRecvs, static_cast<DWORD>(actualSize / recvBufferSize));
    bufferSize = static_cast<DWORD>(actualSize);
    char* pBuffer = reinterpret_cast<char*>(VirtualAllocExNuma(GetCurrentProcess(), 0, bufferSize, MEM_COMMIT | MEM_RESERVE | (largePageMinimum != 0 ? MEM_LARGE_PAGES : 0), PAGE_READWRITE, preferredNumaNode));
    if (pBuffer == 0)
    {
        ErrorExit("VirtualAlloc");
    }
    return pBuffer;
}

inline char* AllocateBufferSpace(
    const DWORD recvBufferSize,
    const DWORD pendingRecvs,
    DWORD& receiveBuffersAllocated)
{
    DWORD notUsed;
    return AllocateBufferSpace(recvBufferSize, pendingRecvs, notUsed, receiveBuffersAllocated);
}

inline void PostIOCPRecvs(
    const DWORD recvBufferSize,
    const DWORD pendingRecvs)
{
    DWORD totalBuffersAllocated = 0;

    while (totalBuffersAllocated < pendingRecvs)
    {
        DWORD receiveBuffersAllocated = 0;
        char* pBuffer = AllocateBufferSpace(recvBufferSize, pendingRecvs, receiveBuffersAllocated);
        totalBuffersAllocated += receiveBuffersAllocated;
        DWORD offset = 0;
        const DWORD recvFlags = 0;
        EXTENDED_OVERLAPPED* pBufs = new EXTENDED_OVERLAPPED[receiveBuffersAllocated];
        DWORD bytesRecvd = 0;
        DWORD flags = 0;
        for (DWORD i = 0; i < receiveBuffersAllocated; ++i)
        {
            EXTENDED_OVERLAPPED* pOverlapped = pBufs + i;
            ZeroMemory(pOverlapped, sizeof(EXTENDED_OVERLAPPED));
            pOverlapped->buf.buf = pBuffer + offset;
            pOverlapped->buf.len = recvBufferSize;
            offset += recvBufferSize;
            if (SOCKET_ERROR == ::WSARecvFrom(g_s, &(pOverlapped->buf), 1, &bytesRecvd, &flags, NULL, NULL, pOverlapped, 0))
            {
                const DWORD lastError = ::GetLastError();

                if (lastError != ERROR_IO_PENDING)
                {
                    ErrorExit("WSARecv", lastError);
                }
            }
        }

        if (totalBuffersAllocated != pendingRecvs)
        {
            cout << pendingRecvs << " receives pending" << endl;
        }
    }

    cout << totalBuffersAllocated << " total receives pending" << endl;
}


int main(int argc, char* argv[])
{
    std::map<std::size_t, std::pair<std::size_t, std::shared_ptr<char>>> packets;
    SetupTiming("IOCP UDP");
    InitialiseWinsock();
    SOCKET s = CreateSocket(WSA_FLAG_OVERLAPPED);
    HANDLE hIOCP = CreateIOCP();
    Bind(s, PORT);
    if (0 == ::CreateIoCompletionPort(reinterpret_cast<HANDLE>(s), hIOCP, 0, 0))
    {
        ErrorExit("CreateIoCompletionPort");
    }
    struct sockaddr_in sname;
    int snamesize = sizeof(struct sockaddr_in);
    ::getsockname(s, (struct sockaddr*)&sname, &snamesize);
    std::cout << sname.sin_port << std::endl;
    std::cout << ntohs(sname.sin_port) << std::endl;
    PostIOCPRecvs(8192, 2000);
    bool done = false;
    DWORD numberOfBytes = 0;
    ULONG_PTR completionKey = 0;
    OVERLAPPED* pOverlapped = 0;
    if (!::GetQueuedCompletionStatus(hIOCP, &numberOfBytes, &completionKey, &pOverlapped, INFINITE))
    {
        ErrorExit("GetQueuedCompletionStatus");
    }
    StartTiming();

    //std::thread killIOCP([&]() {
    //    std::cout << "iocp kill start" << std::endl;
    //    std::this_thread::sleep_for(std::chrono::seconds(5));
    //    std::cout << "kill iocp" << std::endl;
    //    CloseHandle(hIOCP);
    //    });
    //killIOCP.detach();

    DWORD bytesRecvd = 0;
    DWORD flags = 0;
    std::size_t times = 0;
    do
    {
        if (numberOfBytes == EXPECTED_DATA_SIZE || numberOfBytes == 100)
        {
            g_packets++;
            EXTENDED_OVERLAPPED* pExtOverlapped = static_cast<EXTENDED_OVERLAPPED*>(pOverlapped);
            if (SOCKET_ERROR == ::WSARecvFrom(g_s, &(pExtOverlapped->buf), 1, &bytesRecvd, &flags, NULL, NULL, pExtOverlapped, 0))
            {
                const DWORD lastError = ::GetLastError();
                std::shared_ptr<char> packet(new char[numberOfBytes]);
                memmove(packet.get(), pExtOverlapped->buf.buf, numberOfBytes);
                if (numberOfBytes == 100) {
                    std::cout << pExtOverlapped->buf.buf[2] << std::endl;
                }
                auto ppp = std::make_pair<std::size_t, std::shared_ptr<char>&>(numberOfBytes, packet);
                packets.insert({ g_packets,ppp });
                if (lastError != ERROR_IO_PENDING)
                {
                    ErrorExit("WSARecv", lastError);
                }
            }
        }
        else
        {
            g_packets++;
            EXTENDED_OVERLAPPED* pExtOverlapped = static_cast<EXTENDED_OVERLAPPED*>(pOverlapped);
            if (SOCKET_ERROR == ::WSARecvFrom(g_s, &(pExtOverlapped->buf), 1, &bytesRecvd, &flags, NULL, NULL, pExtOverlapped, 0))
            {
                const DWORD lastError = ::GetLastError();
                std::shared_ptr<char> packet(new char[numberOfBytes]);
                memmove(packet.get(), pExtOverlapped->buf.buf, numberOfBytes);
                auto ppp = std::make_pair<std::size_t, std::shared_ptr<char>&>(numberOfBytes, packet);
                packets.insert({ g_packets,ppp });
                std::cout << "use count:" << packet.use_count() << std::endl;
                if (lastError != ERROR_IO_PENDING)
                {
                    ErrorExit("WSARecv", lastError);
                }
            }
            std::cout << "packets size: " << packets.size() << std::endl;
            StopTiming();
            done = true;
        }
        if (!done)
        {
            if (!::GetQueuedCompletionStatus(hIOCP, &numberOfBytes, &completionKey, &pOverlapped, INFINITE))
            {
                DWORD error = GetLastError();
                if (ERROR_ABANDONED_WAIT_0 == error || ERROR_INVALID_HANDLE == error) {
                    StopTiming();
                    std::cout << error << std::endl;
                    break;
                }
                ErrorExit("GetQueuedCompletionStatus");
            }
        }
    } while (!done);
    PrintTimings();
    packets.clear();
    return 0;
}

测试代码

c++ 复制代码
#include <boost/asio.hpp>
#include <cstring>
#include <iostream>

char* makeMem(size_t size){
    char* mem = (char*)malloc(size);
    memset(mem,'1',size);
    memset(mem+(size-8296),'2',8296);
    return mem;
}

int main(){
    boost::asio::io_context context;
    boost::asio::ip::udp::endpoint destEndpoint = boost::asio::ip::udp::endpoint(boost::asio::ip::make_address("10.10.1.40"),8081);

    boost::asio::ip::udp::socket transmitter = boost::asio::ip::udp::socket(context,boost::asio::ip::udp::endpoint(boost::asio::ip::udp::v4(), 10120));
    size_t size = 20*1024*1024 + 100;// 20MB
    char* mem = makeMem(size);
    char* sendData = mem;

    std::size_t payload = 8192;
    std::size_t leftLastSize = size %payload;
    std::size_t sendTimes = size/payload;
    char headerData[100] = {3};
    memset(headerData,3,sizeof(headerData));
    transmitter.send_to(boost::asio::buffer(headerData, sizeof(headerData)), destEndpoint);
    for(size_t idx =0;idx<sendTimes;++idx,sendData+=payload){
        transmitter.send_to(boost::asio::buffer(sendData, payload), destEndpoint);
        std::cout << idx<<std::endl;
        //std::this_thread::sleep_for(std::chrono::milliseconds(1));
    }
    transmitter.send_to(boost::asio::buffer(sendData, leftLastSize), destEndpoint);
    std::cout << "------------"<<std::endl;
    std::this_thread::sleep_for(std::chrono::seconds(10));

    sendData = mem;
    transmitter.send_to(boost::asio::buffer(headerData, sizeof(headerData)), destEndpoint);
    for(size_t idx =0;idx<sendTimes;++idx,sendData+=payload){
        transmitter.send_to(boost::asio::buffer(sendData, payload), destEndpoint);
        std::cout << idx<<std::endl;
        //std::this_thread::sleep_for(std::chrono::milliseconds(1));
    }
    transmitter.send_to(boost::asio::buffer(sendData, leftLastSize), destEndpoint);


    char tailerData[10] = {9};
    transmitter.send_to(boost::asio::buffer(tailerData, sizeof(tailerData)), destEndpoint);

    delete mem;
    return 0;
}

最后,推荐一个项目,上述代码基本来自于该项目

LenHolgate/RIO: Code that explores the Windows Registered I/O Networking Extensions (github.com)

相关推荐
巴巴_羊1 小时前
前端面经 计网 http和https区别
网络协议·http·https
LyaJpunov4 小时前
HTTPS全解析:从证书签发到TLS握手优化
网络协议·http·https
你曾经是少年4 小时前
HTTPS
网络协议·http·https
2501_915918414 小时前
多账号管理与自动化中的浏览器指纹对抗方案
websocket·网络协议·tcp/ip·http·网络安全·https·udp
IT专业服务商4 小时前
联想 SR550 服务器,配置 RAID 5教程!
运维·服务器·windows·microsoft·硬件架构
海尔辛5 小时前
学习黑客5 分钟小白弄懂Windows Desktop GUI
windows·学习
gushansanren5 小时前
基于WSL用MSVC编译ffmpeg7.1
windows·ffmpeg
伐尘6 小时前
【Qt】编译 Qt 5.15.x For Windows 基础教程 Visual Studio 2019 MSVC142 x64
windows·qt·visual studio
专注代码七年6 小时前
在Windows 境下,将Redis和Nginx注册为服务。
windows·redis·nginx
-九斤-7 小时前
http和https的区别
网络协议·http·https