Springboot整合Netty简单实现1对1聊天(vx小程序服务端)

本文功能实现较为简陋,demo内容仅供参考,有不足之处还请指正。

背景

一个小项目,用于微信小程序的服务端,需要实现小程序端可以和他人1对1聊天

实现功能

Websocket、心跳检测、消息持久化、离线消息存储

Netty配置类

java 复制代码
/**
 * @author Aseubel
 */
@Component
@Slf4j
@EnableConfigurationProperties(NettyServerConfigProperties.class)
public class NettyServerConfig {

    private ChannelFuture serverChannelFuture;

    // 心跳间隔(秒)
    private static final int HEARTBEAT_INTERVAL = 15;
    // 读超时时间
    private static final int READ_TIMEOUT = HEARTBEAT_INTERVAL * 2;

    // 使用线程池管理
    private final EventLoopGroup bossGroup = new NioEventLoopGroup(1);
    private final EventLoopGroup workerGroup = new NioEventLoopGroup();

    private final NettyServerConfigProperties properties;

    // 由于在后面的handler中有依赖注入类,所以要通过springboot的ApplicationContext来获取Bean实例
    @Autowired
    private ApplicationContext applicationContext;

    public NettyServerConfig(NettyServerConfigProperties properties) {
        this.properties = properties;
    }

    @PostConstruct
    public void startNettyServer() {
        // 使用独立线程启动Netty服务
        new Thread(() -> {
            try {
                ServerBootstrap bootstrap = new ServerBootstrap();
                bootstrap.group(bossGroup, workerGroup)
                        .channel(NioServerSocketChannel.class)
                        .childHandler(new ChannelInitializer<Channel>() {
                            @Override
                            protected void initChannel(Channel ch) throws Exception {
                                ChannelPipeline pipeline = ch.pipeline();

                                SSLContext sslContext = SslUtil.createSSLContext("PKCS12",
                                        properties.getSslPath(), properties.getSslPassword());
                                // SSLEngine 此类允许使用ssl安全套接层协议进行安全通信
                                SSLEngine engine = sslContext.createSSLEngine();
                                engine.setUseClientMode(false);

                                pipeline.addLast(new SslHandler(engine)); // 设置SSL
                                pipeline.addLast(new HttpServerCodec());
                                pipeline.addLast(new HttpObjectAggregator(10 * 1024 * 1024));// 最大10MB
                                pipeline.addLast(new ChunkedWriteHandler());
                                pipeline.addLast(new HttpHandler());
                                // 只有text和binarytext的帧能经过WebSocketServerProtocolHandler,所以心跳检测这两个都得放前面
                                pipeline.addLast(new IdleStateHandler(READ_TIMEOUT, 0, 0, TimeUnit.SECONDS));
                                pipeline.addLast(new HeartbeatHandler());
                                pipeline.addLast(new WebSocketServerProtocolHandler("/ws", null, true, 10 * 1024 * 1024));
                                pipeline.addLast(applicationContext.getBean(MessageHandler.class));
                                pipeline.addLast(new ChannelInboundHandlerAdapter() {
                                    @Override
                                    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
                                        // 统一处理所有未被前面handler捕获的异常
                                        log.error("全局异常捕获: {}", cause.getMessage());
                                        ctx.channel().close();
                                    }
                                });
                            }
                        });

                serverChannelFuture = bootstrap.bind(properties.getPort()).sync();

                // 保持通道开放
                serverChannelFuture.channel().closeFuture().sync();
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }).start();
    }

    @PreDestroy
    public void stopNettyServer() {
        // 优雅关闭
        if (serverChannelFuture != null) {
            serverChannelFuture.channel().close();
        }
        bossGroup.shutdownGracefully();
        workerGroup.shutdownGracefully();
    }

}

Handler

心跳检测

java 复制代码
/**
 * @author Aseubel
 */
public class HeartbeatHandler extends ChannelInboundHandlerAdapter {
    private static final int HEARTBEAT_INTERVAL = 15; // 心跳间隔(秒)
    private static final int MAX_MISSED_HEARTBEATS = 2; // 允许丢失的心跳次数
    // 记录每个连接的丢失心跳次数
    private final Map<ChannelId, Integer> missedHeartbeats = new ConcurrentHashMap<>();

    @Override
    public void channelActive(ChannelHandlerContext ctx) {
        // 添加 IdleStateHandler 触发读空闲事件
        ctx.pipeline().addLast(new IdleStateHandler(HEARTBEAT_INTERVAL * MAX_MISSED_HEARTBEATS, 0, 0));
        scheduleHeartbeat(ctx);
    }

