Netty封装Websocket并实现动态路由

引言

关于Netty和Websocket的介绍我就不多讲了,网上一搜一大片。现如今AI的趋势发展很热门,长连接对话也是会经常接触到的,使用Websocket实现长连接,那么很多人为了快速开发快速集成就会使用spring-boot-starter-websocket 依赖快速实现,但是注意该实现是基于tomcat的,有性能瓶颈的,那么就又有人说了那我使用spring-webflux (底层容器就是netty),但是用的人很少,那能不能单独一个项目来处理长连接呢?
那肯定有的,基于netty自己实现

怎么使用?

其实怎么说呢,netty实现的websocket网上也是一大把,但是终究是个demo,网上也是很多人问:怎么实现动态多路由,像mvc一样多个路由呢?用过spring-boot-starter-websocket都知道,搭配onOpen、onMesssage、onClose注解轻松使用,使用@ServerEndpoint实现多路由,那么netty怎么实现呢(netty本身是不支持的,都是需要自己去实现)?

我们要明白netty的定位,高性能、异步事件驱动的网络应用框架​​,主要用于快速开发可维护的高性能协议服务器和客户端,提供底层网络 I/O 的抽象,全是抽象,需要自己去自定义实现。

正题开始

废话就不多说了,直接上代码,如果文章对你有帮助,请3连!!!

maven依赖,只用netty和spring,不需要web容器:

java 复制代码
        <dependency>
            <groupId>io.netty</groupId>
            <artifactId>netty-all</artifactId>
            <version>4.1.65.Final</version>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-autoconfigure</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>

        <dependency>
            <groupId>com.alibaba.cloud</groupId>
            <artifactId>spring-cloud-starter-alibaba-nacos-config</artifactId>
        </dependency>
        <dependency>
            <groupId>com.alibaba.cloud</groupId>
            <artifactId>spring-cloud-starter-alibaba-nacos-discovery</artifactId>
        </dependency>

1、核心接口WebSocketHandler

java 复制代码
//Websocket处理器接口,入口WebSocketGlobalIntercept将会封装连接处理逻辑和事件通知逻辑,专注实现业务
public interface WebSocketHandler {

    /**
     * 当首次握手连接成功后(升级为websocket时)将会触发,可用于连接合法性处理
     *
     * @param session 会话对象
     */
    void onOpen(WebSocketSession session);

    /**
     * 当触发websocket消息帧时,将会通知该方法
     *
     * @param session 会话对象
     * @param message 消息对象:文本、二进制数据等
     */
    void onMessage(WebSocketSession session, WebSocketFrame message);

    /**
     * 当连接关闭时将通知该方法,需要释放资源并且清理session
     *
     * @param session 会话对象
     */
    void onClose(WebSocketSession session);

    /**
     * 当连接过程中、通信过程中出现异常将通知该方法
     *
     * @param session 会话对象
     * @param error   异常信息
     */
    void onError(WebSocketSession session, Throwable error);
}

2、会话Session类

java 复制代码
public class WebSocketSession {

    /**
     * netty channelContext 对象,注意此对象不可序列化
     */
    private ChannelHandlerContext channelContext;

    /**
     * 请求路由路径
     */
    private String path;

    /**
     * 扩展参数map,如需自定义携带参数时即可用于存入
     */
    private Map<String, Object> attributes = new ConcurrentHashMap<>();

    /**
     * 只提供一个有参构造方法,channelContext和 path不能为空
     *
     * @param channelContext channel上下文
     * @param path           请求路径
     * @param attributes     扩展参数map
     */
    public WebSocketSession(ChannelHandlerContext channelContext, String path, Map<String, Object> attributes) {
        this.channelContext = channelContext;
        this.path = path;
        this.attributes = attributes;
    }

    /**
     * 提供一个静态方法获取对象
     *
     * @param channelContext channel上下文
     * @param path           请求路径
     * @param attributes     扩展参数map
     * @return
     */
    public static WebSocketSession of(ChannelHandlerContext channelContext, String path, Map<String, Object> attributes) {
        return new WebSocketSession(channelContext, path, attributes);
    }

