Springboot+WebSocket实现消息推送

WebSocket是一种在单个TCP连接上进行全双工通信的协议。WebSocket通信协议于2011年被IETF定为标准RFC 6455,并由RFC7936补充规范。WebSocketAPI也被W3C定为标准。

复制代码
WebSocket使得客户端和服务器之间的数据交换变得更加简单,允许服务端主动向客户端推送数据。在WebSocket API中,浏览器和服务器只需要完成一次握手,两者之间就直接可以创建持久性的连接,并进行双向数据传输。
创建定时任务,实现定时向前端推送相关消息。
创建存放ws推送的参数缓存Map,定时任务获取参数,获取数据后推送。

引入依赖

java 复制代码
<dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-websocket</artifactId>
        </dependency>

开启WebSocket支持的配置类

java 复制代码
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;

/**
 * 功能描述:
 * 开启websocket支持
 */
@Configuration
public class WebSocketConfig {

    // 使用boot内置tomcat时需要注入此bean
    @Bean
    public ServerEndpointExporter serverEndpointExporter() {
        return new ServerEndpointExporter();
    }
}

WebSocketServer服务端

java 复制代码
import org.apache.commons.lang.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.stereotype.Component;
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 功能描述:
 * WebSocketServer服务端
 */
// @ServerEndpoint 注解是一个类层次的注解,它的功能主要是将目前的类定义成一个websocket服务器端。注解的值将被用于监听用户连接的终端访问URL地址
// encoders = WebSocketCustomEncoding.class 是为了使用ws自己的推送Object消息对象(sendObject())时进行解码,通过Encoder 自定义规则(转换为JSON字符串)
@ServerEndpoint(value = "/websocket/{userId}",encoders = WebSocketCustomEncoding.class)
@Component
public class WebSocket {
    private final static Logger logger = LogManager.getLogger(WebSocket.class);

    /**
     * 静态变量,用来记录当前在线连接数。应该把它设计成线程安全的
     */

    private static int onlineCount = 0;

    /**
     * concurrent包的线程安全Map,用来存放每个客户端对应的MyWebSocket对象
     */
    public static ConcurrentHashMap<String, WebSocket> webSocketMap = new ConcurrentHashMap<>();

    /***
     * 功能描述:
     * concurrent包的线程安全Map,用来存放每个客户端对应的MyWebSocket对象的参数体
     */
    public static ConcurrentHashMap<String, PushParams> webSocketParamsMap = new ConcurrentHashMap<>();

    /**
     * 与某个客户端的连接会话,需要通过它来给客户端发送数据
     */

    private Session session;
    private String userId;


    /**
     * 连接建立成功调用的方法
     * onOpen 和 onClose 方法分别被@OnOpen和@OnClose 所注解。他们定义了当一个新用户连接和断开的时候所调用的方法。
     */
    @OnOpen
    public void onOpen(Session session, @PathParam("userId") String userId) {
        this.session = session;
        this.userId = userId;
        //加入map
        webSocketMap.put(userId, this);
        addOnlineCount();           //在线数加1
        logger.info("用户{}连接成功,当前在线人数为{}", userId, getOnlineCount());
        try {
            sendMessage(String.valueOf(this.session.getQueryString()));
        } catch (IOException e) {
            logger.error("IO异常");
        }
    }


    /**
     * 连接关闭调用的方法
     */
    @OnClose
    public void onClose() {
        //从map中删除
        webSocketMap.remove(userId);
        subOnlineCount();           //在线数减1
        logger.info("用户{}关闭连接!当前在线人数为{}", userId, getOnlineCount());
    }

    /**
     * 收到客户端消息后调用的方法
     * onMessage 方法被@OnMessage所注解。这个注解定义了当服务器接收到客户端发送的消息时所调用的方法。
     * @param message 客户端发送过来的消息
     */
    @OnMessage
    public void onMessage(String message, Session session) {
        logger.info("来自客户端用户:{} 消息:{}",userId, message);

        //群发消息
        /*for (String item : webSocketMap.keySet()) {
            try {
                webSocketMap.get(item).sendMessage(message);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }*/
    }

