使用IO完成端口实现简单回显服务器

说明

使用IO完成端口实现简单回显服务器,因为是测试用的,所以代码很粗糙。

  • 提醒

    使用的是ReadFile、WriteFile来实现Overlapped IO,正式场合应该用WSARecv、WSASend,原因:来自《Windows网络编程技术》 8.2.5节

    在这里插入图片描述

  • 技术点记录下

    io以同步方式立马完成时,系统也会将此通知投递到io完成端口通知列表中,这么做的原因是方便用户编码。

    SetFileCompletionNotificationModes传入FILE_SKIP_COMPLETION_PORT_ON_SUCCESS告诉系统,io以同步方式立马完成时,不要

    将此事件投递到IO完成端口列表中。

    参看《Windows核心编程》第10章 10.5.4

代码

cpp 复制代码
#include <iostream>
#define _WINSOCK_DEPRECATED_NO_WARNINGS
#include <WinSock2.h>
#include <set>
#include <memory>
#include <process.h>

#pragma comment(lib, "Ws2_32.lib")

class MyOverlapped
{
public:
	MyOverlapped()
		:m_bIsRead(false)
	{
		memset(&m_Overlapped, 0, sizeof(OVERLAPPED));
	}

	OVERLAPPED m_Overlapped;
	bool m_bIsRead;
};


struct ClientSocketItem
{
	ClientSocketItem()
	{
		hSocket = NULL;
		memset(szRecv, 0, sizeof(szRecv));
		nRecvSize = 0;
		bFinished = false;
		nWriteOffset = 0;

		readOverlapped.m_bIsRead = true;
		writeOverlapped.m_bIsRead = false;
	}

	SOCKET hSocket;
	std::string strIp;

	MyOverlapped readOverlapped;
	char szRecv[1024];
	unsigned int nRecvSize;
	
	MyOverlapped writeOverlapped;
	unsigned int nWriteOffset = 0;

	bool bFinished;
};
std::set<ClientSocketItem*> g_Clients;

HANDLE g_hIoCompletionPort = NULL;



bool do_read(ClientSocketItem* pClient)
{
	if (!pClient)
	{
		return false;
	}

	char c = 0; //测试用,每次只读一个字符
	DWORD dwReadBytes;
	if (::ReadFile((HANDLE)(pClient->hSocket), &(pClient->szRecv[pClient->nRecvSize]),
		1, &dwReadBytes, &(pClient->readOverlapped.m_Overlapped)))
	{
		return true;
	}

	DWORD dwError = ::GetLastError();
	if (ERROR_IO_PENDING == dwError)
	{
		return true;
	}

	std::cerr << "read failed with error " << dwError << std::endl;

	return false;
}


bool do_write(ClientSocketItem* pClient)
{
	if (!pClient)
	{
		return false;
	}

	//测试用,每次只发送一个字符
	DWORD dwWriteBytes = 0;
	if (::WriteFile((HANDLE)(pClient->hSocket), &(pClient->szRecv[pClient->nWriteOffset]),
		1, &dwWriteBytes, &(pClient->writeOverlapped.m_Overlapped)))
	{
		return true;
	}

	DWORD dwError = ::GetLastError();
	if (ERROR_IO_PENDING == dwError)
	{
		return true;
	}

	std::cerr << "write failed with error " << dwError << std::endl;

	return false;
}


bool do_accept(SOCKET hListenSocket)
{
	sockaddr_in mPeerAddr = { 0 };
	int nAddrLen = sizeof(sockaddr);
	SOCKET hClientSocket = accept(hListenSocket, (sockaddr*)(&mPeerAddr), &nAddrLen);
	if (INVALID_SOCKET == hClientSocket)
	{
		std::cout << "accept failed with error "
			<< WSAGetLastError() << std::endl;
		return false;
	}
	else
	{
		unsigned long nNoBlock = 0;
		ioctlsocket(hClientSocket, FIONBIO, &nNoBlock);

		std::string strIpAddr = inet_ntoa(mPeerAddr.sin_addr);
		std::cout << "accept success, peer ip is " << strIpAddr.c_str() << std::endl;

		auto pClient = new ClientSocketItem();
		pClient->hSocket = hClientSocket;
		pClient->strIp = strIpAddr;
		g_Clients.insert(pClient);

		//附加到IO完成端口上
		if (g_hIoCompletionPort != ::CreateIoCompletionPort(HANDLE(pClient->hSocket),
			g_hIoCompletionPort, ULONG_PTR(pClient), 0))
		{
			std::cerr << "attach socket to io completion port failed with error"
				<< ::GetLastError() << std::endl;
			closesocket(hClientSocket);
			return false;
		}

		//触发读取
		if (!do_read(pClient))
		{
			closesocket(pClient->hSocket);
			g_Clients.erase(pClient);
			std::cerr << "do_read failed, close client" << std::endl;
		}

		return true;
	}
}