    /**
     * 发送TextWebSocketFrame消息
     *
     * @param text 消息文本
     */
    public void sendText(String text) {
        this.channelContext.writeAndFlush(new TextWebSocketFrame(text));
    }

    /**
     * 发送BinaryWebSocketFrame 二进制消息
     *
     * @param data
     */
    public void sendBinary(ByteBuf data) {
        this.channelContext.writeAndFlush(new BinaryWebSocketFrame(data));
    }

    /**
     * 处理心跳检测ping消息,响应pong
     *
     * @param frame pong消息帧
     */
    public void sendPong(PongWebSocketFrame frame) {
        this.channelContext.writeAndFlush(frame);
    }

    /**
     * 强制关闭连接
     */
    public void close() {
        this.channelContext.close();
    }

    /**
     * 优雅关闭连接,其实就是发送了关闭协议帧
     *
     * @param frame 关闭帧
     */
    public void close(CloseWebSocketFrame frame) {
        this.channelContext.writeAndFlush(frame.retain()).addListener(ChannelFutureListener.CLOSE);
    }

    /**
     * 优雅关闭连接,其实就是发送了关闭协议帧
     *
     * @param reason 关闭原因
     */
    public void close(String reason) {
        CloseWebSocketFrame frame = new CloseWebSocketFrame(
                WebSocketCloseStatus.SERVICE_RESTART,
                reason
        );
        close(frame);
    }

    /**
     * set自定义扩展值
     *
     * @param name  名称
     * @param value 值
     */
    public void setAttribute(String name, Object value) {
        this.attributes.put(name, value);
    }

3、路由注册管理器

java 复制代码
public class WebSocketRouter {

    private static final Map<String, WebSocketHandler> HANDLES_MAP = new ConcurrentHashMap<>();

    /**
     * 添加路由
     *
     * @param path    请求路径
     * @param handler handler对象
     */
    public static void addHandler(String path, WebSocketHandler handler) {
        HANDLES_MAP.put(path, handler);
    }

    /**
     * 获取路由
     *
     * @param path 请求路径
     * @return
     */
    public static WebSocketHandler getHandler(String path) {
        return HANDLES_MAP.get(path);
    }

    /**
     * 判断路由是否存在
     *
     * @param path 请求路径
     * @return
     */
    public static boolean containsPath(String path) {
        return HANDLES_MAP.containsKey(path);
    }
}

4、路由注解

java 复制代码
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Component
public @interface WSHandler {

    /**
     * 不能为空,路由地址
     *
     * @return
     */
    String url();

    /**
     * 文本描述
     *
     * @return
     */
    String desc() default "";
}
java 复制代码
@Configuration
@Slf4j
public class WSHandlerAnnotationImpl {

    @Resource
    private ApplicationContext applicationContext;

    /**
     * 在 Spring 启动后扫描所有带 @WSHandler 的类并初始化
     */
    @PostConstruct
    public void initHandles() {
        // 获取所有带有 @WSHandler 注解的 Bean
        Map<String, Object> beans = applicationContext.getBeansWithAnnotation(WSHandler.class);
        beans.forEach((beanName, beanInstance) -> {
            Class<?> aClass = beanInstance.getClass();
            WSHandler annotation = AnnotationUtils.findAnnotation(aClass, WSHandler.class);
            if (annotation != null) {
                String key = annotation.url();
                WebSocketRouter.addHandler(key, (WebSocketHandler) beanInstance);
                log.info("[Register WS handle] key: {}, handle name: {} register success.", key, aClass.getName());
            }
        });
    }
}

5、netty启动类

java 复制代码
@Component
@Slf4j
public class NettyStartServer {

    /**
     * boss NioEventLoopGroup 处理连接事件
     */
    private NioEventLoopGroup bossGroup;

    /**
     * work NioEventLoopGroup 处理I/O事件
     */
    private NioEventLoopGroup workerGroup;

    /**
     * 引导类
     */
    private ServerBootstrap serverBootstrap;

    /**
     *异步I/O操作的结果
     */
    private ChannelFuture future;

    /**
     * Websocket 消息处理器
     */
    @Resource
    private WebSocketGlobalIntercept webSocketGlobalIntercept;

    @Value("${netty.port}")
    private int port;

