Springboot整合Netty简单实现1对1聊天(vx小程序服务端)

本文功能实现较为简陋,demo内容仅供参考,有不足之处还请指正。

背景

一个小项目,用于微信小程序的服务端,需要实现小程序端可以和他人1对1聊天

实现功能

Websocket、心跳检测、消息持久化、离线消息存储

Netty配置类

java 复制代码
/**
 * @author Aseubel
 */
@Component
@Slf4j
@EnableConfigurationProperties(NettyServerConfigProperties.class)
public class NettyServerConfig {

    private ChannelFuture serverChannelFuture;

    // 心跳间隔(秒)
    private static final int HEARTBEAT_INTERVAL = 15;
    // 读超时时间
    private static final int READ_TIMEOUT = HEARTBEAT_INTERVAL * 2;

    // 使用线程池管理
    private final EventLoopGroup bossGroup = new NioEventLoopGroup(1);
    private final EventLoopGroup workerGroup = new NioEventLoopGroup();

    private final NettyServerConfigProperties properties;

    // 由于在后面的handler中有依赖注入类,所以要通过springboot的ApplicationContext来获取Bean实例
    @Autowired
    private ApplicationContext applicationContext;

    public NettyServerConfig(NettyServerConfigProperties properties) {
        this.properties = properties;
    }

    @PostConstruct
    public void startNettyServer() {
        // 使用独立线程启动Netty服务
        new Thread(() -> {
            try {
                ServerBootstrap bootstrap = new ServerBootstrap();
                bootstrap.group(bossGroup, workerGroup)
                        .channel(NioServerSocketChannel.class)
                        .childHandler(new ChannelInitializer<Channel>() {
                            @Override
                            protected void initChannel(Channel ch) throws Exception {
                                ChannelPipeline pipeline = ch.pipeline();

                                SSLContext sslContext = SslUtil.createSSLContext("PKCS12",
                                        properties.getSslPath(), properties.getSslPassword());
                                // SSLEngine 此类允许使用ssl安全套接层协议进行安全通信
                                SSLEngine engine = sslContext.createSSLEngine();
                                engine.setUseClientMode(false);

                                pipeline.addLast(new SslHandler(engine)); // 设置SSL
                                pipeline.addLast(new HttpServerCodec());
                                pipeline.addLast(new HttpObjectAggregator(10 * 1024 * 1024));// 最大10MB
                                pipeline.addLast(new ChunkedWriteHandler());
                                pipeline.addLast(new HttpHandler());
                                // 只有text和binarytext的帧能经过WebSocketServerProtocolHandler,所以心跳检测这两个都得放前面
                                pipeline.addLast(new IdleStateHandler(READ_TIMEOUT, 0, 0, TimeUnit.SECONDS));
                                pipeline.addLast(new HeartbeatHandler());
                                pipeline.addLast(new WebSocketServerProtocolHandler("/ws", null, true, 10 * 1024 * 1024));
                                pipeline.addLast(applicationContext.getBean(MessageHandler.class));
                                pipeline.addLast(new ChannelInboundHandlerAdapter() {
                                    @Override
                                    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
                                        // 统一处理所有未被前面handler捕获的异常
                                        log.error("全局异常捕获: {}", cause.getMessage());
                                        ctx.channel().close();
                                    }
                                });
                            }
                        });

                serverChannelFuture = bootstrap.bind(properties.getPort()).sync();

                // 保持通道开放
                serverChannelFuture.channel().closeFuture().sync();
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }).start();
    }

    @PreDestroy
    public void stopNettyServer() {
        // 优雅关闭
        if (serverChannelFuture != null) {
            serverChannelFuture.channel().close();
        }
        bossGroup.shutdownGracefully();
        workerGroup.shutdownGracefully();
    }

}

Handler

心跳检测

java 复制代码
/**
 * @author Aseubel
 */
public class HeartbeatHandler extends ChannelInboundHandlerAdapter {
    private static final int HEARTBEAT_INTERVAL = 15; // 心跳间隔(秒)
    private static final int MAX_MISSED_HEARTBEATS = 2; // 允许丢失的心跳次数
    // 记录每个连接的丢失心跳次数
    private final Map<ChannelId, Integer> missedHeartbeats = new ConcurrentHashMap<>();

    @Override
    public void channelActive(ChannelHandlerContext ctx) {
        // 添加 IdleStateHandler 触发读空闲事件
        ctx.pipeline().addLast(new IdleStateHandler(HEARTBEAT_INTERVAL * MAX_MISSED_HEARTBEATS, 0, 0));
        scheduleHeartbeat(ctx);
    }