    private void scheduleHeartbeat(ChannelHandlerContext ctx) {
        ctx.executor().scheduleAtFixedRate(() -> {
            if (ctx.channel().isActive()) {
                ctx.writeAndFlush(new PingWebSocketFrame(Unpooled.copiedBuffer("HEARTBEAT", CharsetUtil.UTF_8)));
                // 记录丢失的心跳次数
                missedHeartbeats.compute(ctx.channel().id(), (k, v) -> v == null ? 1 : v + 1);
            }
        }, HEARTBEAT_INTERVAL, HEARTBEAT_INTERVAL, TimeUnit.SECONDS);
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) {
        if (msg instanceof PongWebSocketFrame) {
            // 收到 Pong 后重置丢失计数
            missedHeartbeats.remove(ctx.channel().id());
            ctx.fireChannelRead(msg); // 传递消息给后续处理器
        } else {
            ctx.fireChannelRead(msg);
        }
    }

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
        if (evt instanceof IdleStateEvent) {
            int missed = missedHeartbeats.getOrDefault(ctx.channel().id(), 0);
            if (missed >= MAX_MISSED_HEARTBEATS) {
                // 超过最大丢失次数,关闭连接
                System.out.println("连接超时,关闭连接" + ctx.channel().id().asLongText());
                ctx.close();
                cleanOfflineResources(ctx.channel());
            }
        }
    }

    private void cleanOfflineResources(Channel channel) {
        MessageHandler.removeUserChannel(channel);
        missedHeartbeats.remove(channel.id());
    }
}

处理http请求,建立连接

java 复制代码
/**
 * @author Aseubel
 * @description 处理websocket连接请求,将code参数存入channel的attribute中
 * @date 2025-02-21 15:34
 */
public class HttpHandler extends ChannelInboundHandlerAdapter {

    public static final AttributeKey<String> WS_TOKEN_KEY = AttributeKey.valueOf("code");


    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        // 判断是否是连接请求
        if (msg instanceof FullHttpRequest) {
            FullHttpRequest request = (FullHttpRequest) msg;
            try {
                QueryStringDecoder decoder = new QueryStringDecoder(request.uri());
                ctx.channel().attr(WS_TOKEN_KEY).set(decoder.parameters().get("code").get(0));
            } catch (Exception e) {
                throw new AppException("非法的websocket连接请求");
            }
            // 将 FullHttpRequest 转发到 MessageHandler
            ctx.fireChannelRead(request);
            // 重新设置 uri,将请求转发到 websocket handler,否则无法成功建立连接
            request.setUri("/ws");
        }
        // 消息直接交给下一个 handler
        super.channelRead(ctx, msg);
    }

}

消息处理

java 复制代码
/**
 * @author Aseubel
 * @description 处理 WebSocket 消息
 * @date 2025-02-21 15:33
 */
@Component
@Slf4j
@Sharable
public class MessageHandler extends SimpleChannelInboundHandler<WebSocketFrame> {

    public static final AttributeKey<String> WS_TOKEN_KEY = AttributeKey.valueOf("code");
    public static final AttributeKey<String> WS_USER_ID_KEY = AttributeKey.valueOf("userId");

    private static final Map<String, Queue<WebSocketFrame>> OFFLINE_MSGS = new ConcurrentHashMap<>();

    private static final Map<String, Channel> userChannels = new ConcurrentHashMap<>();

    @Autowired
    private ThreadPoolTaskExecutor threadPoolExecutor;

    @Resource
    private IMessageRepository messageRepository;

    // 提供受控的访问方法
    public static void removeUserChannel(Channel channel) {
        userChannels.values().remove(channel);
    }

    public static boolean containsUser(String userId) {
        return userChannels.containsKey(userId);
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        super.channelActive(ctx);
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object req) throws Exception {
        if (req instanceof FullHttpRequest) {
            String code = getCodeFromRequest(ctx); // 从请求中提取 code
            String userId = getOpenid(APPID, SECRET, code);    // 验证 code 获取 openid

            userChannels.put(userId, ctx.channel());
            ctx.channel().attr(WS_USER_ID_KEY).set(userId);

            System.out.println("客户端连接成功,用户id:" + userId);
            // 由于这里还在处理握手请求也就是建立连接,所以需要延迟发送离线消息
            new Thread(() -> {
                try {
                    Thread.sleep(50);
                    OFFLINE_MSGS.getOrDefault(userId, new LinkedList<>())
                            .forEach(ctx::writeAndFlush);
                    OFFLINE_MSGS.remove(userId);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            }).start();
        } else if (req instanceof TextWebSocketFrame ) {
            this.channelRead0(ctx, (TextWebSocketFrame) req);
        } else {
            ctx.fireChannelRead(req);
        }
    }

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame frame) throws Exception {
        if (frame instanceof TextWebSocketFrame) {
            MessageEntity message = validateMessage(ctx.channel().attr(WS_USER_ID_KEY).get(), (TextWebSocketFrame) frame);
            saveMessage(message);
            sendOrStoreMessage(message.getToUserId(), frame);
        } else {
            ctx.close();
        }
    }