    @PostConstruct
    public void start() throws Exception {
        //boss线程,处理连接事件,通常是1个线程
        bossGroup = new NioEventLoopGroup(1);
        //工作线程,处理io事件,默认是机器的cpu*2,但是docker部署需要指定,以免影响性能
        workerGroup = new NioEventLoopGroup();
        serverBootstrap = new ServerBootstrap();
        //初始化NioEventLoopGroup(线程池的 Executor),它将会创建Eventloop(单线程的 Executor,处理多个channel,也就是说一个线程能够处理多个请求),
        serverBootstrap.group(bossGroup, workerGroup)
                //指定I/O模型
                .channel(NioServerSocketChannel.class)
                // 连接队列大小
                .option(ChannelOption.SO_BACKLOG, 1024)
                // 禁用Nagle算法
                .childOption(ChannelOption.TCP_NODELAY, true)
                // 保持长连接
                .childOption(ChannelOption.SO_KEEPALIVE, true)
                //初始化Channel配置
                .childHandler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    protected void initChannel(SocketChannel ch) throws Exception {
                        //公共处理
                        //设置log监听器,并且日志级别为debug,方便观察运行流程,上线需调为ERROR
                        ch.pipeline().addLast("logging", new LoggingHandler("INFO"));
                        //websocket协议本身是基于Http协议的,设置解码器
                        ch.pipeline().addLast("http-codec", new HttpServerCodec());
                        //聚合 HTTP 请求(支持 WebSocket 握手),使用websocket会用到
                        ch.pipeline().addLast("aggregator", new HttpObjectAggregator(65536));
                        //用于大数据的分区传输
                        ch.pipeline().addLast("http-chunked", new ChunkedWriteHandler());
                        //配置handle
                        ch.pipeline().addLast("handler", webSocketGlobalIntercept);
                    }
                });
        future = serverBootstrap.bind(port).sync();
        log.info("[Netty websocket server] startup success, port: {}", port);
    }

    @PreDestroy
    public void stop() {
        try {
            if (future != null) {
                future.channel().close().sync();
            }
            if (bossGroup != null) {
                bossGroup.shutdownGracefully().sync();
            }
            if (workerGroup != null) {
                workerGroup.shutdownGracefully().sync();
            }
            log.info("[Netty websocket server] shutdown success, port: {}", port);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            log.error("[Netty websocket server] shutdown exception: {}, port: {}, ", e, port);
        }
    }
}

6、拦截类

java 复制代码
@ChannelHandler.Sharable
@Component
@Slf4j
public class WebSocketGlobalIntercept extends SimpleChannelInboundHandler<Object> {