    /**
     * 发生错误时调用
     *
     * @OnError
     */
    @OnError
    public void onError(Session session, Throwable error) {
        logger.error("用户错误:" + this.userId + ",原因:" + error.getMessage());
        error.printStackTrace();
    }

    /**
     * 向客户端发送消息
     */
    public void sendMessage(String message) throws IOException {
        this.session.getBasicRemote().sendText(message);
        //this.session.getAsyncRemote().sendText(message);
    }

    /**
     * 向客户端发送消息
     */
    public void sendMessage(Object message) throws IOException, EncodeException {
        this.session.getBasicRemote().sendObject(message);
        //this.session.getAsyncRemote().sendText(message);
    }

    /**
     * 通过userId向客户端发送消息
     */
    public void sendMessageByUserId(String userId, String message) throws IOException {
        logger.info("服务端发送消息到{},消息:{}",userId,message);

     if(StringUtils.isNotBlank(userId)&&webSocketMap.containsKey(userId)){
         webSocketMap.get(userId).sendMessage(message);
     }else{
         logger.error("用户{}不在线",userId);
     }

    }

    /**
     * 通过userId向客户端发送消息
     */
    public void sendMessageByUserId(String userId, Object message) throws IOException, EncodeException {
        logger.info("服务端发送消息到{},消息:{}",userId,message);
        if(StringUtils.isNotBlank(userId)&&webSocketMap.containsKey(userId)){
            webSocketMap.get(userId).sendMessage(message);
        }else{
            logger.error("用户{}不在线",userId);
        }
    }

    /**
     * 通过userId更新缓存的参数
     */
    public void changeParamsByUserId(String userId, PushParams pushParams) throws IOException, EncodeException {
        logger.info("ws用户{}请求参数更新,参数:{}",userId,pushParams.toString());
        webSocketParamsMap.put(userId,pushParams);
    }

    /**
     * 群发自定义消息
     */
    public static void sendInfo(String message) throws IOException {
        for (String item : webSocketMap.keySet()) {
            try {
                webSocketMap.get(item).sendMessage(message);
            } catch (IOException e) {
                continue;
            }
        }
    }

    public static synchronized int getOnlineCount() {
        return onlineCount;
    }

    public static synchronized void addOnlineCount() {
        WebSocket.onlineCount++;
    }

    public static synchronized void subOnlineCount() {
        WebSocket.onlineCount--;
    }

}

Encoder 自定义规则(转换为JSON字符串)

java 复制代码
import com.alibaba.fastjson.JSON;
import javax.websocket.EncodeException;
import javax.websocket.Encoder;
import javax.websocket.EndpointConfig;

/**
 * 在 websocket 中直接发送 obj 会有问题 - No encoder specified for object of class
 * 需要对 obj 创建解码类,实现 websocket 中的 Encoder.Text<>
 * */
public class WebSocketCustomEncoding implements Encoder.Text<Object> {
    /**
     * The Encoder interface defines how developers can provide a way to convert their
     * custom objects into web socket messages. The Encoder interface contains
     * subinterfaces that allow encoding algorithms to encode custom objects to:
     * text, binary data, character stream and write to an output stream.
     *
     * Encoder 接口定义了如何提供一种方法将定制对象转换为 websocket 消息
     * 可自定义对象编码为文本、二进制数据、字符流、写入输出流
     *  Text、TextStream、Binary、BinaryStream
     * */

    @Override
    public void init(EndpointConfig endpointConfig) {

    }

    @Override
    public void destroy() {

    }

    @Override
    public String encode(Object o) throws EncodeException {
        return JSON.toJSONString(o);
    }
}

自定义消息推送的参数体

java 复制代码
/**
 * 功能描述:
 *
 * @description: ws推送的参数结构
 */
@Data
public class PushParams {

    /**
     * 功能描述:
     * 类型
     */
    private String type;

    /**
     * 功能描述:
     * 开始时间
     */
    private String startTime;

    /**
     * 功能描述:
     * 结束时间
     */
    private String stopTime;
}

根据用户ID更新ws推送的参数,或者使用onMessage修改缓存的结构体

