说明
使用IO完成端口实现简单回显服务器,因为是测试用的,所以代码很粗糙。
-
提醒
使用的是ReadFile、WriteFile来实现Overlapped IO,正式场合应该用WSARecv、WSASend,原因:来自《Windows网络编程技术》 8.2.5节
在这里插入图片描述
-
技术点记录下
io以同步方式立马完成时,系统也会将此通知投递到io完成端口通知列表中,这么做的原因是方便用户编码。
SetFileCompletionNotificationModes传入FILE_SKIP_COMPLETION_PORT_ON_SUCCESS告诉系统,io以同步方式立马完成时,不要
将此事件投递到IO完成端口列表中。
参看《Windows核心编程》第10章 10.5.4
代码
cpp
#include <iostream>
#define _WINSOCK_DEPRECATED_NO_WARNINGS
#include <WinSock2.h>
#include <set>
#include <memory>
#include <process.h>
#pragma comment(lib, "Ws2_32.lib")
class MyOverlapped
{
public:
MyOverlapped()
:m_bIsRead(false)
{
memset(&m_Overlapped, 0, sizeof(OVERLAPPED));
}
OVERLAPPED m_Overlapped;
bool m_bIsRead;
};
struct ClientSocketItem
{
ClientSocketItem()
{
hSocket = NULL;
memset(szRecv, 0, sizeof(szRecv));
nRecvSize = 0;
bFinished = false;
nWriteOffset = 0;
readOverlapped.m_bIsRead = true;
writeOverlapped.m_bIsRead = false;
}
SOCKET hSocket;
std::string strIp;
MyOverlapped readOverlapped;
char szRecv[1024];
unsigned int nRecvSize;
MyOverlapped writeOverlapped;
unsigned int nWriteOffset = 0;
bool bFinished;
};
std::set<ClientSocketItem*> g_Clients;
HANDLE g_hIoCompletionPort = NULL;
bool do_read(ClientSocketItem* pClient)
{
if (!pClient)
{
return false;
}
char c = 0; //测试用,每次只读一个字符
DWORD dwReadBytes;
if (::ReadFile((HANDLE)(pClient->hSocket), &(pClient->szRecv[pClient->nRecvSize]),
1, &dwReadBytes, &(pClient->readOverlapped.m_Overlapped)))
{
return true;
}
DWORD dwError = ::GetLastError();
if (ERROR_IO_PENDING == dwError)
{
return true;
}
std::cerr << "read failed with error " << dwError << std::endl;
return false;
}
bool do_write(ClientSocketItem* pClient)
{
if (!pClient)
{
return false;
}
//测试用,每次只发送一个字符
DWORD dwWriteBytes = 0;
if (::WriteFile((HANDLE)(pClient->hSocket), &(pClient->szRecv[pClient->nWriteOffset]),
1, &dwWriteBytes, &(pClient->writeOverlapped.m_Overlapped)))
{
return true;
}
DWORD dwError = ::GetLastError();
if (ERROR_IO_PENDING == dwError)
{
return true;
}
std::cerr << "write failed with error " << dwError << std::endl;
return false;
}
bool do_accept(SOCKET hListenSocket)
{
sockaddr_in mPeerAddr = { 0 };
int nAddrLen = sizeof(sockaddr);
SOCKET hClientSocket = accept(hListenSocket, (sockaddr*)(&mPeerAddr), &nAddrLen);
if (INVALID_SOCKET == hClientSocket)
{
std::cout << "accept failed with error "
<< WSAGetLastError() << std::endl;
return false;
}
else
{
unsigned long nNoBlock = 0;
ioctlsocket(hClientSocket, FIONBIO, &nNoBlock);
std::string strIpAddr = inet_ntoa(mPeerAddr.sin_addr);
std::cout << "accept success, peer ip is " << strIpAddr.c_str() << std::endl;
auto pClient = new ClientSocketItem();
pClient->hSocket = hClientSocket;
pClient->strIp = strIpAddr;
g_Clients.insert(pClient);
//附加到IO完成端口上
if (g_hIoCompletionPort != ::CreateIoCompletionPort(HANDLE(pClient->hSocket),
g_hIoCompletionPort, ULONG_PTR(pClient), 0))
{
std::cerr << "attach socket to io completion port failed with error"
<< ::GetLastError() << std::endl;
closesocket(hClientSocket);
return false;
}
//触发读取
if (!do_read(pClient))
{
closesocket(pClient->hSocket);
g_Clients.erase(pClient);
std::cerr << "do_read failed, close client" << std::endl;
}
return true;
}
}
//读写线程函数
unsigned __stdcall ReadWriteThreadFun(void* pParam)
{
DWORD dwTransferBytes = 0;
ULONG_PTR nCompleteKey = 0;
LPOVERLAPPED lpOverlapped = NULL;
while (::GetQueuedCompletionStatus(g_hIoCompletionPort, &dwTransferBytes,
&nCompleteKey, &lpOverlapped, INFINITE))
{
if (nCompleteKey == UINT_MAX)
{
std::cout << "user require quit" << std::endl;
break;
}
if (!lpOverlapped)
{
std::cerr << "lpOverlapped is null" << std::endl;
break;
}
ClientSocketItem* pClient = (ClientSocketItem*)nCompleteKey;
MyOverlapped* pMyOverlapped = CONTAINING_RECORD(lpOverlapped, MyOverlapped, m_Overlapped);
if (!pMyOverlapped || !pClient)
{
std::cerr << "pMyOverlapped or pClient is null" << std::endl;
break;
}
if (pMyOverlapped->m_bIsRead)//read finished notify
{
char c = pClient->szRecv[pClient->nRecvSize];
pClient->nRecvSize += dwTransferBytes;
std::cout << "read one char: " << c << std::endl;
if (c == '\n')
{
std::cout << "read finished, start to write" << std::endl;
if (!do_write(pClient))
{
std::cerr << "do_write failed, close socket" << std::endl;
closesocket(pClient->hSocket);
g_Clients.erase(pClient);
}
}
else
{
std::cout << "next char read" << std::endl;
if (!do_read(pClient))
{
std::cerr << "do_read failed, close socket" << std::endl;
closesocket(pClient->hSocket);
g_Clients.erase(pClient);
}
}
}
else //write finished notify
{
char c = pClient->szRecv[pClient->nWriteOffset];
std::cout << "send one char: " << c << std::endl;
pClient->nWriteOffset += dwTransferBytes;
if (pClient->nWriteOffset == pClient->nRecvSize)
{
std::cout << "send finished, close client(" << pClient->strIp.c_str()
<< ")" << std::endl;
closesocket(pClient->hSocket);
g_Clients.erase(pClient);
}
else
{
std::cout << "next send" << std::endl;
if (!do_write(pClient))
{
std::cerr << "do_write failed, close socket" << std::endl;
closesocket(pClient->hSocket);
g_Clients.erase(pClient);
}
}
}
}
std::cout << "thread quit" << std::endl;
return 0;
}
int main(int argc, char* argv)
{
WORD wVersionRequested = MAKEWORD(2, 2);
WSADATA wsaData = { 0 };
int err = WSAStartup(wVersionRequested, &wsaData);
if (err != 0)
{
return -1;
}
if (LOBYTE(wsaData.wVersion) != 2 ||
HIBYTE(wsaData.wVersion) != 2)
{
WSACleanup();
return -1;
}
SOCKET hListenSocket = socket(AF_INET, SOCK_STREAM, 0);
if (INVALID_SOCKET == hListenSocket)
{
std::cerr << "create socket failed with error " << WSAGetLastError()
<< std::endl;
return -1;
}
sockaddr_in mSockAddrIn = { 0 };
mSockAddrIn.sin_family = AF_INET;
mSockAddrIn.sin_port = htons((u_short)8878);
mSockAddrIn.sin_addr.S_un.S_addr = inet_addr("0.0.0.0");
if (SOCKET_ERROR == bind(hListenSocket, (sockaddr*)(&mSockAddrIn),
sizeof(sockaddr)))
{
std::cerr << "bind failed with error " << WSAGetLastError() << std::endl;
return -1;
}
if (SOCKET_ERROR == listen(hListenSocket, SOMAXCONN))
{
std::cerr << "listen failed with error " << WSAGetLastError() << std::endl;
return -1;
}
//创建完成端口
g_hIoCompletionPort = ::CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0);
if (NULL == g_hIoCompletionPort)
{
std::cerr << "create io completion port failed with error "
<< ::GetLastError() << std::endl;
return -1;
}
//创建一堆服务线程
for (int i = 0; i < 4; ++i)
{
_beginthreadex(0, 0, ReadWriteThreadFun, 0, 0, nullptr);
}
while (true)
{
if (!do_accept(hListenSocket))
{
break;
}
}
::PostQueuedCompletionStatus(g_hIoCompletionPort, 0, UINT_MAX, nullptr);
Sleep(2000);
return 0;
}