本文功能实现较为简陋,demo内容仅供参考,有不足之处还请指正。
背景
一个小项目,用于微信小程序的服务端,需要实现小程序端可以和他人1对1聊天
实现功能
Websocket、心跳检测、消息持久化、离线消息存储
Netty配置类
java
/**
* @author Aseubel
*/
@Component
@Slf4j
@EnableConfigurationProperties(NettyServerConfigProperties.class)
public class NettyServerConfig {
private ChannelFuture serverChannelFuture;
// 心跳间隔(秒)
private static final int HEARTBEAT_INTERVAL = 15;
// 读超时时间
private static final int READ_TIMEOUT = HEARTBEAT_INTERVAL * 2;
// 使用线程池管理
private final EventLoopGroup bossGroup = new NioEventLoopGroup(1);
private final EventLoopGroup workerGroup = new NioEventLoopGroup();
private final NettyServerConfigProperties properties;
// 由于在后面的handler中有依赖注入类,所以要通过springboot的ApplicationContext来获取Bean实例
@Autowired
private ApplicationContext applicationContext;
public NettyServerConfig(NettyServerConfigProperties properties) {
this.properties = properties;
}
@PostConstruct
public void startNettyServer() {
// 使用独立线程启动Netty服务
new Thread(() -> {
try {
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
SSLContext sslContext = SslUtil.createSSLContext("PKCS12",
properties.getSslPath(), properties.getSslPassword());
// SSLEngine 此类允许使用ssl安全套接层协议进行安全通信
SSLEngine engine = sslContext.createSSLEngine();
engine.setUseClientMode(false);
pipeline.addLast(new SslHandler(engine)); // 设置SSL
pipeline.addLast(new HttpServerCodec());
pipeline.addLast(new HttpObjectAggregator(10 * 1024 * 1024));// 最大10MB
pipeline.addLast(new ChunkedWriteHandler());
pipeline.addLast(new HttpHandler());
// 只有text和binarytext的帧能经过WebSocketServerProtocolHandler,所以心跳检测这两个都得放前面
pipeline.addLast(new IdleStateHandler(READ_TIMEOUT, 0, 0, TimeUnit.SECONDS));
pipeline.addLast(new HeartbeatHandler());
pipeline.addLast(new WebSocketServerProtocolHandler("/ws", null, true, 10 * 1024 * 1024));
pipeline.addLast(applicationContext.getBean(MessageHandler.class));
pipeline.addLast(new ChannelInboundHandlerAdapter() {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
// 统一处理所有未被前面handler捕获的异常
log.error("全局异常捕获: {}", cause.getMessage());
ctx.channel().close();
}
});
}
});
serverChannelFuture = bootstrap.bind(properties.getPort()).sync();
// 保持通道开放
serverChannelFuture.channel().closeFuture().sync();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}).start();
}
@PreDestroy
public void stopNettyServer() {
// 优雅关闭
if (serverChannelFuture != null) {
serverChannelFuture.channel().close();
}
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}
}
Handler
心跳检测
java
/**
* @author Aseubel
*/
public class HeartbeatHandler extends ChannelInboundHandlerAdapter {
private static final int HEARTBEAT_INTERVAL = 15; // 心跳间隔(秒)
private static final int MAX_MISSED_HEARTBEATS = 2; // 允许丢失的心跳次数
// 记录每个连接的丢失心跳次数
private final Map<ChannelId, Integer> missedHeartbeats = new ConcurrentHashMap<>();
@Override
public void channelActive(ChannelHandlerContext ctx) {
// 添加 IdleStateHandler 触发读空闲事件
ctx.pipeline().addLast(new IdleStateHandler(HEARTBEAT_INTERVAL * MAX_MISSED_HEARTBEATS, 0, 0));
scheduleHeartbeat(ctx);
}
private void scheduleHeartbeat(ChannelHandlerContext ctx) {
ctx.executor().scheduleAtFixedRate(() -> {
if (ctx.channel().isActive()) {
ctx.writeAndFlush(new PingWebSocketFrame(Unpooled.copiedBuffer("HEARTBEAT", CharsetUtil.UTF_8)));
// 记录丢失的心跳次数
missedHeartbeats.compute(ctx.channel().id(), (k, v) -> v == null ? 1 : v + 1);
}
}, HEARTBEAT_INTERVAL, HEARTBEAT_INTERVAL, TimeUnit.SECONDS);
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
if (msg instanceof PongWebSocketFrame) {
// 收到 Pong 后重置丢失计数
missedHeartbeats.remove(ctx.channel().id());
ctx.fireChannelRead(msg); // 传递消息给后续处理器
} else {
ctx.fireChannelRead(msg);
}
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof IdleStateEvent) {
int missed = missedHeartbeats.getOrDefault(ctx.channel().id(), 0);
if (missed >= MAX_MISSED_HEARTBEATS) {
// 超过最大丢失次数,关闭连接
System.out.println("连接超时,关闭连接" + ctx.channel().id().asLongText());
ctx.close();
cleanOfflineResources(ctx.channel());
}
}
}
private void cleanOfflineResources(Channel channel) {
MessageHandler.removeUserChannel(channel);
missedHeartbeats.remove(channel.id());
}
}
处理http请求,建立连接
java
/**
* @author Aseubel
* @description 处理websocket连接请求,将code参数存入channel的attribute中
* @date 2025-02-21 15:34
*/
public class HttpHandler extends ChannelInboundHandlerAdapter {
public static final AttributeKey<String> WS_TOKEN_KEY = AttributeKey.valueOf("code");
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
// 判断是否是连接请求
if (msg instanceof FullHttpRequest) {
FullHttpRequest request = (FullHttpRequest) msg;
try {
QueryStringDecoder decoder = new QueryStringDecoder(request.uri());
ctx.channel().attr(WS_TOKEN_KEY).set(decoder.parameters().get("code").get(0));
} catch (Exception e) {
throw new AppException("非法的websocket连接请求");
}
// 将 FullHttpRequest 转发到 MessageHandler
ctx.fireChannelRead(request);
// 重新设置 uri,将请求转发到 websocket handler,否则无法成功建立连接
request.setUri("/ws");
}
// 消息直接交给下一个 handler
super.channelRead(ctx, msg);
}
}
消息处理
java
/**
* @author Aseubel
* @description 处理 WebSocket 消息
* @date 2025-02-21 15:33
*/
@Component
@Slf4j
@Sharable
public class MessageHandler extends SimpleChannelInboundHandler<WebSocketFrame> {
public static final AttributeKey<String> WS_TOKEN_KEY = AttributeKey.valueOf("code");
public static final AttributeKey<String> WS_USER_ID_KEY = AttributeKey.valueOf("userId");
private static final Map<String, Queue<WebSocketFrame>> OFFLINE_MSGS = new ConcurrentHashMap<>();
private static final Map<String, Channel> userChannels = new ConcurrentHashMap<>();
@Autowired
private ThreadPoolTaskExecutor threadPoolExecutor;
@Resource
private IMessageRepository messageRepository;
// 提供受控的访问方法
public static void removeUserChannel(Channel channel) {
userChannels.values().remove(channel);
}
public static boolean containsUser(String userId) {
return userChannels.containsKey(userId);
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
super.channelActive(ctx);
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object req) throws Exception {
if (req instanceof FullHttpRequest) {
String code = getCodeFromRequest(ctx); // 从请求中提取 code
String userId = getOpenid(APPID, SECRET, code); // 验证 code 获取 openid
userChannels.put(userId, ctx.channel());
ctx.channel().attr(WS_USER_ID_KEY).set(userId);
System.out.println("客户端连接成功,用户id:" + userId);
// 由于这里还在处理握手请求也就是建立连接,所以需要延迟发送离线消息
new Thread(() -> {
try {
Thread.sleep(50);
OFFLINE_MSGS.getOrDefault(userId, new LinkedList<>())
.forEach(ctx::writeAndFlush);
OFFLINE_MSGS.remove(userId);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}).start();
} else if (req instanceof TextWebSocketFrame ) {
this.channelRead0(ctx, (TextWebSocketFrame) req);
} else {
ctx.fireChannelRead(req);
}
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame frame) throws Exception {
if (frame instanceof TextWebSocketFrame) {
MessageEntity message = validateMessage(ctx.channel().attr(WS_USER_ID_KEY).get(), (TextWebSocketFrame) frame);
saveMessage(message);
sendOrStoreMessage(message.getToUserId(), frame);
} else {
ctx.close();
}
}
// 处理连接断开
@Override
public void channelInactive(ChannelHandlerContext ctx) {
System.out.println("客户端断开连接,用户id:" + ctx.channel().attr(WS_USER_ID_KEY).get());
Channel channel = ctx.channel();
for (Map.Entry<String, Channel> entry : userChannels.entrySet()) {
if (entry.getValue() == channel) {
userChannels.remove(entry.getKey());
break;
}
}
}
private MessageEntity validateMessage(String userId, TextWebSocketFrame textFrame) {
String message = textFrame.text();
try {
JsonObject json = JsonParser.parseString(message).getAsJsonObject();
String toUserId = json.get("toUserId").getAsString();
String content = json.get("content").getAsString();
String type = json.get("type").getAsString();
if (type.equals("text") || type.equals("image")) {
return new MessageEntity(userId, toUserId, content, type);
} else {
throw new AppException("非法的消息类型!");
}
} catch (Exception e) {
throw new AppException("非法的消息格式!");
}
}
private void sendOrStoreMessage(String toUserId, WebSocketFrame message) {
if (isUserOnline(toUserId)) {
Channel targetChannel = userChannels.get(toUserId);
if (targetChannel != null && targetChannel.isActive()) {
targetChannel.writeAndFlush(message.retain());
}
} else {
// 存储原始WebSocketFrame(需保留引用)
OFFLINE_MSGS.computeIfAbsent(toUserId, k -> new LinkedList<>())
.add(message.retain());
}
}
private void saveMessage(MessageEntity message) {
threadPoolExecutor.execute(() -> {
messageRepository.saveMessage(message);
});
}
private boolean isUserOnline(String userId) {
return userChannels.containsKey(userId);
}
private String getCodeFromRequest(ChannelHandlerContext ctx) {
String code = ctx.channel().attr(WS_TOKEN_KEY).get();
// 检查 code 参数是否存在且非空
if (code == null || code.isEmpty()) {
throw new IllegalArgumentException("WebSocket token is missing or empty");
}
return code;
}
private String getOpenid(String appid, String secret, String code) {
Map<String, String> paramMap = new HashMap<>();
paramMap.put("appid", appid);
paramMap.put("secret", secret);
paramMap.put("js_code", code);
paramMap.put("grant_type", "authorization_code");
String result = HttpClientUtil.doGet(WX_LOGIN, paramMap);
//获取请求结果
JSONObject jsonObject = JSON.parseObject(result);
String openid = jsonObject.getString("openid");
//判断openid是否存在
if (StringUtils.isEmpty(openid)) {
throw new WxException(jsonObject.getString("errcode"), jsonObject.getString("errmsg"));
}
return openid;
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
if (cause instanceof AppException appCause) {
log.error("AppException caught: {}", appCause.getInfo());
} else if (cause instanceof WxException wxCause) {
log.error("WxException caught: {}", wxCause.getMessage());
} else {
log.error("Exception caught: {}", cause.getMessage(), cause);
}
ctx.close(); // 建议关闭发生异常的连接
}
}
连接及消息格式:
bash
wss://127.0.0.1:21611/ws?code=xxxxxx
{
"toUserId": "1001",
"type": "text",
"content": "Hello World!"
}
规定了type只有text和image两种,text为文本content,image则为Base64编码格式
本文功能实现较为简陋,demo内容仅供参考,可能有注释错误或设计不合理的地方