    // 处理连接断开
    @Override
    public void channelInactive(ChannelHandlerContext ctx) {
        System.out.println("客户端断开连接,用户id:" + ctx.channel().attr(WS_USER_ID_KEY).get());
        Channel channel = ctx.channel();
        for (Map.Entry<String, Channel> entry : userChannels.entrySet()) {
            if (entry.getValue() == channel) {
                userChannels.remove(entry.getKey());
                break;
            }
        }
    }

    private MessageEntity validateMessage(String userId, TextWebSocketFrame textFrame) {
        String message = textFrame.text();
        try {
            JsonObject json = JsonParser.parseString(message).getAsJsonObject();
            String toUserId = json.get("toUserId").getAsString();
            String content = json.get("content").getAsString();
            String type = json.get("type").getAsString();

            if (type.equals("text") || type.equals("image")) {
                return new MessageEntity(userId, toUserId, content, type);
            } else {
                throw new AppException("非法的消息类型!");
            }

        } catch (Exception e) {
            throw new AppException("非法的消息格式!");
        }
    }

    private void sendOrStoreMessage(String toUserId, WebSocketFrame message) {
        if (isUserOnline(toUserId)) {
            Channel targetChannel = userChannels.get(toUserId);
            if (targetChannel != null && targetChannel.isActive()) {
                targetChannel.writeAndFlush(message.retain());
            }
        } else {
            // 存储原始WebSocketFrame(需保留引用)
            OFFLINE_MSGS.computeIfAbsent(toUserId, k -> new LinkedList<>())
                    .add(message.retain());
        }
    }

    private void saveMessage(MessageEntity message) {
        threadPoolExecutor.execute(() -> {
            messageRepository.saveMessage(message);
        });
    }

    private boolean isUserOnline(String userId) {
        return userChannels.containsKey(userId);
    }

    private String getCodeFromRequest(ChannelHandlerContext ctx) {
        String code = ctx.channel().attr(WS_TOKEN_KEY).get();
        // 检查 code 参数是否存在且非空
        if (code == null || code.isEmpty()) {
            throw new IllegalArgumentException("WebSocket token  is missing or empty");
        }
        return code;
    }

    private String getOpenid(String appid, String secret, String code) {
        Map<String, String> paramMap = new HashMap<>();
        paramMap.put("appid", appid);
        paramMap.put("secret", secret);
        paramMap.put("js_code", code);
        paramMap.put("grant_type", "authorization_code");
        String result = HttpClientUtil.doGet(WX_LOGIN, paramMap);

        //获取请求结果
        JSONObject jsonObject = JSON.parseObject(result);
        String openid = jsonObject.getString("openid");
        //判断openid是否存在
        if (StringUtils.isEmpty(openid)) {
            throw new WxException(jsonObject.getString("errcode"), jsonObject.getString("errmsg"));
        }

        return openid;
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        if (cause instanceof AppException appCause) {
            log.error("AppException caught: {}", appCause.getInfo());
        } else if (cause instanceof WxException wxCause) {
            log.error("WxException caught: {}", wxCause.getMessage());
        } else {
            log.error("Exception caught: {}", cause.getMessage(), cause);
        }
        ctx.close(); // 建议关闭发生异常的连接
    }

}

连接及消息格式:

bash 复制代码
wss://127.0.0.1:21611/ws?code=xxxxxx

{
    "toUserId": "1001",
    "type": "text",
    "content": "Hello World!"
}

规定了type只有text和image两种,text为文本content,image则为Base64编码格式

本文功能实现较为简陋,demo内容仅供参考,可能有注释错误或设计不合理的地方

相关推荐
一只叫煤球的猫25 分钟前
写代码很6,面试秒变菜鸟?不卖课,面试官视角走心探讨
前端·后端·面试
bobz9651 小时前
tcp/ip 中的多路复用
后端
bobz9651 小时前
tls ingress 简单记录
后端
皮皮林5512 小时前
IDEA 源码阅读利器,你居然还不会?
java·intellij idea
你的人类朋友2 小时前
什么是OpenSSL
后端·安全·程序员
bobz9652 小时前
mcp 直接操作浏览器
后端
前端小张同学5 小时前
服务器部署 gitlab 占用空间太大怎么办,优化思路。
后端
databook5 小时前
Manim实现闪光轨迹特效
后端·python·动效
武子康6 小时前
大数据-98 Spark 从 DStream 到 Structured Streaming:Spark 实时计算的演进
大数据·后端·spark
该用户已不存在6 小时前
6个值得收藏的.NET ORM 框架
前端·后端·.net