具体设计详见:《Spring Boot 3 整合 SSE (Server-Sent Events) 企业级最佳实践(一)》
完整代码示例:《Spring Boot 3 整合 SSE (Server-Sent Events) 企业级最佳实践(二)》
配置与测试部署:《Spring Boot 3 整合 SSE (Server-Sent Events) 企业级最佳实践(三)》
文章目录
- [Spring Boot 3 SSE 完整代码示例](#Spring Boot 3 SSE 完整代码示例)
-
- [1. 事件模型 (SseEvent.java)](#1. 事件模型 (SseEvent.java))
- [2. 连接元数据 (ConnectionMetadata.java)](#2. 连接元数据 (ConnectionMetadata.java))
- [3. 连接管理器 (SseConnectionManager.java)](#3. 连接管理器 (SseConnectionManager.java))
- [4. SSE Controller (SseController.java)](#4. SSE Controller (SseController.java))
- [5. 消息服务 (SseMessageService.java)](#5. 消息服务 (SseMessageService.java))
Spring Boot 3 SSE 完整代码示例
本文档包含完整的可运行代码示例。
1. 事件模型 (SseEvent.java)
java
package com.enterprise.sse.model;
import com.fasterxml.jackson.annotation.JsonFormat;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.time.LocalDateTime;
/**
* SSE 事件消息模型
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class SseEvent {
private String id; // 事件唯一ID
private String event; // 事件类型
private Object data; // 消息数据
private Long retry; // 重试时间(毫秒)
@JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss")
private LocalDateTime timestamp;
private String targetUserId; // 目标用户
private String targetGroupId; // 目标分组
@Builder.Default
private Integer priority = 5; // 优先级 0-10
/**
* 创建通知事件
*/
public static SseEvent notification(String userId, String message) {
return SseEvent.builder()
.id(generateId())
.event("notification")
.data(message)
.targetUserId(userId)
.timestamp(LocalDateTime.now())
.build();
}
/**
* 创建进度事件
*/
public static SseEvent progress(String userId, int percentage) {
return SseEvent.builder()
.id(generateId())
.event("progress")
.data(new ProgressData(percentage,
percentage == 100 ? "completed" : "processing"))
.targetUserId(userId)
.timestamp(LocalDateTime.now())
.build();
}
/**
* 创建系统广播
*/
public static SseEvent systemBroadcast(String message) {
return SseEvent.builder()
.id(generateId())
.event("system")
.data(message)
.timestamp(LocalDateTime.now())
.priority(10)
.build();
}
private static String generateId() {
return System.currentTimeMillis() + "-" +
Long.toHexString(Double.doubleToLongBits(Math.random()));
}
@Data
@AllArgsConstructor
public static class ProgressData {
private int percentage;
private String status;
}
}
2. 连接元数据 (ConnectionMetadata.java)
java
package com.enterprise.sse.model;
import lombok.Builder;
import lombok.Data;
import java.time.LocalDateTime;
import java.util.Set;
/**
* SSE 连接元数据
*/
@Data
@Builder
public class ConnectionMetadata {
private String userId;
private String connectionId;
private LocalDateTime connectedAt;
private LocalDateTime lastActiveAt;
private String clientIp;
private String userAgent;
private long messageCount;
private ConnectionStatus status;
private Set<String> groups;
public enum ConnectionStatus {
ACTIVE, IDLE, TIMEOUT, ERROR
}
public void updateLastActive() {
this.lastActiveAt = LocalDateTime.now();
}
public void incrementMessageCount() {
this.messageCount++;
}
public boolean isTimeout(int timeoutSeconds) {
return lastActiveAt.plusSeconds(timeoutSeconds)
.isBefore(LocalDateTime.now());
}
}
3. 连接管理器 (SseConnectionManager.java)
java
package com.enterprise.sse.manager;
import com.enterprise.sse.model.ConnectionMetadata;
import com.enterprise.sse.model.SseEvent;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
/**
* SSE 连接管理器 - 核心组件
*/
@Slf4j
@Component
public class SseConnectionManager {
// 连接存储
private final ConcurrentHashMap<String, SseEmitter> userConnections = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, String> connectionToUser = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, Set<String>> groupConnections = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, ConnectionMetadata> metadata = new ConcurrentHashMap<>();
private final AtomicLong connectionCount = new AtomicLong(0);
private final ObjectMapper objectMapper;
// Prometheus 指标
private final Counter connectionCreatedCounter;
private final Counter connectionClosedCounter;
private final Counter messageSentCounter;
private final Counter messageFailedCounter;
@Value("${sse.connection.timeout-seconds:3600}")
private int connectionTimeout;
@Value("${sse.connection.max-per-user:3}")
private int maxConnectionsPerUser;
@Value("${sse.connection.global-limit:10000}")
private int globalConnectionLimit;
public SseConnectionManager(ObjectMapper objectMapper, MeterRegistry meterRegistry) {
this.objectMapper = objectMapper;
this.connectionCreatedCounter = Counter.builder("sse.connections.created")
.description("Total SSE connections created")
.register(meterRegistry);
this.connectionClosedCounter = Counter.builder("sse.connections.closed")
.description("Total SSE connections closed")
.register(meterRegistry);
this.messageSentCounter = Counter.builder("sse.messages.sent")
.description("Total messages sent")
.register(meterRegistry);
this.messageFailedCounter = Counter.builder("sse.messages.failed")
.description("Total failed messages")
.register(meterRegistry);
}
/**
* 创建新连接
*/
public SseEmitter createConnection(String userId, String clientIp, String userAgent) {
// 检查连接限制
if (connectionCount.get() >= globalConnectionLimit) {
throw new IllegalStateException("服务器连接数已满,请稍后重试");
}
if (getUserConnectionCount(userId) >= maxConnectionsPerUser) {
throw new IllegalStateException("您的连接数已达上限");
}
SseEmitter emitter = new SseEmitter(connectionTimeout * 1000L);
String connectionId = UUID.randomUUID().toString();
// 注册回调
emitter.onCompletion(() -> handleConnectionClose(userId, connectionId, "completed"));
emitter.onTimeout(() -> handleConnectionClose(userId, connectionId, "timeout"));
emitter.onError(throwable -> handleConnectionError(userId, connectionId, throwable));
// 存储连接
userConnections.put(userId, emitter);
connectionToUser.put(connectionId, userId);
// 创建元数据
ConnectionMetadata meta = ConnectionMetadata.builder()
.userId(userId)
.connectionId(connectionId)
.connectedAt(LocalDateTime.now())
.lastActiveAt(LocalDateTime.now())
.clientIp(clientIp)
.userAgent(userAgent)
.messageCount(0)
.status(ConnectionMetadata.ConnectionStatus.ACTIVE)
.groups(new HashSet<>())
.build();
metadata.put(connectionId, meta);
connectionCount.incrementAndGet();
connectionCreatedCounter.increment();
log.info("创建 SSE 连接: userId={}, connectionId={}, clientIp={}",
userId, connectionId, clientIp);
sendWelcomeMessage(userId, emitter);
return emitter;
}
/**
* 发送消息给指定用户
*/
public boolean sendToUser(String userId, SseEvent event) {
SseEmitter emitter = userConnections.get(userId);
if (emitter == null) {
log.debug("用户 {} 未连接", userId);
return false;
}
return sendMessage(emitter, userId, event);
}
/**
* 发送消息给分组
*/
public int sendToGroup(String groupId, SseEvent event) {
Set<String> users = groupConnections.get(groupId);
if (users == null || users.isEmpty()) {
return 0;
}
int successCount = 0;
for (String userId : users) {
if (sendToUser(userId, event)) {
successCount++;
}
}
log.info("向分组 {} 推送,成功: {}/{}", groupId, successCount, users.size());
return successCount;
}
/**
* 广播消息
*/
public int broadcast(SseEvent event) {
int successCount = 0;
for (Map.Entry<String, SseEmitter> entry : userConnections.entrySet()) {
if (sendMessage(entry.getValue(), entry.getKey(), event)) {
successCount++;
}
}
log.info("广播消息,成功: {}/{}", successCount, userConnections.size());
return successCount;
}
/**
* 加入分组
*/
public void addToGroup(String userId, String groupId) {
groupConnections.computeIfAbsent(groupId, k -> ConcurrentHashMap.newKeySet())
.add(userId);
metadata.values().stream()
.filter(meta -> meta.getUserId().equals(userId))
.forEach(meta -> meta.getGroups().add(groupId));
log.debug("用户 {} 加入分组 {}", userId, groupId);
}
/**
* 离开分组
*/
public void removeFromGroup(String userId, String groupId) {
Set<String> users = groupConnections.get(groupId);
if (users != null) {
users.remove(userId);
if (users.isEmpty()) {
groupConnections.remove(groupId);
}
}
metadata.values().stream()
.filter(meta -> meta.getUserId().equals(userId))
.forEach(meta -> meta.getGroups().remove(groupId));
}
/**
* 发送心跳
*/
public void sendHeartbeat(String userId) {
SseEmitter emitter = userConnections.get(userId);
if (emitter != null) {
try {
emitter.send(SseEmitter.event().comment("keepalive"));
updateLastActive(userId);
} catch (IOException e) {
log.warn("发送心跳失败: userId={}", userId);
removeConnection(userId);
}
}
}
/**
* 清理超时连接
*/
public int cleanupTimeoutConnections() {
List<String> timeoutUsers = new ArrayList<>();
for (ConnectionMetadata meta : metadata.values()) {
if (meta.isTimeout(connectionTimeout)) {
timeoutUsers.add(meta.getUserId());
}
}
timeoutUsers.forEach(this::removeConnection);
if (!timeoutUsers.isEmpty()) {
log.info("清理超时连接: {} 个", timeoutUsers.size());
}
return timeoutUsers.size();
}
/**
* 关闭连接
*/
public void removeConnection(String userId) {
SseEmitter emitter = userConnections.remove(userId);
if (emitter != null) {
try {
emitter.complete();
} catch (Exception e) {
log.warn("关闭连接异常: userId={}", userId);
}
String connectionId = findConnectionId(userId);
if (connectionId != null) {
metadata.remove(connectionId);
connectionToUser.remove(connectionId);
}
groupConnections.values().forEach(users -> users.remove(userId));
connectionCount.decrementAndGet();
connectionClosedCounter.increment();
log.info("移除连接: userId={}", userId);
}
}
/**
* 关闭所有连接
*/
public void closeAll() {
log.info("关闭所有连接,总数: {}", userConnections.size());
userConnections.keySet().forEach(this::removeConnection);
}
/**
* 获取统计信息
*/
public ConnectionStats getStats() {
return new ConnectionStats(
connectionCount.get(),
userConnections.size(),
groupConnections.size(),
getTotalMessagesSent()
);
}
/**
* 获取在线用户
*/
public List<String> getOnlineUsers() {
return new ArrayList<>(userConnections.keySet());
}
// ========== 私有方法 ==========
private boolean sendMessage(SseEmitter emitter, String userId, SseEvent event) {
try {
SseEmitter.SseEventBuilder builder = SseEmitter.event();
if (event.getId() != null) builder.id(event.getId());
if (event.getEvent() != null) builder.name(event.getEvent());
if (event.getRetry() != null) builder.reconnectTime(event.getRetry());
String jsonData = objectMapper.writeValueAsString(event.getData());
builder.data(jsonData);
emitter.send(builder);
updateLastActive(userId);
incrementMessageCount(userId);
messageSentCounter.increment();
return true;
} catch (IOException e) {
log.error("发送消息失败: userId={}", userId, e);
messageFailedCounter.increment();
removeConnection(userId);
return false;
}
}
private void sendWelcomeMessage(String userId, SseEmitter emitter) {
try {
SseEvent welcome = SseEvent.builder()
.id(UUID.randomUUID().toString())
.event("connected")
.data(Map.of(
"message", "连接成功",
"userId", userId,
"timestamp", LocalDateTime.now()
))
.build();
sendMessage(emitter, userId, welcome);
} catch (Exception e) {
log.warn("发送欢迎消息失败: userId={}", userId);
}
}
private void handleConnectionClose(String userId, String connectionId, String reason) {
log.info("连接关闭: userId={}, reason={}", userId, reason);
removeConnection(userId);
}
private void handleConnectionError(String userId, String connectionId, Throwable throwable) {
log.error("连接错误: userId={}, connectionId={}", userId, connectionId, throwable);
removeConnection(userId);
}
private void updateLastActive(String userId) {
String connectionId = findConnectionId(userId);
if (connectionId != null) {
ConnectionMetadata meta = metadata.get(connectionId);
if (meta != null) {
meta.updateLastActive();
}
}
}
private void incrementMessageCount(String userId) {
String connectionId = findConnectionId(userId);
if (connectionId != null) {
ConnectionMetadata meta = metadata.get(connectionId);
if (meta != null) {
meta.incrementMessageCount();
}
}
}
private String findConnectionId(String userId) {
return connectionToUser.entrySet().stream()
.filter(entry -> entry.getValue().equals(userId))
.map(Map.Entry::getKey)
.findFirst()
.orElse(null);
}
private int getUserConnectionCount(String userId) {
return (int) connectionToUser.values().stream()
.filter(uid -> uid.equals(userId))
.count();
}
private long getTotalMessagesSent() {
return metadata.values().stream()
.mapToLong(ConnectionMetadata::getMessageCount)
.sum();
}
public record ConnectionStats(
long totalConnections,
int activeConnections,
int groups,
long totalMessagesSent
) {}
}
4. SSE Controller (SseController.java)
java
package com.enterprise.sse.controller;
import com.enterprise.sse.manager.SseConnectionManager;
import jakarta.servlet.http.HttpServletRequest;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
import org.springframework.security.core.Authentication;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.Map;
/**
* SSE 连接控制器
*/
@Slf4j
@RestController
@RequestMapping("/api/sse")
@RequiredArgsConstructor
public class SseController {
private final SseConnectionManager connectionManager;
/**
* 建立 SSE 连接
*/
@GetMapping(value = "/connect", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter connect(
@RequestParam(required = false) String userId,
HttpServletRequest request,
Authentication authentication) {
// 从认证信息获取用户ID
if (userId == null && authentication != null) {
userId = authentication.getName();
}
if (userId == null || userId.isBlank()) {
throw new IllegalArgumentException("用户ID不能为空");
}
String clientIp = getClientIp(request);
String userAgent = request.getHeader("User-Agent");
log.info("收到连接请求: userId={}, ip={}", userId, clientIp);
return connectionManager.createConnection(userId, clientIp, userAgent);
}
/**
* 加入分组
*/
@PostMapping("/group/{groupId}/join")
public Map<String, Object> joinGroup(
@PathVariable String groupId,
Authentication authentication) {
String userId = authentication.getName();
connectionManager.addToGroup(userId, groupId);
return Map.of("success", true, "message", "成功加入分组: " + groupId);
}
/**
* 离开分组
*/
@PostMapping("/group/{groupId}/leave")
public Map<String, Object> leaveGroup(
@PathVariable String groupId,
Authentication authentication) {
String userId = authentication.getName();
connectionManager.removeFromGroup(userId, groupId);
return Map.of("success", true, "message", "成功离开分组: " + groupId);
}
/**
* 获取连接统计
*/
@GetMapping("/stats")
public SseConnectionManager.ConnectionStats getStats() {
return connectionManager.getStats();
}
/**
* 获取在线用户
*/
@GetMapping("/online-users")
public Map<String, Object> getOnlineUsers() {
return Map.of(
"users", connectionManager.getOnlineUsers(),
"count", connectionManager.getStats().activeConnections()
);
}
/**
* 断开连接
*/
@DeleteMapping("/disconnect")
public Map<String, Object> disconnect(Authentication authentication) {
String userId = authentication.getName();
connectionManager.removeConnection(userId);
return Map.of("success", true, "message", "连接已断开");
}
/**
* 获取客户端真实IP
*/
private String getClientIp(HttpServletRequest request) {
String ip = request.getHeader("X-Forwarded-For");
if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("X-Real-IP");
}
if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
if (ip != null && ip.contains(",")) {
ip = ip.split(",")[0].trim();
}
return ip;
}
}
5. 消息服务 (SseMessageService.java)
java
package com.enterprise.sse.service;
import com.enterprise.sse.manager.SseConnectionManager;
import com.enterprise.sse.model.SseEvent;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
/**
* SSE 消息服务
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class SseMessageService {
private final SseConnectionManager connectionManager;
private final StringRedisTemplate redisTemplate;
private final ObjectMapper objectMapper;
private final Executor taskExecutor;
@Value("${sse.redis.channel:sse:message}")
private String redisChannel;
@Value("${sse.redis.enabled:true}")
private boolean redisEnabled;
/**
* 发送消息给用户(单播)
*/
public boolean sendToUser(String userId, SseEvent event) {
event.setTargetUserId(userId);
boolean localSuccess = connectionManager.sendToUser(userId, event);
if (redisEnabled) {
publishToCluster(event);
}
return localSuccess;
}
/**
* 异步发送
*/
public CompletableFuture<Boolean> sendToUserAsync(String userId, SseEvent event) {
return CompletableFuture.supplyAsync(
() -> sendToUser(userId, event),
taskExecutor
);
}
/**
* 发送给分组(组播)
*/
public int sendToGroup(String groupId, SseEvent event) {
event.setTargetGroupId(groupId);
int localCount = connectionManager.sendToGroup(groupId, event);
if (redisEnabled) {
publishToCluster(event);
}
return localCount;
}
/**
* 广播消息
*/
public int broadcast(SseEvent event) {
int localCount = connectionManager.broadcast(event);
if (redisEnabled) {
publishToCluster(event);
}
return localCount;
}
/**
* 发送通知
*/
public void sendNotification(String userId, String title, String content) {
SseEvent event = SseEvent.builder()
.id(generateEventId())
.event("notification")
.data(new NotificationData(title, content, System.currentTimeMillis()))
.targetUserId(userId)
.build();
sendToUser(userId, event);
}
/**
* 发送进度更新
*/
public void sendProgress(String userId, String taskId, int percentage, String message) {
SseEvent event = SseEvent.builder()
.id(generateEventId())
.event("progress")
.data(new ProgressData(taskId, percentage, message))
.targetUserId(userId)
.build();
sendToUser(userId, event);
}
/**
* 批量发送
*/
public int sendBatch(List<String> userIds, SseEvent event) {
int successCount = 0;
for (String userId : userIds) {
if (sendToUser(userId, event)) {
successCount++;
}
}
log.info("批量发送: 目标={}, 成功={}", userIds.size(), successCount);
return successCount;
}
/**
* 发布到Redis集群
*/
private void publishToCluster(SseEvent event) {
try {
String message = objectMapper.writeValueAsString(event);
redisTemplate.convertAndSend(redisChannel, message);
} catch (Exception e) {
log.error("发布到Redis失败", e);
}
}
private String generateEventId() {
return System.currentTimeMillis() + "-" +
Long.toHexString(Double.doubleToLongBits(Math.random()));
}
// 数据模型
public record NotificationData(String title, String content, long timestamp) {}
public record ProgressData(String taskId, int percentage, String message) {}
}
代码示例完整,可直接在项目中使用。