Spring Boot 3 整合 SSE (Server-Sent Events) 企业级最佳实践(二)

具体设计详见:《Spring Boot 3 整合 SSE (Server-Sent Events) 企业级最佳实践(一)

完整代码示例:《Spring Boot 3 整合 SSE (Server-Sent Events) 企业级最佳实践(二)

配置与测试部署:《Spring Boot 3 整合 SSE (Server-Sent Events) 企业级最佳实践(三)

文章目录

  • [Spring Boot 3 SSE 完整代码示例](#Spring Boot 3 SSE 完整代码示例)
    • [1. 事件模型 (SseEvent.java)](#1. 事件模型 (SseEvent.java))
    • [2. 连接元数据 (ConnectionMetadata.java)](#2. 连接元数据 (ConnectionMetadata.java))
    • [3. 连接管理器 (SseConnectionManager.java)](#3. 连接管理器 (SseConnectionManager.java))
    • [4. SSE Controller (SseController.java)](#4. SSE Controller (SseController.java))
    • [5. 消息服务 (SseMessageService.java)](#5. 消息服务 (SseMessageService.java))

Spring Boot 3 SSE 完整代码示例

本文档包含完整的可运行代码示例。


1. 事件模型 (SseEvent.java)

java 复制代码
package com.enterprise.sse.model;

import com.fasterxml.jackson.annotation.JsonFormat;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.time.LocalDateTime;

/**
 * SSE 事件消息模型
 */
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class SseEvent {
    
    private String id;              // 事件唯一ID
    private String event;           // 事件类型
    private Object data;            // 消息数据
    private Long retry;             // 重试时间(毫秒)
    
    @JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss")
    private LocalDateTime timestamp;
    
    private String targetUserId;    // 目标用户
    private String targetGroupId;   // 目标分组
    
    @Builder.Default
    private Integer priority = 5;   // 优先级 0-10
    
    /**
     * 创建通知事件
     */
    public static SseEvent notification(String userId, String message) {
        return SseEvent.builder()
                .id(generateId())
                .event("notification")
                .data(message)
                .targetUserId(userId)
                .timestamp(LocalDateTime.now())
                .build();
    }
    
    /**
     * 创建进度事件
     */
    public static SseEvent progress(String userId, int percentage) {
        return SseEvent.builder()
                .id(generateId())
                .event("progress")
                .data(new ProgressData(percentage, 
                    percentage == 100 ? "completed" : "processing"))
                .targetUserId(userId)
                .timestamp(LocalDateTime.now())
                .build();
    }
    
    /**
     * 创建系统广播
     */
    public static SseEvent systemBroadcast(String message) {
        return SseEvent.builder()
                .id(generateId())
                .event("system")
                .data(message)
                .timestamp(LocalDateTime.now())
                .priority(10)
                .build();
    }
    
    private static String generateId() {
        return System.currentTimeMillis() + "-" + 
               Long.toHexString(Double.doubleToLongBits(Math.random()));
    }
    
    @Data
    @AllArgsConstructor
    public static class ProgressData {
        private int percentage;
        private String status;
    }
}

2. 连接元数据 (ConnectionMetadata.java)

java 复制代码
package com.enterprise.sse.model;

import lombok.Builder;
import lombok.Data;
import java.time.LocalDateTime;
import java.util.Set;

/**
 * SSE 连接元数据
 */
@Data
@Builder
public class ConnectionMetadata {
    
    private String userId;
    private String connectionId;
    private LocalDateTime connectedAt;
    private LocalDateTime lastActiveAt;
    private String clientIp;
    private String userAgent;
    private long messageCount;
    private ConnectionStatus status;
    private Set<String> groups;
    
    public enum ConnectionStatus {
        ACTIVE, IDLE, TIMEOUT, ERROR
    }
    
    public void updateLastActive() {
        this.lastActiveAt = LocalDateTime.now();
    }
    
    public void incrementMessageCount() {
        this.messageCount++;
    }
    
    public boolean isTimeout(int timeoutSeconds) {
        return lastActiveAt.plusSeconds(timeoutSeconds)
                .isBefore(LocalDateTime.now());
    }
}

3. 连接管理器 (SseConnectionManager.java)

java 复制代码
package com.enterprise.sse.manager;

import com.enterprise.sse.model.ConnectionMetadata;
import com.enterprise.sse.model.SseEvent;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import java.io.IOException;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;

/**
 * SSE 连接管理器 - 核心组件
 */
@Slf4j
@Component
public class SseConnectionManager {
    
    // 连接存储
    private final ConcurrentHashMap<String, SseEmitter> userConnections = new ConcurrentHashMap<>();
    private final ConcurrentHashMap<String, String> connectionToUser = new ConcurrentHashMap<>();
    private final ConcurrentHashMap<String, Set<String>> groupConnections = new ConcurrentHashMap<>();
    private final ConcurrentHashMap<String, ConnectionMetadata> metadata = new ConcurrentHashMap<>();
    
    private final AtomicLong connectionCount = new AtomicLong(0);
    private final ObjectMapper objectMapper;
    
    // Prometheus 指标
    private final Counter connectionCreatedCounter;
    private final Counter connectionClosedCounter;
    private final Counter messageSentCounter;
    private final Counter messageFailedCounter;
    
    @Value("${sse.connection.timeout-seconds:3600}")
    private int connectionTimeout;
    
    @Value("${sse.connection.max-per-user:3}")
    private int maxConnectionsPerUser;
    
    @Value("${sse.connection.global-limit:10000}")
    private int globalConnectionLimit;
    
    public SseConnectionManager(ObjectMapper objectMapper, MeterRegistry meterRegistry) {
        this.objectMapper = objectMapper;
        
        this.connectionCreatedCounter = Counter.builder("sse.connections.created")
                .description("Total SSE connections created")
                .register(meterRegistry);
        
        this.connectionClosedCounter = Counter.builder("sse.connections.closed")
                .description("Total SSE connections closed")
                .register(meterRegistry);
        
        this.messageSentCounter = Counter.builder("sse.messages.sent")
                .description("Total messages sent")
                .register(meterRegistry);
        
        this.messageFailedCounter = Counter.builder("sse.messages.failed")
                .description("Total failed messages")
                .register(meterRegistry);
    }
    
    /**
     * 创建新连接
     */
    public SseEmitter createConnection(String userId, String clientIp, String userAgent) {
        // 检查连接限制
        if (connectionCount.get() >= globalConnectionLimit) {
            throw new IllegalStateException("服务器连接数已满,请稍后重试");
        }
        
        if (getUserConnectionCount(userId) >= maxConnectionsPerUser) {
            throw new IllegalStateException("您的连接数已达上限");
        }
        
        SseEmitter emitter = new SseEmitter(connectionTimeout * 1000L);
        String connectionId = UUID.randomUUID().toString();
        
        // 注册回调
        emitter.onCompletion(() -> handleConnectionClose(userId, connectionId, "completed"));
        emitter.onTimeout(() -> handleConnectionClose(userId, connectionId, "timeout"));
        emitter.onError(throwable -> handleConnectionError(userId, connectionId, throwable));
        
        // 存储连接
        userConnections.put(userId, emitter);
        connectionToUser.put(connectionId, userId);
        
        // 创建元数据
        ConnectionMetadata meta = ConnectionMetadata.builder()
                .userId(userId)
                .connectionId(connectionId)
                .connectedAt(LocalDateTime.now())
                .lastActiveAt(LocalDateTime.now())
                .clientIp(clientIp)
                .userAgent(userAgent)
                .messageCount(0)
                .status(ConnectionMetadata.ConnectionStatus.ACTIVE)
                .groups(new HashSet<>())
                .build();
        
        metadata.put(connectionId, meta);
        connectionCount.incrementAndGet();
        connectionCreatedCounter.increment();
        
        log.info("创建 SSE 连接: userId={}, connectionId={}, clientIp={}", 
                userId, connectionId, clientIp);
        
        sendWelcomeMessage(userId, emitter);
        return emitter;
    }
    
    /**
     * 发送消息给指定用户
     */
    public boolean sendToUser(String userId, SseEvent event) {
        SseEmitter emitter = userConnections.get(userId);
        if (emitter == null) {
            log.debug("用户 {} 未连接", userId);
            return false;
        }
        return sendMessage(emitter, userId, event);
    }
    
    /**
     * 发送消息给分组
     */
    public int sendToGroup(String groupId, SseEvent event) {
        Set<String> users = groupConnections.get(groupId);
        if (users == null || users.isEmpty()) {
            return 0;
        }
        
        int successCount = 0;
        for (String userId : users) {
            if (sendToUser(userId, event)) {
                successCount++;
            }
        }
        
        log.info("向分组 {} 推送,成功: {}/{}", groupId, successCount, users.size());
        return successCount;
    }
    
    /**
     * 广播消息
     */
    public int broadcast(SseEvent event) {
        int successCount = 0;
        for (Map.Entry<String, SseEmitter> entry : userConnections.entrySet()) {
            if (sendMessage(entry.getValue(), entry.getKey(), event)) {
                successCount++;
            }
        }
        
        log.info("广播消息,成功: {}/{}", successCount, userConnections.size());
        return successCount;
    }
    
    /**
     * 加入分组
     */
    public void addToGroup(String userId, String groupId) {
        groupConnections.computeIfAbsent(groupId, k -> ConcurrentHashMap.newKeySet())
                .add(userId);
        
        metadata.values().stream()
                .filter(meta -> meta.getUserId().equals(userId))
                .forEach(meta -> meta.getGroups().add(groupId));
        
        log.debug("用户 {} 加入分组 {}", userId, groupId);
    }
    
    /**
     * 离开分组
     */
    public void removeFromGroup(String userId, String groupId) {
        Set<String> users = groupConnections.get(groupId);
        if (users != null) {
            users.remove(userId);
            if (users.isEmpty()) {
                groupConnections.remove(groupId);
            }
        }
        
        metadata.values().stream()
                .filter(meta -> meta.getUserId().equals(userId))
                .forEach(meta -> meta.getGroups().remove(groupId));
    }
    
    /**
     * 发送心跳
     */
    public void sendHeartbeat(String userId) {
        SseEmitter emitter = userConnections.get(userId);
        if (emitter != null) {
            try {
                emitter.send(SseEmitter.event().comment("keepalive"));
                updateLastActive(userId);
            } catch (IOException e) {
                log.warn("发送心跳失败: userId={}", userId);
                removeConnection(userId);
            }
        }
    }
    
    /**
     * 清理超时连接
     */
    public int cleanupTimeoutConnections() {
        List<String> timeoutUsers = new ArrayList<>();
        
        for (ConnectionMetadata meta : metadata.values()) {
            if (meta.isTimeout(connectionTimeout)) {
                timeoutUsers.add(meta.getUserId());
            }
        }
        
        timeoutUsers.forEach(this::removeConnection);
        
        if (!timeoutUsers.isEmpty()) {
            log.info("清理超时连接: {} 个", timeoutUsers.size());
        }
        
        return timeoutUsers.size();
    }
    
    /**
     * 关闭连接
     */
    public void removeConnection(String userId) {
        SseEmitter emitter = userConnections.remove(userId);
        if (emitter != null) {
            try {
                emitter.complete();
            } catch (Exception e) {
                log.warn("关闭连接异常: userId={}", userId);
            }
            
            String connectionId = findConnectionId(userId);
            if (connectionId != null) {
                metadata.remove(connectionId);
                connectionToUser.remove(connectionId);
            }
            
            groupConnections.values().forEach(users -> users.remove(userId));
            
            connectionCount.decrementAndGet();
            connectionClosedCounter.increment();
            
            log.info("移除连接: userId={}", userId);
        }
    }
    
    /**
     * 关闭所有连接
     */
    public void closeAll() {
        log.info("关闭所有连接,总数: {}", userConnections.size());
        userConnections.keySet().forEach(this::removeConnection);
    }
    
    /**
     * 获取统计信息
     */
    public ConnectionStats getStats() {
        return new ConnectionStats(
                connectionCount.get(),
                userConnections.size(),
                groupConnections.size(),
                getTotalMessagesSent()
        );
    }
    
    /**
     * 获取在线用户
     */
    public List<String> getOnlineUsers() {
        return new ArrayList<>(userConnections.keySet());
    }
    
    // ========== 私有方法 ==========
    
    private boolean sendMessage(SseEmitter emitter, String userId, SseEvent event) {
        try {
            SseEmitter.SseEventBuilder builder = SseEmitter.event();
            
            if (event.getId() != null) builder.id(event.getId());
            if (event.getEvent() != null) builder.name(event.getEvent());
            if (event.getRetry() != null) builder.reconnectTime(event.getRetry());
            
            String jsonData = objectMapper.writeValueAsString(event.getData());
            builder.data(jsonData);
            
            emitter.send(builder);
            
            updateLastActive(userId);
            incrementMessageCount(userId);
            messageSentCounter.increment();
            
            return true;
            
        } catch (IOException e) {
            log.error("发送消息失败: userId={}", userId, e);
            messageFailedCounter.increment();
            removeConnection(userId);
            return false;
        }
    }
    
    private void sendWelcomeMessage(String userId, SseEmitter emitter) {
        try {
            SseEvent welcome = SseEvent.builder()
                    .id(UUID.randomUUID().toString())
                    .event("connected")
                    .data(Map.of(
                            "message", "连接成功",
                            "userId", userId,
                            "timestamp", LocalDateTime.now()
                    ))
                    .build();
            
            sendMessage(emitter, userId, welcome);
        } catch (Exception e) {
            log.warn("发送欢迎消息失败: userId={}", userId);
        }
    }
    
    private void handleConnectionClose(String userId, String connectionId, String reason) {
        log.info("连接关闭: userId={}, reason={}", userId, reason);
        removeConnection(userId);
    }
    
    private void handleConnectionError(String userId, String connectionId, Throwable throwable) {
        log.error("连接错误: userId={}, connectionId={}", userId, connectionId, throwable);
        removeConnection(userId);
    }
    
    private void updateLastActive(String userId) {
        String connectionId = findConnectionId(userId);
        if (connectionId != null) {
            ConnectionMetadata meta = metadata.get(connectionId);
            if (meta != null) {
                meta.updateLastActive();
            }
        }
    }
    
    private void incrementMessageCount(String userId) {
        String connectionId = findConnectionId(userId);
        if (connectionId != null) {
            ConnectionMetadata meta = metadata.get(connectionId);
            if (meta != null) {
                meta.incrementMessageCount();
            }
        }
    }
    
    private String findConnectionId(String userId) {
        return connectionToUser.entrySet().stream()
                .filter(entry -> entry.getValue().equals(userId))
                .map(Map.Entry::getKey)
                .findFirst()
                .orElse(null);
    }
    
    private int getUserConnectionCount(String userId) {
        return (int) connectionToUser.values().stream()
                .filter(uid -> uid.equals(userId))
                .count();
    }
    
    private long getTotalMessagesSent() {
        return metadata.values().stream()
                .mapToLong(ConnectionMetadata::getMessageCount)
                .sum();
    }
    
    public record ConnectionStats(
            long totalConnections,
            int activeConnections,
            int groups,
            long totalMessagesSent
    ) {}
}

4. SSE Controller (SseController.java)

java 复制代码
package com.enterprise.sse.controller;

import com.enterprise.sse.manager.SseConnectionManager;
import jakarta.servlet.http.HttpServletRequest;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
import org.springframework.security.core.Authentication;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import java.util.Map;

/**
 * SSE 连接控制器
 */
@Slf4j
@RestController
@RequestMapping("/api/sse")
@RequiredArgsConstructor
public class SseController {
    
    private final SseConnectionManager connectionManager;
    
    /**
     * 建立 SSE 连接
     */
    @GetMapping(value = "/connect", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public SseEmitter connect(
            @RequestParam(required = false) String userId,
            HttpServletRequest request,
            Authentication authentication) {
        
        // 从认证信息获取用户ID
        if (userId == null && authentication != null) {
            userId = authentication.getName();
        }
        
        if (userId == null || userId.isBlank()) {
            throw new IllegalArgumentException("用户ID不能为空");
        }
        
        String clientIp = getClientIp(request);
        String userAgent = request.getHeader("User-Agent");
        
        log.info("收到连接请求: userId={}, ip={}", userId, clientIp);
        
        return connectionManager.createConnection(userId, clientIp, userAgent);
    }
    
    /**
     * 加入分组
     */
    @PostMapping("/group/{groupId}/join")
    public Map<String, Object> joinGroup(
            @PathVariable String groupId,
            Authentication authentication) {
        
        String userId = authentication.getName();
        connectionManager.addToGroup(userId, groupId);
        
        return Map.of("success", true, "message", "成功加入分组: " + groupId);
    }
    
    /**
     * 离开分组
     */
    @PostMapping("/group/{groupId}/leave")
    public Map<String, Object> leaveGroup(
            @PathVariable String groupId,
            Authentication authentication) {
        
        String userId = authentication.getName();
        connectionManager.removeFromGroup(userId, groupId);
        
        return Map.of("success", true, "message", "成功离开分组: " + groupId);
    }
    
    /**
     * 获取连接统计
     */
    @GetMapping("/stats")
    public SseConnectionManager.ConnectionStats getStats() {
        return connectionManager.getStats();
    }
    
    /**
     * 获取在线用户
     */
    @GetMapping("/online-users")
    public Map<String, Object> getOnlineUsers() {
        return Map.of(
                "users", connectionManager.getOnlineUsers(),
                "count", connectionManager.getStats().activeConnections()
        );
    }
    
    /**
     * 断开连接
     */
    @DeleteMapping("/disconnect")
    public Map<String, Object> disconnect(Authentication authentication) {
        String userId = authentication.getName();
        connectionManager.removeConnection(userId);
        
        return Map.of("success", true, "message", "连接已断开");
    }
    
    /**
     * 获取客户端真实IP
     */
    private String getClientIp(HttpServletRequest request) {
        String ip = request.getHeader("X-Forwarded-For");
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("X-Real-IP");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        if (ip != null && ip.contains(",")) {
            ip = ip.split(",")[0].trim();
        }
        return ip;
    }
}

5. 消息服务 (SseMessageService.java)

java 复制代码
package com.enterprise.sse.service;

import com.enterprise.sse.manager.SseConnectionManager;
import com.enterprise.sse.model.SseEvent;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;

import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;

/**
 * SSE 消息服务
 */
@Slf4j
@Service
@RequiredArgsConstructor
public class SseMessageService {
    
    private final SseConnectionManager connectionManager;
    private final StringRedisTemplate redisTemplate;
    private final ObjectMapper objectMapper;
    private final Executor taskExecutor;
    
    @Value("${sse.redis.channel:sse:message}")
    private String redisChannel;
    
    @Value("${sse.redis.enabled:true}")
    private boolean redisEnabled;
    
    /**
     * 发送消息给用户(单播)
     */
    public boolean sendToUser(String userId, SseEvent event) {
        event.setTargetUserId(userId);
        
        boolean localSuccess = connectionManager.sendToUser(userId, event);
        
        if (redisEnabled) {
            publishToCluster(event);
        }
        
        return localSuccess;
    }
    
    /**
     * 异步发送
     */
    public CompletableFuture<Boolean> sendToUserAsync(String userId, SseEvent event) {
        return CompletableFuture.supplyAsync(
                () -> sendToUser(userId, event), 
                taskExecutor
        );
    }
    
    /**
     * 发送给分组(组播)
     */
    public int sendToGroup(String groupId, SseEvent event) {
        event.setTargetGroupId(groupId);
        
        int localCount = connectionManager.sendToGroup(groupId, event);
        
        if (redisEnabled) {
            publishToCluster(event);
        }
        
        return localCount;
    }
    
    /**
     * 广播消息
     */
    public int broadcast(SseEvent event) {
        int localCount = connectionManager.broadcast(event);
        
        if (redisEnabled) {
            publishToCluster(event);
        }
        
        return localCount;
    }
    
    /**
     * 发送通知
     */
    public void sendNotification(String userId, String title, String content) {
        SseEvent event = SseEvent.builder()
                .id(generateEventId())
                .event("notification")
                .data(new NotificationData(title, content, System.currentTimeMillis()))
                .targetUserId(userId)
                .build();
        
        sendToUser(userId, event);
    }
    
    /**
     * 发送进度更新
     */
    public void sendProgress(String userId, String taskId, int percentage, String message) {
        SseEvent event = SseEvent.builder()
                .id(generateEventId())
                .event("progress")
                .data(new ProgressData(taskId, percentage, message))
                .targetUserId(userId)
                .build();
        
        sendToUser(userId, event);
    }
    
    /**
     * 批量发送
     */
    public int sendBatch(List<String> userIds, SseEvent event) {
        int successCount = 0;
        for (String userId : userIds) {
            if (sendToUser(userId, event)) {
                successCount++;
            }
        }
        
        log.info("批量发送: 目标={}, 成功={}", userIds.size(), successCount);
        return successCount;
    }
    
    /**
     * 发布到Redis集群
     */
    private void publishToCluster(SseEvent event) {
        try {
            String message = objectMapper.writeValueAsString(event);
            redisTemplate.convertAndSend(redisChannel, message);
        } catch (Exception e) {
            log.error("发布到Redis失败", e);
        }
    }
    
    private String generateEventId() {
        return System.currentTimeMillis() + "-" + 
               Long.toHexString(Double.doubleToLongBits(Math.random()));
    }
    
    // 数据模型
    public record NotificationData(String title, String content, long timestamp) {}
    public record ProgressData(String taskId, int percentage, String message) {}
}

代码示例完整,可直接在项目中使用。

相关推荐
chilavert3182 小时前
技术演进中的开发沉思-349:高效并发(下)
java·jvm
远方16092 小时前
112-Oracle database 26ai下载和安装环境准备
大数据·数据库·sql·oracle·database
2401_838472512 小时前
Python多线程与多进程:如何选择?(GIL全局解释器锁详解)
jvm·数据库·python
光影少年2 小时前
非关系数据库和关系型数据库都有哪些?
数据库·数据库开发·非关系型数据库
2301_822363602 小时前
Python单元测试(unittest)实战指南
jvm·数据库·python
麦兜*2 小时前
深入解析分布式数据库TiDB核心架构:基于Raft一致性协议与HTAP混合负载实现金融级高可用与实时分析的工程实践
数据库·分布式·tidb
shejizuopin2 小时前
基于SSM的高校旧书交易系统的设计与实现(任务书)
java·mysql·毕业设计·论文·任务书·基于ssm的·高校旧书交易系统的设计与实现
好好研究2 小时前
SpringBoot使用外置Tomcat
spring boot·后端·tomcat
m0_561359672 小时前
Python面向对象编程(OOP)终极指南
jvm·数据库·python