超好用的线程传递组件!从源码来学习TransmittableThreadLocal

1.初识TransmittableThreadLocal

这个组件是用来干嘛的?

我们可能一些场景,主线程设置一些参数,想要异步提高一个任务给线程池,这个参数传递我们一般会怎么做?

  • 肯定有同学想到了:可以对任务使用构造函数,往任务里面加参数就行了!
  • 但是如何参数很多呢?

上样例!

java 复制代码
// 假设任务需要多个参数
class ComplexTask implements Runnable {
    private final String param1;
    private final int param2;
    private final boolean param3;
    // 可能还有更多参数

    public ComplexTask(String param1, int param2, boolean param3) {
        this.param1 = param1;
        this.param2 = param2;
        this.param3 = param3;
    }

    @Override
    public void run() {
        // 任务逻辑
        System.out.println("hello");
    }
}

显而易见代码的维护性很差侵入性很高,这不是我们想要的

TransmittableThreadLocal就是来解决这个问题的

利用TransmittableThreadLocal重新维护后的代码

java 复制代码
import com.alibaba.ttl.TransmittableThreadLocal;
import com.alibaba.ttl.TtlRunnable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class TtlExample {
    private static final TransmittableThreadLocal<String> ttl = new TransmittableThreadLocal<>();

    public static void main(String[] args) {
        ExecutorService executorService = Executors.newFixedThreadPool(1);
        ttl.set("parameter");
        Runnable task = () -> {
            String value = ttl.get();
            // 任务逻辑使用 value
        };
        Runnable ttlTask = TtlRunnable.get(task);
        executorService.submit(ttlTask);
        executorService.shutdown();
    }
}

使用 TransmittableThreadLocal 可以避免在构造函数中传递大量参数,减少代码的侵入性。参数的设置和获取都在 TransmittableThreadLocal 中完成,代码结构更加清晰,便于维护和扩展。

2.TransmittableThreadLocal原理解析

2.1 继承关系

TransmittableThreadLocal 继承自 InheritableThreadLocal,而 InheritableThreadLocal 是 Java 标准库中 ThreadLocal 的子类,它允许子线程继承父线程中设置的 ThreadLocal 值。TransmittableThreadLocal 在此基础上进行了扩展,以支持在更复杂的线程池场景下的数据传递。

2.2 底层结构

  • TransmittableThreadLocal:作为核心类,负责存储和管理线程局部变量的值,同时提供了数据捕获和恢复的方法。
  • TtlRunnableTtlCallable :用于包装普通的 RunnableCallable 任务,在任务执行前后进行数据的捕获、传递和恢复操作。
  • TransmittableThreadLocal.Transmitter:封装了数据的捕获、传递和恢复的核心逻辑。

2.3 核心流程

1. 数据捕获

当提交一个任务到线程池时,TransmittableThreadLocal 会在提交任务的线程中捕获当前所有 TransmittableThreadLocal 实例的值。具体实现是通过 TransmittableThreadLocal.Transmitter.capture() 方法完成的,该方法会遍历所有注册的 TransmittableThreadLocal 实例,并将它们的值存储在一个 Map 中。

ini 复制代码
// 捕获当前线程中所有 TransmittableThreadLocal 的值
Map<TransmittableThreadLocal<?>, Object> captured = TransmittableThreadLocal.Transmitter.capture();

2. 任务包装

为了确保在子线程中能够正确恢复捕获的数据,TransmittableThreadLocal 提供了 TtlRunnableTtlCallable 类来包装普通的 RunnableCallable 任务。这些包装类会在任务执行前后进行数据的恢复和清理操作。

java 复制代码
// 包装普通的 Runnable 任务
Runnable task = () -> {
  // 任务逻辑
};
Runnable ttlTask = TtlRunnable.get(task);

3. 数据传递

当子线程开始执行包装后的任务时,TtlRunnableTtlCallable 会调用 TransmittableThreadLocal.Transmitter.replay(captured) 方法,将捕获的数据恢复到子线程的 TransmittableThreadLocal 中。在恢复数据之前,会先记录子线程中原来的 TransmittableThreadLocal 值,以便后续清理。

java 复制代码
// 恢复捕获的数据到子线程
Object backup = TransmittableThreadLocal.Transmitter.replay(captured);

4. 任务执行

在数据恢复完成后,子线程会执行包装后的任务逻辑。

