websocket加鉴权 @ServerEndpoint方式

java 复制代码
import org.springframework.stereotype.Component;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
@Component
public class SimpleCORSFilter implements Filter {
    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
    }
    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {
        HttpServletRequest req= (HttpServletRequest) request;
        req.getSession().setAttribute("ipAddr",req.getRemoteHost());
        HttpServletResponse res = (HttpServletResponse)response;
        res.setHeader("Access-Control-Allow-Origin", "*");
        res.setHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS, DELETE");
        res.setHeader("Access-Control-Max-Age", "3600");
        res.setHeader("Access-Control-Allow-Headers", "Origin, No-Cache, X-Requested-With, If-Modified-Since, Pragma, Last-Modified, Cache-Control, Expires, Content-Type, X-E4M-With,userId,token,timestamp");
        chain.doFilter(request, response);
    }
    @Override
    public void destroy() {
    }
}
java 复制代码
import com.neo.websocket.WebSocketConfig;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;


@Component
@ServerEndpoint(value = "/wsTask/{userId}/{tId}", configurator = WebSocketConfig.class)  // 接口路径 ws://localhost:8087/webSocket/userId/111;
public class TaskWebSocket {
    private static AsynSendMsg asyncSendMsg;

    @Autowired
    public void setAsynSendMsg(AsynSendMsg asyncSendMsg) {
        TaskWebSocket.asyncSendMsg = asyncSendMsg;
    }

    public Logger log = LogManager.getLogger(getClass());

    //与某个客户端的连接会话,需要通过它来给客户端发送数据
    private Session session;
    /**
     * 用户ID
     */
    private String userId;

    //concurrent包的线程安全Set,用来存放每个客户端对应的MyWebSocket对象。
    //虽然@Component默认是单例模式的,但springboot还是会为每个websocket连接初始化一个bean,所以可以用一个静态set保存起来。
    //  注:底下WebSocket是当前类名
    private static CopyOnWriteArraySet<TaskWebSocket> webSockets =new CopyOnWriteArraySet<>();
    // 用来存在线连接用户信息
    private static ConcurrentHashMap<String,Session> sessionPool = new ConcurrentHashMap<String,Session>();


    public Session getSession() {
        return session;
    }

    public String getUserId() {
        return userId;
    }

    public static CopyOnWriteArraySet<TaskWebSocket> getWebSockets() {
        return webSockets;
    }

    public static ConcurrentHashMap<String, Session> getSessionPool() {
        return sessionPool;
    }

    /**
     * 链接成功调用的方法
     */
    @OnOpen
    public void onOpen(Session session, @PathParam(value="userId")String userId, @PathParam(value="tId")String tId) {
        try {
            this.session = session;
            this.userId = userId;
            webSockets.add(this);
            sessionPool.put(userId, session);
            log.info("【websocket消息】有新的连接,总数为:"+webSockets.size());
            asyncSendMsg.sendOneMessage(userId,tId,session);
        } catch (Exception e) {
        }
    }

    /**
     * 链接关闭调用的方法
     */
    @OnClose
    public void onClose() {
        try {
            webSockets.remove(this);
            sessionPool.remove(this.userId);
            log.info("【websocket消息】连接断开,总数为:"+webSockets.size());
        } catch (Exception e) {
        }
    }
    /**
     * 收到客户端消息后调用的方法
     *
     * @param message
     */
    @OnMessage
    public void onMessage(String message) throws InterruptedException {
        log.info("【websocket消息】收到客户端消息:"+message);
    }

    /** 发送错误时的处理
     * @param session
     * @param error
     */
    @OnError
    public void onError(Session session, Throwable error) {
        log.error("用户错误,原因:"+error.getMessage());
        error.printStackTrace();
    }


