若依框架集成WebSocket带用户信息认证

一、WebSocket 基础知识

我们平时前后台请求用的最多的就是 HTTP/1.1协议,它有一个缺陷, 通信只能由客户端发起,如果想要不断获取服务器信息就要不断轮询发出请求,那么如果我们需要服务器状态变化的时候能够主动通知客户端就需要用到WebSocket了, WebSocket是一种网络传输协议,同样也位于 OSI 模型的应用层,建立在传输层协议TCP之上。主要特点是 全双工 通信允许数据在两个方向上同时传输,它在能力上相当于两个单工通信方式的结合 例如指 A→B 的同时 B→A ,是瞬时同步的 二进制帧 采用了二进制帧结构,语法、语义与 HTTP 完全不兼容
相比 http/2,WebSocket 更侧重于"实时通信",而 HTTP/2 更侧重于提高传输效率,所以两者的帧结构也有很大的区别 不像 HTTP/2 那样定义流,也就不存在多路复用、优先级等特性 自身就是全双工,也不需要服务器推送 协议名 引入 ws 和 wss 分别代表明文和密文的 websocket 协议,且默认端口使用 80 或 443,几乎与 http 一致
如果只是服务单纯的向客户端推送消息,不涉及到客户端发送消息到服务端,也可以使用Spring WebFlux技术实现,直接建立在当前http连接上,本质上是保持一个http长连接,适合 简单的服务器数据推送的场景,使用服务器推送事件,更轻量更便捷
ps:关于OSI七层模型可以看我的另一篇文章https://stronger.blog.csdn.net/article/details/127725957

二、若依框架集成WebSocket

在若依代码仓库gitee上的Issues,多次有人提到这个事,例如

java 复制代码
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;

/**
 * websocket 配置
 * 
 * @author ruoyi
 */
@Configuration
public class WebSocketConfig
{
    @Bean
    public ServerEndpointExporter serverEndpointExporter()
    {
        return new ServerEndpointExporter();
    }
}

工具类 SemaphoreUtils

java 复制代码
import java.util.concurrent.Semaphore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * 信号量相关处理
 * 
 * @author ruoyi
 */
public class SemaphoreUtils{
    /**
     * SemaphoreUtils 日志控制器
     */
    private static final Logger LOGGER = LoggerFactory.getLogger(SemaphoreUtils.class);

    /**
     * 获取信号量
     * 
     * @param semaphore
     * @return
     */
    public static boolean tryAcquire(Semaphore semaphore)
    {
        boolean flag = false;
        try
        {
            flag = semaphore.tryAcquire();
        }
        catch (Exception e)
        {
            LOGGER.error("获取信号量异常", e);
        }
        return flag;
    }

    /**
     * 释放信号量
     * 
     * @param semaphore
     */
    public static void release(Semaphore semaphore)
    {
        try
        {
            semaphore.release();
        }
        catch (Exception e)
        {
            LOGGER.error("释放信号量异常", e);
        }
    }
}

服务端类WebSocketServer

java 复制代码
import java.util.concurrent.Semaphore;
import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.ServerEndpoint;
import com.lxh.demo.util.SemaphoreUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

/**
 * websocket 消息处理
 * 
 * @author ruoyi
 */
@Component
@ServerEndpoint("/websocket/message")
public class WebSocketServer
{
    /**
     * WebSocketServer 日志控制器
     */
    private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketServer.class);

    /**
     * 默认最多允许同时在线人数100
     */
    public static int socketMaxOnlineCount = 100;

    private static Semaphore socketSemaphore = new Semaphore(socketMaxOnlineCount);

    /**
     * 连接建立成功调用的方法
     */
    @OnOpen
    public void onOpen(Session session) throws Exception{
        boolean semaphoreFlag = false;
        // 尝试获取信号量
        semaphoreFlag = SemaphoreUtils.tryAcquire(socketSemaphore);
        if (!semaphoreFlag)
        {
            // 未获取到信号量
            LOGGER.error("\n 当前在线人数超过限制数- {}", socketMaxOnlineCount);
            WebSocketUsers.sendMessageToUserByText(session, "当前在线人数超过限制数:" + socketMaxOnlineCount);
            session.close();
        }
        else
        {
            // 添加用户
            WebSocketUsers.put(session.getId(), session);
            LOGGER.info("\n 建立连接 - {}", session);
            LOGGER.info("\n 当前人数 - {}", WebSocketUsers.getUsers().size());
            WebSocketUsers.sendMessageToUserByText(session, "连接成功");
        }
    }

    /**
     * 连接关闭时处理
     */
    @OnClose
    public void onClose(Session session)
    {
        LOGGER.info("\n 关闭连接 - {}", session);
        // 移除用户
        WebSocketUsers.remove(session.getId());
        // 获取到信号量则需释放
        SemaphoreUtils.release(socketSemaphore);
    }

    /**
     * 抛出异常时处理
     */
    @OnError
    public void onError(Session session, Throwable exception) throws Exception
    {
        if (session.isOpen())
        {
            // 关闭连接
            session.close();
        }
        String sessionId = session.getId();
        LOGGER.info("\n 连接异常 - {}", sessionId);
        LOGGER.info("\n 异常信息 - {}", exception);
        // 移出用户
        WebSocketUsers.remove(sessionId);
        // 获取到信号量则需释放
        SemaphoreUtils.release(socketSemaphore);
    }

    /**
     * 服务器接收到客户端消息时调用的方法
     */
    @OnMessage
    public void onMessage(String message, Session session)
    {
        String msg = message.replace("你", "我").replace("吗", "");
        WebSocketUsers.sendMessageToUserByText(session, msg);
    }
}

