目录

超好用的线程传递组件!从源码来学习TransmittableThreadLocal

1.初识TransmittableThreadLocal

这个组件是用来干嘛的?

我们可能一些场景,主线程设置一些参数,想要异步提高一个任务给线程池,这个参数传递我们一般会怎么做?

  • 肯定有同学想到了:可以对任务使用构造函数,往任务里面加参数就行了!
  • 但是如何参数很多呢?

上样例!

java 复制代码
// 假设任务需要多个参数
class ComplexTask implements Runnable {
    private final String param1;
    private final int param2;
    private final boolean param3;
    // 可能还有更多参数

    public ComplexTask(String param1, int param2, boolean param3) {
        this.param1 = param1;
        this.param2 = param2;
        this.param3 = param3;
    }

    @Override
    public void run() {
        // 任务逻辑
        System.out.println("hello");
    }
}

显而易见代码的维护性很差侵入性很高,这不是我们想要的

TransmittableThreadLocal就是来解决这个问题的

利用TransmittableThreadLocal重新维护后的代码

java 复制代码
import com.alibaba.ttl.TransmittableThreadLocal;
import com.alibaba.ttl.TtlRunnable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class TtlExample {
    private static final TransmittableThreadLocal<String> ttl = new TransmittableThreadLocal<>();

    public static void main(String[] args) {
        ExecutorService executorService = Executors.newFixedThreadPool(1);
        ttl.set("parameter");
        Runnable task = () -> {
            String value = ttl.get();
            // 任务逻辑使用 value
        };
        Runnable ttlTask = TtlRunnable.get(task);
        executorService.submit(ttlTask);
        executorService.shutdown();
    }
}

使用 TransmittableThreadLocal 可以避免在构造函数中传递大量参数,减少代码的侵入性。参数的设置和获取都在 TransmittableThreadLocal 中完成,代码结构更加清晰,便于维护和扩展。

2.TransmittableThreadLocal原理解析

2.1 继承关系

TransmittableThreadLocal 继承自 InheritableThreadLocal,而 InheritableThreadLocal 是 Java 标准库中 ThreadLocal 的子类,它允许子线程继承父线程中设置的 ThreadLocal 值。TransmittableThreadLocal 在此基础上进行了扩展,以支持在更复杂的线程池场景下的数据传递。

2.2 底层结构

  • TransmittableThreadLocal:作为核心类,负责存储和管理线程局部变量的值,同时提供了数据捕获和恢复的方法。
  • TtlRunnableTtlCallable :用于包装普通的 RunnableCallable 任务,在任务执行前后进行数据的捕获、传递和恢复操作。
  • TransmittableThreadLocal.Transmitter:封装了数据的捕获、传递和恢复的核心逻辑。

2.3 核心流程

1. 数据捕获

当提交一个任务到线程池时,TransmittableThreadLocal 会在提交任务的线程中捕获当前所有 TransmittableThreadLocal 实例的值。具体实现是通过 TransmittableThreadLocal.Transmitter.capture() 方法完成的,该方法会遍历所有注册的 TransmittableThreadLocal 实例,并将它们的值存储在一个 Map 中。

ini 复制代码
// 捕获当前线程中所有 TransmittableThreadLocal 的值
Map<TransmittableThreadLocal<?>, Object> captured = TransmittableThreadLocal.Transmitter.capture();

2. 任务包装

为了确保在子线程中能够正确恢复捕获的数据,TransmittableThreadLocal 提供了 TtlRunnableTtlCallable 类来包装普通的 RunnableCallable 任务。这些包装类会在任务执行前后进行数据的恢复和清理操作。

java 复制代码
// 包装普通的 Runnable 任务
Runnable task = () -> {
  // 任务逻辑
};
Runnable ttlTask = TtlRunnable.get(task);

3. 数据传递

当子线程开始执行包装后的任务时,TtlRunnableTtlCallable 会调用 TransmittableThreadLocal.Transmitter.replay(captured) 方法,将捕获的数据恢复到子线程的 TransmittableThreadLocal 中。在恢复数据之前,会先记录子线程中原来的 TransmittableThreadLocal 值,以便后续清理。

java 复制代码
// 恢复捕获的数据到子线程
Object backup = TransmittableThreadLocal.Transmitter.replay(captured);

