1. 背景
在传统的 Java 应用中,ThreadLocal 常用于在同一线程中传递上下文信息(如请求ID、用户信息等)。
然而,随着 Java 虚拟线程(Virtual Thread) 的引入,线程数量可以非常大(成千上万),ThreadLocal 在这种场景下存在几个问题:
-
内存泄漏风险:线程长期存在时,ThreadLocal 变量容易被残留引用占用。
-
上下文传递复杂:虚拟线程切换可能导致 ThreadLocal 值不一致,尤其在使用异步或挂起操作时。
为了解决这个问题,Java 提供了 ScopedValue,用于在虚拟线程中安全、轻量地传递上下文。
2. ScopedValue 特点
-
轻量级:与 ThreadLocal 不同,它不会在每个线程上创建额外的存储空间。
-
线程安全:值是不可变的,只能在创建的作用域内访问。
-
自动传递:在虚拟线程中创建作用域时,内部逻辑可以自动将上下文传递给挂起和恢复操作。
-
适合虚拟线程:与 ThreadLocal 相比,ScopedValue 更适合大量短生命周期线程的场景。
3. 使用方式
1、全局开启使用虚拟线程(yaml配置)
bash
spring:
main:
# 保证 JVM 在全是虚拟线程情况下不会提前退出
keep-alive: true
# 全局虚拟线程开关(推荐方式)
threads:
virtual:
# 启用虚拟线程,覆盖 TaskExecutor、@Async、@Scheduled、Web Server
enabled: true
2、虚拟线程上下文传递参数
java
import lombok.Builder;
/**
* 虚拟线程上下文传递参数
*
* @param traceId 链路ID(分布式微服务传递追踪)
* @param userId 用户ID
* @param tenantId 租户ID
*/
@Builder
public record RequestContext(
String traceId,
String userId,
String tenantId) {
}
3、ScopedValue工具类
java
import lombok.NoArgsConstructor;
/**
* ScopedValue工具类
*/
@NoArgsConstructor
public final class ContextKeys {
// 链路ID
public static final String TRACE_ID = "traceId";
/**
* WEB请求上下文传递
*/
public static final ScopedValue<RequestContext> REQUEST_CONTEXT = ScopedValue.newInstance();
}
4、获取上下文业务参数
java
import lombok.NoArgsConstructor;
import org.slf4j.MDC;
import java.util.Optional;
import java.util.concurrent.Callable;
/**
* 获取上下文业务参数
*/
@NoArgsConstructor
public final class RequestContextHolder {
/**
* 获取完整上下文
*/
public static Optional<RequestContext> getOptional() {
return ContextKeys.REQUEST_CONTEXT.isBound() ? Optional.of(ContextKeys.REQUEST_CONTEXT.get()) : Optional.empty();
}
/**
* 获取 traceId
*/
public static String getTraceId() {
return getOptional().map(RequestContext::traceId).orElse(null);
}
/**
* 获取 userId
*/
public static String getUserId() {
return getOptional().map(RequestContext::userId).orElse(null);
}
/**
* 获取 tenantId
*/
public static String getTenantId() {
return getOptional().map(RequestContext::tenantId).orElse(null);
}
/**
* 绑定上下文并运行 Runnable
*/
public static void with(RequestContext ctx, Runnable task) {
ScopedValue.where(ContextKeys.REQUEST_CONTEXT, ctx).run(() -> {
try {
// MDC桥接
MDC.put(ContextKeys.TRACE_ID, ctx.traceId());
task.run();
} finally {
MDC.clear();
}
});
}
/**
* 绑定上下文并运行 Callable
*/
public static <T> T with(RequestContext ctx, Callable<T> task) throws Exception {
return ScopedValue.where(ContextKeys.REQUEST_CONTEXT, ctx).call(() -> {
try {
// MDC桥接
MDC.put(ContextKeys.TRACE_ID, ctx.traceId());
return task.call();
} finally {
MDC.clear();
}
});
}
}
5、HTTP请求上下文初始化
java
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.jetbrains.annotations.NotNull;
import org.slf4j.MDC;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;
import java.io.IOException;
import java.util.Optional;
import java.util.UUID;
/**
* HTTP请求上下文初始化
*/
@Component
@Order(Ordered.HIGHEST_PRECEDENCE)
public class ContextInitFilter extends OncePerRequestFilter {
@Override
protected void doFilterInternal(@NotNull HttpServletRequest request, @NotNull HttpServletResponse response, @NotNull FilterChain filterChain) throws ServletException, IOException {
RequestContext ctx = buildContext(request);
// ScopedValue 绑定上下文
try {
ScopedValue.where(ContextKeys.REQUEST_CONTEXT, ctx).run(() -> {
try {
// MDC桥接
MDC.put(ContextKeys.TRACE_ID, ctx.traceId());
filterChain.doFilter(request, response);
} catch (IOException | ServletException e) {
throw new RuntimeException(e);
} finally {
MDC.clear();
}
});
} catch (RuntimeException e) {
// 拆包,保持Servlet语义
if (e.getCause() instanceof IOException io) throw io;
if (e.getCause() instanceof ServletException se) throw se;
throw e;
}
}
/**
* 构建请求上下文
*/
private RequestContext buildContext(HttpServletRequest request) {
// 链路ID
String traceId = Optional.ofNullable(request.getHeader("X-Trace-Id")).orElse(UUID.randomUUID().toString());
// 用户ID
String userId = request.getHeader("X-User-Id");
// 租户ID
String tenantId = request.getHeader("X-Tenant-Id");
return RequestContext.builder().traceId(traceId).userId(userId).tenantId(tenantId).build();
}
}
6、ScopedValue和StructuredTaskScope使用方式
java
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@RestController
public class DemoController {
@GetMapping("/test")
public String test() {
// 获取参数
String tenantId = RequestContextHolder.getTenantId();
String userId = RequestContextHolder.getUserId();
// 新虚拟线程执行
RequestContext ctx = RequestContextHolder.getOptional().orElseThrow();
ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();
executor.submit(() -> {
RequestContextHolder.with(ctx, () -> {
// 异步任务中依然可以获取 traceId / userId
System.out.println("traceId=" + RequestContextHolder.getTraceId());
});
});
return "tenantId=" + tenantId + ", userId=" + userId;
}
/**
* StructuredTaskScope实现:同步写法 + 并发执行 + 自动失败传播
* 非常类似WebFlux/Reactor的:Mono.zip(callA(), callB()).map(tuple -> combine(tuple.getT1(), tuple.getT2()));
*/
@Transactional
public void service() {
try (var scope = new StructuredTaskScope.ShutdownOnFailure()) {
var a = scope.fork(this::taskA);
var b = scope.fork(this::taskB);
// 等待所有
scope.join();
// 有失败就抛
scope.throwIfFailed();
return combine(a.get(), b.get());
}
}
}