    /**
     * Channel每次收到消息将会回调该方法,包括连接请求、帧消息
     *
     * @param ctx   channel上下文,也就是session
     * @param frame
     * @throws Exception
     */
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, Object frame) throws Exception {
        // 根据帧类型分发处理
        if (frame instanceof TextWebSocketFrame) {
            handleTextFrame(ctx, (TextWebSocketFrame) frame);
        } else if (frame instanceof BinaryWebSocketFrame) {
            handleBinaryFrame(ctx, (BinaryWebSocketFrame) frame);
        } else if (frame instanceof PingWebSocketFrame) {
            handlePingFrame(ctx, (PingWebSocketFrame) frame);
        } else if (frame instanceof CloseWebSocketFrame) {
            handleCloseFrame(ctx, (CloseWebSocketFrame) frame);
        } else if (frame instanceof FullHttpMessage) {
            //请求需要携带token问号拼接
            FullHttpRequest request = (FullHttpRequest) frame;
            String uri = request.uri();
            UriComponentUtil uriComponentUtil = UriComponentUtil.fromUri(uri);
            String path = uriComponentUtil.getPath();
            //校验请求路径是否合法
            if (!WebSocketRouter.containsPath(path)) {
                sendHttpResponse(ctx, request,
                        new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.NOT_FOUND));
                return;
            }
            //判断query参数是否合法
            boolean tokenStatus = uriComponentUtil.hasQueryParam("token");
            if (!tokenStatus) {
                //没有携带token
                sendHttpResponse(ctx, request,
                        new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED));
                return;
            }
            //​​手动创建握手工厂​​(WebSocketServerHandshakerFactory),配置 WebSocket 地址、子协议和扩展。
            WebSocketServerHandshakerFactory factory = new WebSocketServerHandshakerFactory("ws://localhost:6010", null, false);
            //创建一个 ​​WebSocket 握手处理器(检查客户端请求是否合法。)
            WebSocketServerHandshaker webSocketServerHandshaker = factory.newHandshaker(request);
            if (webSocketServerHandshaker == null) {
                // 握手失败,返回 HTTP 错误
                WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
            } else {
                //完成 WebSocket 握手,将 ​​HTTP 连接升级为 WebSocket 连接​​,连接升级后,后续通信将使用 ​​WebSocket 帧​​(TextWebSocketFrame、BinaryWebSocketFrame 等),而不是 HTTP。
                webSocketServerHandshaker.handshake(ctx.channel(), request);
                WebSocketHandler handler = WebSocketRouter.getHandler(path);
                //构建session对象
                WebSocketSession session = WebSocketSession.of(ctx, path, uriComponentUtil.getQueryParams());
                ctx.channel().attr(AttributeKey.<WebSocketSession>valueOf("session")).set(session);
                //握手成功后将会事件通知onOpen
                handler.onOpen(session);
            }
        }
    }

    /**
     * 异常捕获
     *
     * @param ctx   channel上下文对象
     * @param cause 异常对象
     * @throws Exception
     */
    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        WebSocketSession session = ctx.channel().attr(AttributeKey.<WebSocketSession>valueOf("session")).get();
        if (session != null) {
            WebSocketHandler handler = WebSocketRouter.getHandler(session.getPath());
            if (handler != null) {
                handler.onError(session, cause);
            }
        }
        super.exceptionCaught(ctx, cause);
    }

    /**
     * 响应http请求
     *
     * @param ctx      channel上下文对象
     * @param request  请求对象
     * @param response 响应对象
     */
    private void sendHttpResponse(ChannelHandlerContext ctx, FullHttpRequest request, DefaultFullHttpResponse response) {
        if (response.status().code() != 200) {
            ByteBuf buf = Unpooled.copiedBuffer(response.status().toString(), CharsetUtil.UTF_8);
            response.content().writeBytes(buf);
            buf.release();
        }
        ChannelFuture future = ctx.channel().writeAndFlush(response);
        if (response.status().code() != 200) {
            future.addListener(ChannelFutureListener.CLOSE);
        }
    }

    /**
     * 处理文本消息
     *
     * @param ctx   channel上下文
     * @param frame 帧
     */
    private void handleTextFrame(ChannelHandlerContext ctx, TextWebSocketFrame frame) {
        message(ctx, frame);
    }

    /**
     * 处理二进制消息
     *
     * @param ctx   channel上下文
     * @param frame 帧
     */
    private void handleBinaryFrame(ChannelHandlerContext ctx, BinaryWebSocketFrame frame) {
        message(ctx, frame);
    }

    /**
     * 公共消息处理
     *
     * @param ctx   channel上下文
     * @param frame websocket帧
     */
    private void message(ChannelHandlerContext ctx, WebSocketFrame frame) {
        WebSocketSession session = ctx.channel().attr(AttributeKey.<WebSocketSession>valueOf("session")).get();
        if (session != null) {
            WebSocketHandler handler = WebSocketRouter.getHandler(session.getPath());
            if (handler != null) {
                handler.onMessage(session, frame);
            }
        }
    }

    /**
     * 处理客户端ping帧消息
     *
     * @param ctx   channel上下文
     * @param frame 帧
     */
    private void handlePingFrame(ChannelHandlerContext ctx, PingWebSocketFrame frame) {
        WebSocketSession session = ctx.channel().attr(AttributeKey.<WebSocketSession>valueOf("session")).get();
        if (session != null) {
            session.sendPong(new PongWebSocketFrame(frame.content().retain()));
        }
    }

    /**
     * 处理客户端关闭请求帧
     *
     * @param ctx   channel上下文
     * @param frame 帧
     */
    private void handleCloseFrame(ChannelHandlerContext ctx, CloseWebSocketFrame frame) {
        ctx.close();
    }

    /**
     * 读完消息需要释放内存
     *
     * @param ctx channel上下文
     * @throws Exception
     */
    @Override
    public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
        ctx.flush();
    }

    /**
     * 活跃状态,连接成功
     *
     * @param ctx channel上下文
     * @throws Exception
     */
    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        log.debug("Client joins the connection {}: ", ctx.channel().toString());
        super.channelActive(ctx);
    }

    /**
     * 断开连接(只要客户端断开连接将会通知该方法,比如:主动断开触发close方法、断网等)
     *
     * @param ctx channel上下文
     * @throws Exception
     */
    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        WebSocketSession session = ctx.channel().attr(AttributeKey.<WebSocketSession>valueOf("session")).get();
        if (session != null) {
            WebSocketHandler handler = WebSocketRouter.getHandler(session.getPath());
            if (handler != null) {
                handler.onClose(session);
            }
        }
        super.channelInactive(ctx);
    }
}