WebSocketUsers工具类

java 复制代码
import java.io.IOException;
import java.util.Collection;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import javax.websocket.Session;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * websocket 客户端用户集
 * 
 * @author ruoyi
 */
public class WebSocketUsers
{
    /**
     * WebSocketUsers 日志控制器
     */
    private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketUsers.class);

    /**
     * 用户集
     */
    private static Map<String, Session> USERS = new ConcurrentHashMap<String, Session>();

    /**
     * 存储用户
     *
     * @param key 唯一键
     * @param session 用户信息
     */
    public static void put(String key, Session session)
    {
        USERS.put(key, session);
    }

    /**
     * 移除用户
     *
     * @param session 用户信息
     *
     * @return 移除结果
     */
    public static boolean remove(Session session)
    {
        String key = null;
        boolean flag = USERS.containsValue(session);
        if (flag)
        {
            Set<Map.Entry<String, Session>> entries = USERS.entrySet();
            for (Map.Entry<String, Session> entry : entries)
            {
                Session value = entry.getValue();
                if (value.equals(session))
                {
                    key = entry.getKey();
                    break;
                }
            }
        }
        else
        {
            return true;
        }
        return remove(key);
    }

    /**
     * 移出用户
     *
     * @param key 键
     */
    public static boolean remove(String key)
    {
        LOGGER.info("\n 正在移出用户 - {}", key);
        Session remove = USERS.remove(key);
        if (remove != null)
        {
            boolean containsValue = USERS.containsValue(remove);
            LOGGER.info("\n 移出结果 - {}", containsValue ? "失败" : "成功");
            return containsValue;
        }
        else
        {
            return true;
        }
    }

    /**
     * 获取在线用户列表
     *
     * @return 返回用户集合
     */
    public static Map<String, Session> getUsers()
    {
        return USERS;
    }

    /**
     * 群发消息文本消息
     *
     * @param message 消息内容
     */
    public static void sendMessageToUsersByText(String message)
    {
        Collection<Session> values = USERS.values();
        for (Session value : values)
        {
            sendMessageToUserByText(value, message);
        }
    }

    /**
     * 发送文本消息
     *
     * @param session 缓存
     * @param message 消息内容
     */
    public static void sendMessageToUserByText(Session session, String message)
    {
        if (session != null)
        {
            try
            {
                session.getBasicRemote().sendText(message);
            }
            catch (IOException e)
            {
                LOGGER.error("\n[发送消息异常]", e);
            }
        }
        else
        {
            LOGGER.info("\n[你已离线]");
        }
    }
}

Html 页面代码

html 复制代码
<!DOCTYPE html>
<html lang="zh" xmlns:th="http://www.thymeleaf.org">
<head>
    <meta charset="utf-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <title>测试界面</title>
</head>

<body>

<div>
    <input type="text" style="width: 20%" value="ws://127.0.0.1/websocket/message" id="url">
	<button id="btn_join">连接</button>
	<button id="btn_exit">断开</button>