5. 数据清理

任务执行完成后,TtlRunnableTtlCallable 会调用 TransmittableThreadLocal.Transmitter.restore(backup) 方法,将子线程的 TransmittableThreadLocal 恢复到执行任务之前的状态,避免数据污染。

java 复制代码
// 恢复子线程的 TransmittableThreadLocal 到原始状态
TransmittableThreadLocal.Transmitter.restore(backup);

2.4 时序图

这是底层逻辑的演变时序图,下面我将从源码角度剖析中间设计的精妙之处

3. 源码解析

3.1 Transmittercapture方法解析

java 复制代码
/**
 * Capture all {@link Transmittable}.
 *
 * @return the captured values
 */
@NonNull
public Capture capture() {
    final HashMap<Transmittable<Object, Object>, Object> transmit2Value = newHashMap(registeredTransmittableSet.size());
    for (Transmittable<Object, Object> transmittable : registeredTransmittableSet) {
        try {
            transmit2Value.put(transmittable, transmittable.capture());
        } catch (Throwable t) {
            propagateIfFatal(t);
            if (logger.isLoggable(Level.WARNING)) {
                logger.log(Level.WARNING, "exception when capture for transmittable " + transmittable +
                        "(class " + transmittable.getClass().getName() + "), just ignored; cause: " + t, t);
            }
        }
    }
    return new Snapshot(transmit2Value, null);
}

方法比较简单,只是捕获所有已注册的 Transmittable 对象的当前状态,并将这些状态存储在一个 Snapshot 对象中返回。这个过程会变历 registeredTransmittableSet 集合中的每个 Transmittable 对象,调用其 capture 方法获取捕获的数据,并将其存储在一个 HashMap 中。

3.2 TtlRunnableget方法解析

java 复制代码
/**
 * Factory method, wrap input {@link Runnable} to {@link TtlRunnable}.
 *
 * @param runnable                         input {@link Runnable}. if input is {@code null}, return {@code null}.
 * @param releaseTtlValueReferenceAfterRun release TTL value reference after run, avoid memory leak even if {@link TtlRunnable} is referred.
 * @param idempotent                       is idempotent mode or not. if {@code true}, just return input {@link Runnable} when it's {@link TtlRunnable},
 *                                         otherwise throw {@link IllegalStateException}.
 *                                         <B><I>Caution</I></B>: {@code true} will cover up bugs! <b>DO NOT</b> set, only when you know why.
 * @return Wrapped {@link Runnable}
 * @throws IllegalStateException when input is {@link TtlRunnable} already and not idempotent.
 */
@Nullable
@Contract(value = "null, _, _ -> null; !null, _, _ -> !null", pure = true)
public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
    if (runnable == null) return null;

    if (runnable instanceof TtlEnhanced) {
        // avoid redundant decoration, and ensure idempotency
        if (idempotent) return (TtlRunnable) runnable;
        else throw new IllegalStateException("Already TtlRunnable!");
    }
    return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
}

这个方法用来封装简单的任务对象而返回一个TtlRunnable任务对象,如果掉用默认的get方法后置的两个参数都是false,返回的是构造函数封装的对象,releaseTtlValueReferenceAfterRun作为成员变量封装在TtlRunnable中,它的作用是控制在 Runnable任务执行完毕后,是否释放 TransmittableThreadLocal值的引用。

  • 避免内存泄漏:TransmittableThreadLocal 用于在线程间传递上下文信息,当 TtlRunnable 实例被长期引用时,如果不释放 TransmittableThreadLocal 值的引用,可能会导致内存泄漏。通过将 releaseTtlValueReferenceAfterRun 设置为 true,可以在任务执行完成后及时释放这些引用,避免内存泄漏。

3.3 Transmitterreplay方法解析

less 复制代码
/**
 * Replay the captured values from {@link #capture()},
 * and return the backup values before replay.
 *
 * @param captured captured values {@link #capture()}
 * @return the backup values before replay
 * @see #capture()
 */