java 复制代码
import com.company.project.common.websocket.PushParams;
import com.company.project.common.websocket.WebSocket;
import com.company.project.service.TestMongodbService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.websocket.EncodeException;
import java.io.IOException;

/**
 * 功能描述:
 * 建立WebSocket连接
 * @Author: LXD
 * @Date: 2022-12-01 09:55:00
 * @since: 1.0.0
 */
@RestController
@RequestMapping("/webSocketPush")
public class WebSocketController {
    @Autowired
    private WebSocket webSocket;
    @Autowired
    private TestMongodbService testMongodbService;

    @RequestMapping("/sentMessage")
    public void sentMessage(String userId,String message){
        try {
            webSocket.sendMessageByUserId(userId,message);
        } catch (IOException e) {
            e.printStackTrace();
        }

    }

    @RequestMapping("/sentObjectMessage")
    public void sentObjectMessage(String userId){
        try {
            webSocket.sendMessageByUserId(userId,testMongodbService.query());
        } catch (IOException e) {
            e.printStackTrace();
        } catch (EncodeException e) {
            e.printStackTrace();
        }

    }

    /***
     * 功能描述:
     * 根据用户ID更新ws推送的参数
     * @Param  userId: WS中的用户ID
     * @Param pushParams: 推送参数
     * @return: void
     * @since: 1.0.0
     */
    @RequestMapping("/changeWsParams")
    public void changeWsParams(String userId, PushParams pushParams){
        try {
            webSocket.changeParamsByUserId(userId,pushParams);
        } catch (IOException e) {
            e.printStackTrace();
        } catch (EncodeException e) {
            e.printStackTrace();
        }

    }

}

创建定时推送的任务

java 复制代码
import com.company.project.service.TestMongodbService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.EnableScheduling;
import org.springframework.scheduling.annotation.Scheduled;
import javax.websocket.EncodeException;
import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;
import static com.company.project.common.websocket.WebSocket.webSocketMap;
import static com.company.project.common.websocket.WebSocket.webSocketParamsMap;

/**
 * 功能描述:
 *
 * @description: ws定时推送
 */
@Configuration
@EnableScheduling
public class WebsocketSchedule {

    @Autowired
    private WebSocket webSocket;
    @Autowired
    private TestMongodbService testMongodbService;
	// 第一次延迟1秒后执行,之后按fixedRate的规则每5秒执行一次 fixedRateString 与 fixedRate 意思相同,只是使用字符串的形式。唯一不同的是支持占位符
    @Scheduled(initialDelay=1000, fixedRateString = "${ws.pushInterval}")
    public void pushData() throws EncodeException, IOException {
        ConcurrentHashMap<String, WebSocket> webSocketPushMap = webSocketMap;
        ConcurrentHashMap<String, PushParams> webSocketPushParamsMap = webSocketParamsMap;
        if(!webSocketPushMap.isEmpty()){
            for(String key : webSocketPushMap.keySet()){
                // 根据ws连接用户ID获取推送参数
                PushParams pushParams = webSocketPushParamsMap.get(key);
                webSocket.sendMessageByUserId(key,testMongodbService.query());
            }
        }

    }
}
相关推荐
FFF-X7 分钟前
Vue3 路由缓存实战:从基础到进阶的完整指南
vue.js·spring boot·缓存
励志成为架构师1 小时前
跟小白一起领悟Thread——如何开启一个线程(上)
java·后端
hankeyyh1 小时前
golang 易错点-slice copy
后端·go
考虑考虑1 小时前
Redis事务
redis·后端
Victor3561 小时前
Redis(6)Redis的单线程模型是如何工作的?
后端
Victor3561 小时前
Redis(7)Redis如何实现高效的内存管理?
后端
David爱编程3 小时前
进程 vs 线程到底差在哪?一文吃透操作系统视角与 Java 视角的关键差异
后端
smileNicky13 小时前
SpringBoot系列之从繁琐配置到一键启动之旅
java·spring boot·后端
David爱编程13 小时前
为什么必须学并发编程?一文带你看懂从单线程到多线程的演进史
java·后端
long31614 小时前
java 策略模式 demo
java·开发语言·后端·spring·设计模式