</div>
<br/>
<textarea id="message" cols="100" rows="9"></textarea> <button id="btn_send">发送消息</button>
<br/>
<br/>
<textarea id="text_content" readonly="readonly" cols="100" rows="9"></textarea>返回内容
<br/>
<br/>
<script th:src="@{/js/jquery.min.js}" ></script>
<script type="text/javascript">
    $(document).ready(function(){
        var ws = null;
        // 连接
        $('#btn_join').click(function() {
        	var url = $("#url").val();
            ws = new WebSocket(url);
            ws.onopen = function(event) {
                $('#text_content').append('已经打开连接!' + '\n');
            }
            ws.onmessage = function(event) {
                $('#text_content').append(event.data + '\n');
            }
            ws.onclose = function(event) {
                $('#text_content').append('已经关闭连接!' + '\n');
            }
        });
        // 发送消息
        $('#btn_send').click(function() {
            var message = $('#message').val();
            if (ws) {
                ws.send(message);
            } else {
                alert("未连接到服务器");
            }
        });
        //断开
        $('#btn_exit').click(function() {
            if (ws) {
                ws.close();
                ws = null;
            }
        });
    })
</script>
</body>
</html>

成功运行后,页面如下

注意此时没有走用户认证,那么就要对路径放行,因为若依框架用的是SpringSecurity,所以找到文件SecurityConfig.java ,进行路径放行

三、用户认证问题

虽然按着上述步骤我们完成了浏览器(客户端)和Java(服务端)的WebSocket通信,但是我们不能限定哪些用户可以连接我们的服务端获取数据,服务端也不知道应该具体给哪些用户发送消息,在我们框架之前交互我们是通过浏览器传递toke 值来实现用户身份确认的,那么我们的WebSocket可不可以也这样呢?

很不幸的是 ws连接是无法像http一样完全自主定义请求头的,给token认证带来了不便,我们大致可以通过以下集中方式完成用户认证

1、将 token 明文携带在 url 中,例如ws://localhost:8080/weggo/websocket/message?Authorization=Bearer+token

2、通过websocket下的子协议来实现,Stomp这个协议来实现,前端采用SocketJs框架来实现对应定制请求头。实现携带authorization=Bearer +token 的需求,这样就可以正常建立连接

3、利用子协议数组,将 token 携带在 protocols 里,var ws = new WebSocket(url, ["token"]);

这样后端在 onOpen 事件中,就可以从 server 中读取 Sec-WebSocket-Protocol 属性来进行 token 的获取,具体可以参考WebScoket构造函数官方文档

TypeScript 复制代码
var aWebSocket = new WebSocket(url [, protocols]);
url
要连接的URL;这应该是WebSocket服务器将响应的URL。
protocols 可选
一个协议字符串或者一个包含协议字符串的数组。这些字符串用于指定子协议,这样单个服务器可以实现多个WebSocket子协议
(例如,您可能希望一台服务器能够根据指定的协议(protocol)处理不同类型的交互)。如果不指定协议字符串,则假定为空字符串。

protocols对应的就是发起ws连接时, 携带在请求头中的Sec-WebSocket-Protocol属性, 服务端可以获取到此属性的值用于通信逻辑(即通信子协议,当然用来进行token认证也是完全没问题的),前端人员在请求头上携带sec-websocket-protocol=Bearer +token后台在请求到达oauth2之前进行拦截,然后将在请求头上添加Authorization=Bearer +token(key首字母大写),然后在响应头(respone)上添加sec-websocket-protocol=Bearer +token(不添加会报错)

方法3部分代码示例

TypeScript 复制代码
//前端
var aWebSocket = new WebSocket(url ['用户token']);

//后端
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
     //这里就是我们所提交的token
     String submitedToken=session.getHandshakeHeaders().get("sec-websocket-protocol").get(0);

     //根据token取得登录用户信息(业务逻辑根据你自己的来处理)
}

另外,如果需要在第一次握手前的时候就取得token,只需要在header里面取得就可以啦

java 复制代码
@Override
public boolean beforeHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Map<String, Object> map) throws Exception {
        System.out.println("准备握手");
        String submitedToken = serverHttpRequest.getHeaders().get("sec-websocket-protocol")
        return true;
}

因为我的项目是APP 移动端与服务端进行交互,所以后来选择了最简单实现的方案一

首先要解决的就是在拦截器获取url 的token 信息,原框架只从head里面获取,所以需要稍加改动

找到TokenService.java文件里的getToken方法,改成如下,这样就可以获取url 中的token 了又不影响原来的Http 请求

java 复制代码
 private String getToken(HttpServletRequest request)
    {
        String token = Optional.ofNullable(request.getHeader(header)).orElse(request.getParameter(header));
        if (StringUtils.isNotEmpty(token) && token.startsWith(Constants.TOKEN_PREFIX))
        {
            token = token.replace(Constants.TOKEN_PREFIX, "");
        }
        return token;
    }