@NonNull
public Backup replay(@NonNull Capture captured) {
    final Object data = callback.beforeReplay();

    final Snapshot capturedSnapshot = (Snapshot) captured;
    final HashMap<Transmittable<Object, Object>, Object> transmit2Value = newHashMap(capturedSnapshot.transmit2Value.size());
    for (Map.Entry<Transmittable<Object, Object>, Object> entry : capturedSnapshot.transmit2Value.entrySet()) {
        Transmittable<Object, Object> transmittable = entry.getKey();
        try {
            Object transmitCaptured = entry.getValue();
            transmit2Value.put(transmittable, transmittable.replay(transmitCaptured));
        } catch (Throwable t) {
            propagateIfFatal(t);
            if (logger.isLoggable(Level.WARNING)) {
                logger.log(Level.WARNING, "exception when replay for transmittable " + transmittable +
                        "(class " + transmittable.getClass().getName() + "), just ignored; cause: " + t, t);
            }
        }
    }

    final Object afterData = callback.afterReplay(data);
    return new Snapshot(transmit2Value, afterData);
}

其实和get方法很相似,但是包含了一些数据处理操作,callback对象在捕获之前做数据的预处理,遍历结束后会再做一次后置处理,作用是在replay之前把数据记录下来,哪些是子线程的数据,后续会对这些冗余的数据批处理清除,防止内存溢出。

3.4 Transmitterrestore方法解析

typescript 复制代码
/**
 * Restore the backup values from {@link #replay(Capture)}/{@link #clear()}.
 *
 * @param backup the backup values from {@link #replay(Capture)}/{@link #clear()}
 * @see #replay(Capture)
 * @see #clear()
 */
public void restore(@NonNull Backup backup) {
    final Snapshot snapshot = (Snapshot) backup;
    final Object data = callback.beforeRestore(snapshot.data);

    for (Map.Entry<Transmittable<Object, Object>, Object> entry : snapshot.transmit2Value.entrySet()) {
        Transmittable<Object, Object> transmittable = entry.getKey();
        try {
            Object transmitBackup = entry.getValue();
            transmittable.restore(transmitBackup);
        } catch (Throwable t) {
            propagateIfFatal(t);
            if (logger.isLoggable(Level.WARNING)) {
                logger.log(Level.WARNING, "exception when restore for transmittable " + transmittable +
                        "(class " + transmittable.getClass().getName() + "), just ignored; cause: " + t, t);
            }
        }
    }

    callback.afterRestore(data);
}

replay方法备份的数据在这里就派上用场了,对transmittable对象遍历过程会调用自身的restore把备份数据恢复到上下文中,将子线程TransmittableThreadLocal恢复到执行任务之前的状态,避免数据污染

3.5 TtlRunnablerun 方法

java 复制代码
/**
 * wrap method {@link Runnable#run()}.
 */
@Override
public void run() {
    final Capture captured = capturedRef.get();
    if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
        throw new IllegalStateException("TTL value reference is released after run!");
    }

    final Backup backup = replay(captured);
    try {
        runnable.run();
    } finally {
        restore(backup);
    }
}

get方法中设置的默认false参数在这里就会有影响,如果设置releaseTtlValueReferenceAfterRuntrue的话会尝试把capturedRef引用设置成null,对引用进行清除操作避免长期引用占据导致的内存泄露

相关推荐
uhakadotcom34 分钟前
React Query:简化数据获取和状态管理的利器
后端·面试·github
小菜不菜。44 分钟前
本地部署 DeepSeek:从 Ollama 配置到 Spring Boot 集成
java·spring boot·后端
007php0071 小时前
企微审批中MySQL字段TEXT类型被截断的排查与修复实践
大数据·开发语言·数据库·后端·mysql·重构·golang
倚栏听风雨1 小时前
类型引起的索引失效-mybatis TypeHandler 完美解决
后端
猿毕设1 小时前
【FL0090】基于SSM和微信小程序的球馆预约系统
java·spring boot·后端·python·微信小程序·小程序
山间点烟雨2 小时前
3. 前后端实现压缩包文件下载
前端·后端·压缩包
m0_748238632 小时前
【SpringBoot3】Spring Boot 3.0 集成 Mybatis Plus
spring boot·后端·mybatis
aircrushin2 小时前
【PromptCoder + Cursor】利用AI智能编辑器快速实现设计稿
前端·后端·html
m0_748234343 小时前
搭建Golang gRPC环境:protoc、protoc-gen-go 和 protoc-gen-go-grpc 工具安装教程
开发语言·后端·golang
Oo_Amy_oO3 小时前
使用SDKMAN!安装springboot
spring boot·后端·sdkman