4. 任务执行

在数据恢复完成后,子线程会执行包装后的任务逻辑。

5. 数据清理

任务执行完成后,TtlRunnableTtlCallable 会调用 TransmittableThreadLocal.Transmitter.restore(backup) 方法,将子线程的 TransmittableThreadLocal 恢复到执行任务之前的状态,避免数据污染。

java 复制代码
// 恢复子线程的 TransmittableThreadLocal 到原始状态
TransmittableThreadLocal.Transmitter.restore(backup);

2.4 时序图

这是底层逻辑的演变时序图,下面我将从源码角度剖析中间设计的精妙之处

3. 源码解析

3.1 Transmittercapture方法解析

java 复制代码
/**
 * Capture all {@link Transmittable}.
 *
 * @return the captured values
 */
@NonNull
public Capture capture() {
    final HashMap<Transmittable<Object, Object>, Object> transmit2Value = newHashMap(registeredTransmittableSet.size());
    for (Transmittable<Object, Object> transmittable : registeredTransmittableSet) {
        try {
            transmit2Value.put(transmittable, transmittable.capture());
        } catch (Throwable t) {
            propagateIfFatal(t);
            if (logger.isLoggable(Level.WARNING)) {
                logger.log(Level.WARNING, "exception when capture for transmittable " + transmittable +
                        "(class " + transmittable.getClass().getName() + "), just ignored; cause: " + t, t);
            }
        }
    }
    return new Snapshot(transmit2Value, null);
}

方法比较简单,只是捕获所有已注册的 Transmittable 对象的当前状态,并将这些状态存储在一个 Snapshot 对象中返回。这个过程会变历 registeredTransmittableSet 集合中的每个 Transmittable 对象,调用其 capture 方法获取捕获的数据,并将其存储在一个 HashMap 中。

3.2 TtlRunnableget方法解析

java 复制代码
/**
 * Factory method, wrap input {@link Runnable} to {@link TtlRunnable}.
 *
 * @param runnable                         input {@link Runnable}. if input is {@code null}, return {@code null}.
 * @param releaseTtlValueReferenceAfterRun release TTL value reference after run, avoid memory leak even if {@link TtlRunnable} is referred.
 * @param idempotent                       is idempotent mode or not. if {@code true}, just return input {@link Runnable} when it's {@link TtlRunnable},
 *                                         otherwise throw {@link IllegalStateException}.
 *                                         <B><I>Caution</I></B>: {@code true} will cover up bugs! <b>DO NOT</b> set, only when you know why.
 * @return Wrapped {@link Runnable}
 * @throws IllegalStateException when input is {@link TtlRunnable} already and not idempotent.
 */
@Nullable
@Contract(value = "null, _, _ -> null; !null, _, _ -> !null", pure = true)
public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
    if (runnable == null) return null;

    if (runnable instanceof TtlEnhanced) {
        // avoid redundant decoration, and ensure idempotency
        if (idempotent) return (TtlRunnable) runnable;
        else throw new IllegalStateException("Already TtlRunnable!");
    }
    return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
}

这个方法用来封装简单的任务对象而返回一个TtlRunnable任务对象,如果掉用默认的get方法后置的两个参数都是false,返回的是构造函数封装的对象,releaseTtlValueReferenceAfterRun作为成员变量封装在TtlRunnable中,它的作用是控制在 Runnable任务执行完毕后,是否释放 TransmittableThreadLocal值的引用。

  • 避免内存泄漏:TransmittableThreadLocal 用于在线程间传递上下文信息,当 TtlRunnable 实例被长期引用时,如果不释放 TransmittableThreadLocal 值的引用,可能会导致内存泄漏。通过将 releaseTtlValueReferenceAfterRun 设置为 true,可以在任务执行完成后及时释放这些引用,避免内存泄漏。

3.3 Transmitterreplay方法解析

less 复制代码
/**
 * Replay the captured values from {@link #capture()},
 * and return the backup values before replay.
 *
 * @param captured captured values {@link #capture()}
 * @return the backup values before replay
 * @see #capture()
 */