接下来就是需要对我们的WebSocket类进行改造了,为了方便阅读,去除了WebSocketUsers类,添加了类变量webSocketSet来存储客户端对象

java 复制代码
import com.alibaba.fastjson2.JSON;
import com.tongchuang.common.utils.SecurityUtils;
import com.tongchuang.web.mqtt.domain.DeviceInfo;
import io.netty.util.HashedWheelTimer;
import io.netty.util.Timeout;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.core.Authentication;
import org.springframework.stereotype.Component;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;

/**
 * websocket 消息处理
 *
 * @author stronger
 */
@Component
@ServerEndpoint("/websocket/message")
public class WebSocketServer {
    /*========================声明类变量,意在所有实例共享=================================================*/
    /**
     * WebSocketServer 日志控制器
     */
    private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketServer.class);

    /**
     * 默认最多允许同时在线人数100
     */
    public static int socketMaxOnlineCount = 100;

    private static Semaphore socketSemaphore = new Semaphore(socketMaxOnlineCount);

    HashedWheelTimer timer = new HashedWheelTimer(1, TimeUnit.SECONDS, 8);
    /**
     * concurrent包的线程安全Set,用来存放每个客户端对应的MyWebSocket对象。
     */
    private static final CopyOnWriteArraySet<WebSocketServer> webSocketSet = new CopyOnWriteArraySet<>();
    /**
     * 连接数
     */
    private static final AtomicInteger count = new AtomicInteger();

    /*========================声明实例变量,意在每个实例独享=======================================================*/
    /**
     * 与某个客户端的连接会话,需要通过它来给客户端发送数据
     */
    private Session session;
    /**
     * 用户id
     */
    private String sid = "";

    /**
     * 连接建立成功调用的方法
     */
    @OnOpen
    public void onOpen(Session session) throws Exception {
        // 尝试获取信号量
        boolean semaphoreFlag = SemaphoreUtils.tryAcquire(socketSemaphore);
        if (!semaphoreFlag) {
            // 未获取到信号量
            LOGGER.error("\n 当前在线人数超过限制数- {}", socketMaxOnlineCount);
            // 给当前Session 登录用户发送消息
            sendMessageToUserByText(session, "当前在线人数超过限制数:" + socketMaxOnlineCount);
            session.close();
        } else {
            // 返回此会话的经过身份验证的用户,如果此会话没有经过身份验证的用户,则返回null
            Authentication authentication = (Authentication) session.getUserPrincipal();
            SecurityUtils.setAuthentication(authentication);
            String username = SecurityUtils.getUsername();
            this.session = session;
            //如果存在就先删除一个,防止重复推送消息
            for (WebSocketServer webSocket : webSocketSet) {
                if (webSocket.sid.equals(username)) {
                    webSocketSet.remove(webSocket);
                    count.getAndDecrement();
                }
            }
            count.getAndIncrement();
            webSocketSet.add(this);
            this.sid = username;
            LOGGER.info("\n 当前人数 - {}", count);
            sendMessageToUserByText(session, "连接成功");
        }
    }

    /**
     * 连接关闭时处理
     */
    @OnClose
    public void onClose(Session session) {
        LOGGER.info("\n 关闭连接 - {}", session);
        // 移除用户
        webSocketSet.remove(session);
        // 获取到信号量则需释放
        SemaphoreUtils.release(socketSemaphore);
    }

    /**
     * 抛出异常时处理
     */
    @OnError
    public void onError(Session session, Throwable exception) throws Exception {
        if (session.isOpen()) {
            // 关闭连接
            session.close();
        }
        String sessionId = session.getId();
        LOGGER.info("\n 连接异常 - {}", sessionId);
        LOGGER.info("\n 异常信息 - {}", exception);
        // 移出用户
        webSocketSet.remove(session);
        // 获取到信号量则需释放
        SemaphoreUtils.release(socketSemaphore);
    }

    /**
     * 服务器接收到客户端消息时调用的方法
     */
    @OnMessage
    public void onMessage(String message, Session session) {
        Authentication authentication = (Authentication) session.getUserPrincipal();
        LOGGER.info("收到来自" + sid + "的信息:" + message);
        // 实时更新
        this.refresh(sid, authentication);
        sendMessageToUserByText(session, "我收到了你的新消息哦");
    }

    /**
     * 刷新定时任务,发送信息
     */
    private void refresh(String userId, Authentication authentication) {
        this.start(5000L, task -> {
            // 判断用户是否在线,不在线则不用处理,因为在内部无法关闭该定时任务,所以通过返回值在外部进行判断。
            if (WebSocketServer.isConn(userId)) {
                // 因为这里是长链接,不会和普通网页一样,每次发送http 请求可以走拦截器【doFilterInternal】续约,所以需要手动续约
                SecurityUtils.setAuthentication(authentication);
                // 从数据库或者缓存中获取信息,构建自定义的Bean
                DeviceInfo deviceInfo = DeviceInfo.builder().Macaddress("de5a735951ee").Imei("351517175516665")
                        .Battery("99").Charge("0").Latitude("116.402649").Latitude("39.914859").Altitude("80")
                        .Method(SecurityUtils.getUsername()).build();
                // TODO判断数据是否有更新
                // 发送最新数据给前端
                WebSocketServer.sendInfo("JSON", deviceInfo, userId);
                // 设置返回值,判断是否需要继续执行
                return true;
            }
            return false;
        });
    }

    private void start(long delay, Function<Timeout, Boolean> function) {
        timer.newTimeout(t -> {
            // 获取返回值,判断是否执行
            Boolean result = function.apply(t);
            if (result) {
                timer.newTimeout(t.task(), delay, TimeUnit.MILLISECONDS);
            }
        }, delay, TimeUnit.MILLISECONDS);
    }

    /**
     * 判断是否有链接
     *
     * @return
     */
    public static boolean isConn(String sid) {
        for (WebSocketServer item : webSocketSet) {
            if (item.sid.equals(sid)) {
                return true;
            }
        }
        return false;
    }

    /**
     * 群发自定义消息
     * 或者指定用户发送消息
     */
    public static void sendInfo(String type, Object data, @PathParam("sid") String sid) {
        // 遍历WebSocketServer对象集合,如果符合条件就推送
        for (WebSocketServer item : webSocketSet) {
            try {
                //这里可以设定只推送给这个sid的,为null则全部推送
                if (sid == null) {
                    item.sendMessage(type, data);
                } else if (item.sid.equals(sid)) {
                    item.sendMessage(type, data);
                }
            } catch (IOException ignored) {
            }
        }
    }

    /**
     * 实现服务器主动推送
     */
    private void sendMessage(String type, Object data) throws IOException {
        Map<String, Object> result = new HashMap<>();
        result.put("type", type);
        result.put("data", data);
        this.session.getAsyncRemote().sendText(JSON.toJSONString(result));
    }

    /**
     * 实现服务器主动推送-根据session
     */
    public static void sendMessageToUserByText(Session session, String message) {
        if (session != null) {
            try {
                session.getBasicRemote().sendText(message);
            } catch (IOException e) {
                LOGGER.error("\n[发送消息异常]", e);
            }
        } else {
            LOGGER.info("\n[你已离线]");
        }
    }
}
相关推荐
Q_19284999065 小时前
基于Spring Boot的电影售票系统
java·spring boot·后端
陈无左耳、6 小时前
Spring Boot应用开发实战:从入门到精通
spring boot
我要学编程(ಥ_ಥ)6 小时前
初始JavaEE篇 —— 网络原理---传输层协议:深入理解UDP/TCP
java·网络·tcp/ip·udp·java-ee
烟波人长安吖~6 小时前
【目标跟踪+人流计数+人流热图(Web界面)】基于YOLOV11+Vue+SpringBoot+Flask+MySQL
vue.js·pytorch·spring boot·深度学习·yolo·目标跟踪
百事可乐☆6 小时前
全局webSocket 单个页面进行监听并移除单页面监听
网络·websocket·网络协议
深圳启明云端科技6 小时前
WiFi、蓝牙共存,物联网无线通信技术,设备无线连接数据传输应用
网络·物联网·智能家居
dengjiayue7 小时前
OSI 网络 7 层模型
网络
hgdlip7 小时前
IP属地和所在地不一致什么意思?怎么换成另外一个地方的
服务器·网络协议·tcp/ip
cnsinda_sdc7 小时前
信创数据防泄漏中信创沙箱是什么样的安全方案
运维·网络·安全·源代码管理·源代码防泄密·源代码加密
忆源8 小时前
Linux高级--2.4.5 靠协议头保证传输的 MAC/IP/TCP/UDP---协议帧格式
网络协议·tcp/ip·udp