//读写线程函数
unsigned __stdcall ReadWriteThreadFun(void* pParam)
{
	DWORD dwTransferBytes = 0;
	ULONG_PTR nCompleteKey = 0;
	LPOVERLAPPED lpOverlapped = NULL;
	while (::GetQueuedCompletionStatus(g_hIoCompletionPort, &dwTransferBytes,
		&nCompleteKey, &lpOverlapped, INFINITE))
	{
		if (nCompleteKey == UINT_MAX)
		{
			std::cout << "user require quit" << std::endl;
			break;
		}

		if (!lpOverlapped)
		{
			std::cerr << "lpOverlapped is null" << std::endl;
			break;
		}

		ClientSocketItem* pClient = (ClientSocketItem*)nCompleteKey;
		MyOverlapped* pMyOverlapped = CONTAINING_RECORD(lpOverlapped, MyOverlapped, m_Overlapped);
		if (!pMyOverlapped || !pClient)
		{
			std::cerr << "pMyOverlapped or pClient is null" << std::endl;
			break;
		}

		if (pMyOverlapped->m_bIsRead)//read finished notify
		{
			char c = pClient->szRecv[pClient->nRecvSize];
			pClient->nRecvSize += dwTransferBytes;
			std::cout << "read one char: " << c << std::endl;
			if (c == '\n')
			{
				std::cout << "read finished, start to write" << std::endl;
				if (!do_write(pClient))
				{
					std::cerr << "do_write failed, close socket" << std::endl;
					closesocket(pClient->hSocket);
					g_Clients.erase(pClient);
				}
			}
			else
			{
				std::cout << "next char read" << std::endl;
				if (!do_read(pClient))
				{
					std::cerr << "do_read failed, close socket" << std::endl;
					closesocket(pClient->hSocket);
					g_Clients.erase(pClient);
				}
			}
		}
		else //write finished notify
		{
			char c = pClient->szRecv[pClient->nWriteOffset];
			std::cout << "send one char: " << c << std::endl;
			pClient->nWriteOffset += dwTransferBytes;
			if (pClient->nWriteOffset == pClient->nRecvSize)
			{
				std::cout << "send finished, close client(" << pClient->strIp.c_str()
					<< ")" << std::endl;
				closesocket(pClient->hSocket);
				g_Clients.erase(pClient);
			}
			else
			{
				std::cout << "next send" << std::endl;
				if (!do_write(pClient))
				{
					std::cerr << "do_write failed, close socket" << std::endl;
					closesocket(pClient->hSocket);
					g_Clients.erase(pClient);
				}
			}
		}
	}

	std::cout << "thread quit" << std::endl;
	return 0;
}


int main(int argc, char* argv)
{
	WORD wVersionRequested = MAKEWORD(2, 2);
	WSADATA wsaData = { 0 };
	int err = WSAStartup(wVersionRequested, &wsaData);
	if (err != 0)
	{
		return -1;
	}

	if (LOBYTE(wsaData.wVersion) != 2 ||
		HIBYTE(wsaData.wVersion) != 2)
	{
		WSACleanup();
		return -1;
	}

	SOCKET hListenSocket = socket(AF_INET, SOCK_STREAM, 0);
	if (INVALID_SOCKET == hListenSocket)
	{
		std::cerr << "create socket failed with error " << WSAGetLastError()
			<< std::endl;
		return -1;
	}

	sockaddr_in mSockAddrIn = { 0 };
	mSockAddrIn.sin_family = AF_INET;
	mSockAddrIn.sin_port = htons((u_short)8878);
	mSockAddrIn.sin_addr.S_un.S_addr = inet_addr("0.0.0.0");
	if (SOCKET_ERROR == bind(hListenSocket, (sockaddr*)(&mSockAddrIn),
		sizeof(sockaddr)))
	{
		std::cerr << "bind failed with error " << WSAGetLastError() << std::endl;
		return -1;
	}

	if (SOCKET_ERROR == listen(hListenSocket, SOMAXCONN))
	{
		std::cerr << "listen failed with error " << WSAGetLastError() << std::endl;
		return -1;
	}

	//创建完成端口
	g_hIoCompletionPort = ::CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0);
	if (NULL == g_hIoCompletionPort)
	{
		std::cerr << "create io completion port failed with error " 
			<< ::GetLastError() << std::endl;
		return -1;
	}

	//创建一堆服务线程
	for (int i = 0; i < 4; ++i)
	{
		_beginthreadex(0, 0, ReadWriteThreadFun, 0, 0, nullptr);
	}

	while (true)
	{
		if (!do_accept(hListenSocket))
		{
			break;
		}
	}

	::PostQueuedCompletionStatus(g_hIoCompletionPort, 0, UINT_MAX, nullptr);
	Sleep(2000); 


	return 0;
}