    private void scheduleHeartbeat(ChannelHandlerContext ctx) {
        ctx.executor().scheduleAtFixedRate(() -> {
            if (ctx.channel().isActive()) {
                ctx.writeAndFlush(new PingWebSocketFrame(Unpooled.copiedBuffer("HEARTBEAT", CharsetUtil.UTF_8)));
                // 记录丢失的心跳次数
                missedHeartbeats.compute(ctx.channel().id(), (k, v) -> v == null ? 1 : v + 1);
            }
        }, HEARTBEAT_INTERVAL, HEARTBEAT_INTERVAL, TimeUnit.SECONDS);
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) {
        if (msg instanceof PongWebSocketFrame) {
            // 收到 Pong 后重置丢失计数
            missedHeartbeats.remove(ctx.channel().id());
            ctx.fireChannelRead(msg); // 传递消息给后续处理器
        } else {
            ctx.fireChannelRead(msg);
        }
    }

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
        if (evt instanceof IdleStateEvent) {
            int missed = missedHeartbeats.getOrDefault(ctx.channel().id(), 0);
            if (missed >= MAX_MISSED_HEARTBEATS) {
                // 超过最大丢失次数,关闭连接
                System.out.println("连接超时,关闭连接" + ctx.channel().id().asLongText());
                ctx.close();
                cleanOfflineResources(ctx.channel());
            }
        }
    }

    private void cleanOfflineResources(Channel channel) {
        MessageHandler.removeUserChannel(channel);
        missedHeartbeats.remove(channel.id());
    }
}

处理http请求,建立连接

java 复制代码
/**
 * @author Aseubel
 * @description 处理websocket连接请求,将code参数存入channel的attribute中
 * @date 2025-02-21 15:34
 */
public class HttpHandler extends ChannelInboundHandlerAdapter {

    public static final AttributeKey<String> WS_TOKEN_KEY = AttributeKey.valueOf("code");


    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        // 判断是否是连接请求
        if (msg instanceof FullHttpRequest) {
            FullHttpRequest request = (FullHttpRequest) msg;
            try {
                QueryStringDecoder decoder = new QueryStringDecoder(request.uri());
                ctx.channel().attr(WS_TOKEN_KEY).set(decoder.parameters().get("code").get(0));
            } catch (Exception e) {
                throw new AppException("非法的websocket连接请求");
            }
            // 将 FullHttpRequest 转发到 MessageHandler
            ctx.fireChannelRead(request);
            // 重新设置 uri,将请求转发到 websocket handler,否则无法成功建立连接
            request.setUri("/ws");
        }
        // 消息直接交给下一个 handler
        super.channelRead(ctx, msg);
    }

}

消息处理

java 复制代码
/**
 * @author Aseubel
 * @description 处理 WebSocket 消息
 * @date 2025-02-21 15:33
 */
@Component
@Slf4j
@Sharable
public class MessageHandler extends SimpleChannelInboundHandler<WebSocketFrame> {

    public static final AttributeKey<String> WS_TOKEN_KEY = AttributeKey.valueOf("code");
    public static final AttributeKey<String> WS_USER_ID_KEY = AttributeKey.valueOf("userId");

    private static final Map<String, Queue<WebSocketFrame>> OFFLINE_MSGS = new ConcurrentHashMap<>();

    private static final Map<String, Channel> userChannels = new ConcurrentHashMap<>();

    @Autowired
    private ThreadPoolTaskExecutor threadPoolExecutor;

    @Resource
    private IMessageRepository messageRepository;

    // 提供受控的访问方法
    public static void removeUserChannel(Channel channel) {
        userChannels.values().remove(channel);
    }

    public static boolean containsUser(String userId) {
        return userChannels.containsKey(userId);
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        super.channelActive(ctx);
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object req) throws Exception {
        if (req instanceof FullHttpRequest) {
            String code = getCodeFromRequest(ctx); // 从请求中提取 code
            String userId = getOpenid(APPID, SECRET, code);    // 验证 code 获取 openid

            userChannels.put(userId, ctx.channel());
            ctx.channel().attr(WS_USER_ID_KEY).set(userId);

            System.out.println("客户端连接成功,用户id:" + userId);
            // 由于这里还在处理握手请求也就是建立连接,所以需要延迟发送离线消息
            new Thread(() -> {
                try {
                    Thread.sleep(50);
                    OFFLINE_MSGS.getOrDefault(userId, new LinkedList<>())
                            .forEach(ctx::writeAndFlush);
                    OFFLINE_MSGS.remove(userId);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            }).start();
        } else if (req instanceof TextWebSocketFrame ) {
            this.channelRead0(ctx, (TextWebSocketFrame) req);
        } else {
            ctx.fireChannelRead(req);
        }
    }

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame frame) throws Exception {
        if (frame instanceof TextWebSocketFrame) {
            MessageEntity message = validateMessage(ctx.channel().attr(WS_USER_ID_KEY).get(), (TextWebSocketFrame) frame);
            saveMessage(message);
            sendOrStoreMessage(message.getToUserId(), frame);
        } else {
            ctx.close();
        }
    }