    // 此为广播消息
    public void sendAllMessage(String message) {
        log.info("【websocket消息】广播消息:"+message);
        for(TaskWebSocket webSocket : webSockets) {
            try {
                if(webSocket.session.isOpen()) {
                    webSocket.session.getAsyncRemote().sendText(message);
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    // 此为单点消息
    public void sendOneMessage(String userId,String message,Session session) throws InterruptedException {

        while (sessionPool.get(userId)!= null&&sessionPool.get(userId).isOpen()) {
            try {
                log.info("【websocket消息】 单点消息:" + message);
                session.getAsyncRemote().sendText(message);
            } catch (Exception e) {
                e.printStackTrace();
            }
            Thread.sleep(2000);
        }
    }

    // 此为单点消息(多人)
    public void sendMoreMessage(String[] userIds, String message) {
        for(String userId:userIds) {
            Session session = sessionPool.get(userId);
            if (session != null&&session.isOpen()) {
                try {
                    log.info("【websocket消息】 单点消息:"+message);
                    session.getAsyncRemote().sendText(message);
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }

    }

}
java 复制代码
import com.alibaba.fastjson.JSONObject;
import com.neo.dao.UserInfoDao;
import com.neo.utils.IpUtils;
import com.neo.utils.JwtUtils;
import com.neo.utils.SpringUtils;
import io.jsonwebtoken.ExpiredJwtException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;

import javax.servlet.annotation.WebListener;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession;
import javax.websocket.HandshakeResponse;
import javax.websocket.server.HandshakeRequest;
import javax.websocket.server.ServerEndpointConfig;
import java.util.List;


//打包部署到tomcat时,需要注释注解
@Component
@Configuration
@WebListener
public class WebSocketConfig extends ServerEndpointConfig.Configurator {
    public Logger logger = LoggerFactory.getLogger(getClass());
	@Bean
    public ServerEndpointExporter serverEndpointExporter()
    {
        return new ServerEndpointExporter();
    }
    @Override
    public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
        //获取请求头
        List<String> list = request.getHeaders().get("Sec-WebSocket-Protocol");

        //当Sec-WebSocket-Protocol请求头不为空时,需要返回给前端相同的响应
        response.getHeaders().put("Sec-WebSocket-Protocol",list);
        if(list == null ||list.size()==0) {
            throw new RuntimeException("无token参数");
        }else {
            for (String token : list) {
                Boolean verfy = JwtUtils.verfy(token);
                if(!verfy) {
                    logger.info("token无效");
                    throw new RuntimeException("token无效");
                }else {
                    try {
                        HttpSession session = (HttpSession) request.getHttpSession();
                        String addr = "";
                        if (session != null) {
                            addr = session.getAttribute("ipAddr").toString();
                        }
                        String subject = JwtUtils.parseJwt(token).getSubject();
                        JSONObject jsonObject = JSONObject.parseObject(subject);
                        String ip = jsonObject.getString("ip");
                        String fingerprintToken = jsonObject.getString("fingerprint");
                        String userAgent = request.getHeaders().get("User-Agent").get(0);
                        String fingerprint = userAgent;
                        if (!ip.equals(addr) || !fingerprintToken.contains(fingerprint)) {
                            logger.info("网络环境改变,请重新登录!");
                            throw new RuntimeException("网络环境改变,请重新登录!");
                        } else {
                            UserInfoDao userInfoDao = SpringUtils.getBean("userInfoDao", UserInfoDao.class);
                            int i = userInfoDao.countToken(token);
                            if (i > 0) {
                                super.modifyHandshake(sec, request, response);
                            } else {
                                logger.info("token已失效,请重新登录!");
                                throw new RuntimeException("token已失效,请重新登录!");
                            }
                        }
                    } catch (ExpiredJwtException e) {
                        logger.info("token已过期");
                        throw new RuntimeException("token已过期");
                    } catch (Exception e) {
                        logger.info("token已失效");
                        e.printStackTrace();
                        throw new RuntimeException("token已失效");
                    }
                }
            }
        }
    }
}
相关推荐
LUCIAZZZ7 分钟前
HikariCP数据库连接池原理解析
java·jvm·数据库·spring·springboot·线程池·连接池
FakeOccupational29 分钟前
【碎碎念】宝可梦 Mesh GO : 基于MESH网络的口袋妖怪 宝可梦GO游戏自组网系统
网络·游戏
sky_ph31 分钟前
JAVA-GC浅析(二)G1(Garbage First)回收器
java·后端
IDRSolutions_CN1 小时前
PDF 转 HTML5 —— HTML5 填充图形不支持 Even-Odd 奇偶规则?(第二部分)
java·经验分享·pdf·软件工程·团队开发
hello早上好1 小时前
Spring不同类型的ApplicationContext的创建方式
java·后端·架构
HelloWord~2 小时前
SpringSecurity+vue通用权限系统2
java·vue.js
让我上个超影吧2 小时前
黑马点评【基于redis实现共享session登录】
java·redis
fei_sun3 小时前
【计算机网络】三报文握手建立TCP连接
网络·tcp/ip·计算机网络
BillKu3 小时前
Java + Spring Boot + Mybatis 插入数据后,获取自增 id 的方法
java·tomcat·mybatis
全栈凯哥3 小时前
Java详解LeetCode 热题 100(26):LeetCode 142. 环形链表 II(Linked List Cycle II)详解
java·算法·leetcode·链表