7、Handler实现

java 复制代码
@WSHandler(url = "/chat")
@Slf4j
public class ChatAssistantHandler implements WebSocketHandler {

    @Override
    public void onOpen(WebSocketSession session) {
        log.info("触发onOpen");
    }

    @Override
    public void onMessage(WebSocketSession session, WebSocketFrame message) {
        log.info("触发onMessage");
    }

    @Override
    public void onClose(WebSocketSession session) {
        log.info("触发onClose");
        session.close("关闭会话");
    }

    @Override
    public void onError(WebSocketSession session, Throwable error) {
        log.info("触发onError");
        session.close("关闭会话"+error.getMessage());
    }
}

ws://localhost:6010/chat 测试一下,这就不发出来了,可以自己去玩。

8、UriComponentUtil

java 复制代码
public final class UriComponentUtil {

    private final String path;
    private final Map<String, List<String>> queryParams;

    // 私有构造器,强制使用工厂方法
    private UriComponentUtil(String path, Map<String, List<String>> queryParams) {
        this.path = path;
        this.queryParams = queryParams;
    }

    /**
     * 从URI字符串创建UriComponents实例
     */
    public static UriComponentUtil fromUri(String uri) {
        int queryStart = uri.indexOf('?');
        String path = queryStart == -1 ? uri : uri.substring(0, queryStart);

        Map<String, List<String>> queryParams = new HashMap<>();
        if (queryStart != -1 && queryStart < uri.length() - 1) {
            parseQueryParams(uri.substring(queryStart + 1), queryParams);
        }

        return new UriComponentUtil(path, queryParams);
    }

    /**
     * 解析查询参数字符串
     */
    private static void parseQueryParams(String queryString, Map<String, List<String>> output) {
        String[] pairs = queryString.split("&");
        for (String pair : pairs) {
            int eq = pair.indexOf('=');
            String key = eq == -1 ? pair : pair.substring(0, eq);
            String value = eq == -1 ? "" : pair.substring(eq + 1);
            output.computeIfAbsent(key, k -> new ArrayList<>()).add(value);
        }
    }

    /**
     * 获取路径部分
     */
    public String getPath() {
        return path;
    }

    /**
     * 获取查询参数(单值形式,只返回每个参数的第一个值)
     */
    public Map<String, Object> getQueryParams() {
        Map<String, Object> singleValueParams = new HashMap<>();
        queryParams.forEach((key, values) -> {
            if (!values.isEmpty()) {
                singleValueParams.put(key, values.get(0));
            }
        });
        return singleValueParams;
    }

    /**
     * 获取所有查询参数(多值形式)
     */
    public Map<String, List<String>> getAllQueryParams() {
        return queryParams;
    }

    /**
     * 获取指定参数的值(第一个值)
     */
    public String getQueryParam(String name) {
        List<String> values = queryParams.get(name);
        return (values != null && !values.isEmpty()) ? values.get(0) : null;
    }

    /**
     * 检查是否包含指定参数
     */
    public boolean hasQueryParam(String name) {
        return queryParams.containsKey(name);
    }

}

9、配置文件

yaml 复制代码
spring:
  main:
    #禁止启动web容器
    web-application-type: none
netty:
  port: 6010

各位同学CV的爽不爽,给三连同学们!