    // 处理连接断开
    @Override
    public void channelInactive(ChannelHandlerContext ctx) {
        System.out.println("客户端断开连接,用户id:" + ctx.channel().attr(WS_USER_ID_KEY).get());
        Channel channel = ctx.channel();
        for (Map.Entry<String, Channel> entry : userChannels.entrySet()) {
            if (entry.getValue() == channel) {
                userChannels.remove(entry.getKey());
                break;
            }
        }
    }

    private MessageEntity validateMessage(String userId, TextWebSocketFrame textFrame) {
        String message = textFrame.text();
        try {
            JsonObject json = JsonParser.parseString(message).getAsJsonObject();
            String toUserId = json.get("toUserId").getAsString();
            String content = json.get("content").getAsString();
            String type = json.get("type").getAsString();

            if (type.equals("text") || type.equals("image")) {
                return new MessageEntity(userId, toUserId, content, type);
            } else {
                throw new AppException("非法的消息类型!");
            }

        } catch (Exception e) {
            throw new AppException("非法的消息格式!");
        }
    }

    private void sendOrStoreMessage(String toUserId, WebSocketFrame message) {
        if (isUserOnline(toUserId)) {
            Channel targetChannel = userChannels.get(toUserId);
            if (targetChannel != null && targetChannel.isActive()) {
                targetChannel.writeAndFlush(message.retain());
            }
        } else {
            // 存储原始WebSocketFrame(需保留引用)
            OFFLINE_MSGS.computeIfAbsent(toUserId, k -> new LinkedList<>())
                    .add(message.retain());
        }
    }

    private void saveMessage(MessageEntity message) {
        threadPoolExecutor.execute(() -> {
            messageRepository.saveMessage(message);
        });
    }

    private boolean isUserOnline(String userId) {
        return userChannels.containsKey(userId);
    }

    private String getCodeFromRequest(ChannelHandlerContext ctx) {
        String code = ctx.channel().attr(WS_TOKEN_KEY).get();
        // 检查 code 参数是否存在且非空
        if (code == null || code.isEmpty()) {
            throw new IllegalArgumentException("WebSocket token  is missing or empty");
        }
        return code;
    }

    private String getOpenid(String appid, String secret, String code) {
        Map<String, String> paramMap = new HashMap<>();
        paramMap.put("appid", appid);
        paramMap.put("secret", secret);
        paramMap.put("js_code", code);
        paramMap.put("grant_type", "authorization_code");
        String result = HttpClientUtil.doGet(WX_LOGIN, paramMap);

        //获取请求结果
        JSONObject jsonObject = JSON.parseObject(result);
        String openid = jsonObject.getString("openid");
        //判断openid是否存在
        if (StringUtils.isEmpty(openid)) {
            throw new WxException(jsonObject.getString("errcode"), jsonObject.getString("errmsg"));
        }

        return openid;
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        if (cause instanceof AppException appCause) {
            log.error("AppException caught: {}", appCause.getInfo());
        } else if (cause instanceof WxException wxCause) {
            log.error("WxException caught: {}", wxCause.getMessage());
        } else {
            log.error("Exception caught: {}", cause.getMessage(), cause);
        }
        ctx.close(); // 建议关闭发生异常的连接
    }

}

连接及消息格式:

bash 复制代码
wss://127.0.0.1:21611/ws?code=xxxxxx

{
    "toUserId": "1001",
    "type": "text",
    "content": "Hello World!"
}

规定了type只有text和image两种,text为文本content,image则为Base64编码格式

本文功能实现较为简陋,demo内容仅供参考,可能有注释错误或设计不合理的地方

相关推荐
Asthenia041210 分钟前
JDBC是什么?/Driver的意义/预编译和普通statement/CallableStatement
后端
大鸡腿同学19 分钟前
人生加减法之道
后端
狗-sin狗25 分钟前
MyBatisSystemException:Parameter ‘item‘ not found.
java·mybatis
一个public的class29 分钟前
MyBatis-Plus的加载和初始化
java·mybatis·springboot
m0_7482340830 分钟前
SpringBoot(整合MyBatis + MyBatis-Plus + MyBatisX插件使用)
spring boot·tomcat·mybatis
动亦定30 分钟前
如何理解java中Stream流?
java
Lojarro31 分钟前
MyBatis-Plus(SpringBoot版)学习第一讲:简介&入门案例
spring boot·学习·mybatis
热爱技术的小曹38 分钟前
Spring6:10 数据校验-Validation
java·开发语言·spring
Asthenia041240 分钟前
深入剖析 Netty 的 ByteBuf:设计思路与 ByteBuffer 的对比
后端
Asthenia04121 小时前
Netty:EventLoop、Channel与ChannelHandller
后端