@NonNull
public Backup replay(@NonNull Capture captured) {
    final Object data = callback.beforeReplay();

    final Snapshot capturedSnapshot = (Snapshot) captured;
    final HashMap<Transmittable<Object, Object>, Object> transmit2Value = newHashMap(capturedSnapshot.transmit2Value.size());
    for (Map.Entry<Transmittable<Object, Object>, Object> entry : capturedSnapshot.transmit2Value.entrySet()) {
        Transmittable<Object, Object> transmittable = entry.getKey();
        try {
            Object transmitCaptured = entry.getValue();
            transmit2Value.put(transmittable, transmittable.replay(transmitCaptured));
        } catch (Throwable t) {
            propagateIfFatal(t);
            if (logger.isLoggable(Level.WARNING)) {
                logger.log(Level.WARNING, "exception when replay for transmittable " + transmittable +
                        "(class " + transmittable.getClass().getName() + "), just ignored; cause: " + t, t);
            }
        }
    }

    final Object afterData = callback.afterReplay(data);
    return new Snapshot(transmit2Value, afterData);
}

其实和get方法很相似,但是包含了一些数据处理操作,callback对象在捕获之前做数据的预处理,遍历结束后会再做一次后置处理,作用是在replay之前把数据记录下来,哪些是子线程的数据,后续会对这些冗余的数据批处理清除,防止内存溢出。

3.4 Transmitterrestore方法解析

typescript 复制代码
/**
 * Restore the backup values from {@link #replay(Capture)}/{@link #clear()}.
 *
 * @param backup the backup values from {@link #replay(Capture)}/{@link #clear()}
 * @see #replay(Capture)
 * @see #clear()
 */
public void restore(@NonNull Backup backup) {
    final Snapshot snapshot = (Snapshot) backup;
    final Object data = callback.beforeRestore(snapshot.data);

    for (Map.Entry<Transmittable<Object, Object>, Object> entry : snapshot.transmit2Value.entrySet()) {
        Transmittable<Object, Object> transmittable = entry.getKey();
        try {
            Object transmitBackup = entry.getValue();
            transmittable.restore(transmitBackup);
        } catch (Throwable t) {
            propagateIfFatal(t);
            if (logger.isLoggable(Level.WARNING)) {
                logger.log(Level.WARNING, "exception when restore for transmittable " + transmittable +
                        "(class " + transmittable.getClass().getName() + "), just ignored; cause: " + t, t);
            }
        }
    }

    callback.afterRestore(data);
}

replay方法备份的数据在这里就派上用场了,对transmittable对象遍历过程会调用自身的restore把备份数据恢复到上下文中,将子线程TransmittableThreadLocal恢复到执行任务之前的状态,避免数据污染

3.5 TtlRunnablerun 方法

java 复制代码
/**
 * wrap method {@link Runnable#run()}.
 */
@Override
public void run() {
    final Capture captured = capturedRef.get();
    if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
        throw new IllegalStateException("TTL value reference is released after run!");
    }

    final Backup backup = replay(captured);
    try {
        runnable.run();
    } finally {
        restore(backup);
    }
}

get方法中设置的默认false参数在这里就会有影响,如果设置releaseTtlValueReferenceAfterRuntrue的话会尝试把capturedRef引用设置成null,对引用进行清除操作避免长期引用占据导致的内存泄露

本文是转载文章,点击查看原文
如有侵权,请联系 xyy@jishuzhan.net 删除
相关推荐
why1518 小时前
腾讯(QQ浏览器)后端开发
开发语言·后端·golang
浪裡遊8 小时前
跨域问题(Cross-Origin Problem)
linux·前端·vue.js·后端·https·sprint
声声codeGrandMaster8 小时前
django之优化分页功能(利用参数共存及封装来实现)
数据库·后端·python·django
呼Lu噜8 小时前
WPF-遵循MVVM框架创建图表的显示【保姆级】
前端·后端·wpf
bing_1588 小时前
为什么选择 Spring Boot? 它是如何简化单个微服务的创建、配置和部署的?
spring boot·后端·微服务
学c真好玩9 小时前
Django创建的应用目录详细解释以及如何操作数据库自动创建表
后端·python·django
Asthenia04129 小时前
GenericObjectPool——重用你的对象
后端
Piper蛋窝9 小时前
Go 1.18 相比 Go 1.17 有哪些值得注意的改动?
后端
excel9 小时前
招幕技术人员
前端·javascript·后端
盖世英雄酱581369 小时前
什么是MCP
后端·程序员