小白也能看懂!怎样从子线程获取到父线程的ThreadLocal值

名词

下面说的父线程并非Thread的parent,而是调用Thread.start()的线程,或者调用线程池execute的线程。

背景

1.Threadlocal常用于处理当前线程的变量存储,但是子线程无法直接用父线程的存储内容。

2.jdk提供了父子线程Threadlocal变量的获取,即InheritableThreadLocal

3.在线程池中的线程,使用InheritableThreadLocal只能继承(在创建线程时的父线程)的Threadlocal变量,在后续复用线程时,并不能获取到执行线程时的当前父线程Threadlocal变量。

在一些业务场景下,使用线程池级别的Threadlocal是必要的,比如全链路的分析,那么该怎么解决问题3呢。

Threadlocal原理分析

基本用法

直接上代码,Threadlocal的基本用法

java 复制代码
public class ThreadLocalTest {
    static ThreadLocal<String> threadLocal = new ThreadLocal<>();

    public static void main(String[] args) {
        threadLocal.set("hello");
        // 调用方法执行业务逻辑,不传参数来获取hello
        func();
    }

    public static void func() {
        System.out.println(threadLocal.get());
    }
}

可以看到与Threadlocal有关的代码只有构造器,set,get方法,那么就从这里入手。

构造器

构造器没什么说的,简简单单,仅仅执行创建实例,没有其他操作。

set方法

先看set方法

逻辑简单

1.获取当前线程

2.从getMap方法获取一个对应的ThreadLocalMap

3.map存在即赋值value;不存在即通过createMap创建一个map,再赋值value

ThreadLocalMap

接下来要看一下这个Map到底是什么。

可以看到是一个ThreadLocalMap,并且是在thread.threadLocals中存储的。

ThreadLocalMap不多说,是一个弱引用的Map结构,储存一个ThreadLocal实例与value的映射关系。

那么下一个目标就是thread.threadLocals

这就是ThreadLocal在Thread中的样子,本身并不储存value,是通过ThreadLocalMap储存的。 ThreadLocal仅仅是提供一个hashCode作为Map的key

由于我们可以定义多个ThreadLocal实例代表不同的变量,来在线程中作为(不同的业务需要)来使用,所以使用ThreadLocalMap进行多个ThreadLocal实例的存储。

get方法

知道了ThreadLocal的存储结构,接下来看见怎么读取。

与set类似

1.获取当前线程

2.获取ThreadLocalMap

3.从Map中获取value,是用当前ThreadLocal实例作为key。

4.如果不存在,设置初始值。这里不分析。

那么ThreadLocal的存储原理就很明朗了:每一个线程维护一个map,存储ThreadLocal的映射关系。无论有多少层的方法调用,都可以直接读取当前线程的threadlocals来获取目标值。

InheritableThreadLocal 原理

基本用法

java 复制代码
public class ThreadLocalTest {
    static ThreadLocal<String> threadLocal = new ThreadLocal<>();
    static ThreadLocal<String> inheritableThreadLocal = new InheritableThreadLocal<>();

    public static void main(String[] args) {
        threadLocal.set("hello");
        inheritableThreadLocal.set("world");
        // 调用方法执行业务逻辑,不传参数来获取hello
        func();
    }

    public static void func() {
        new Thread(() -> {
            System.out.println("子线程threadLocal:" + threadLocal.get());
            System.out.println("子线程inheritableThreadLocal:" + inheritableThreadLocal.get());
        }).start();
        System.out.println("main: " +threadLocal.get());
    }
}

结果: 子线程threadLocal:null

子线程inheritableThreadLocal:world

main: hello

InheritableThreadLocal很简单,仅仅重写了几个方法,只是赋值给了Thread.inheritableThreadLocals

通过追踪,发现了赋值的两个位置:

1.图里的createMap方法。

2.Thread的构造器中:

java 复制代码
private Thread(ThreadGroup g, Runnable target, String name,
               long stackSize, AccessControlContext acc,
               boolean inheritThreadLocals) {
 //   ...
    if (inheritThreadLocals && parent.inheritableThreadLocals != null)
        this.inheritableThreadLocals =
            ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
    /* Stash the specified stack size in case the VM cares */
    this.stackSize = stackSize;

    /* Set thread ID */
    this.tid = nextThreadID();
}

可以发现子线程是把父线程的这个属性copy一份赋值给自己。当然如果父线程没有进行set操作,那么就不会有inheritableThreadLocals,子线程也不用复制了。

其他过程与ThreadLocal一样,核心就是不再从thread.threadLocals读取,而是从父线程的拷贝inheritableThreadLocals读取,不多说。

线程池中的子线程继承父线程ThreadLocal方案

java 复制代码
public class ThreadLocalTest {
    static ThreadLocal<String> inheritableThreadLocal = new InheritableThreadLocal<>();

    static ExecutorService threadPool = Executors.newFixedThreadPool(2);

    public static void main(String[] args) throws InterruptedException {
        inheritableThreadLocal.set("hello");
        // 初始化线程池的线程,从Thread构造方法中进行copy
        for (int i = 0; i < 2; i++) {
            threadPool.execute(() -> {
                System.out.println(Thread.currentThread().getName() + ":" + inheritableThreadLocal.get());
            });
        }
        System.out.println(Thread.currentThread().getName() + ":" + inheritableThreadLocal.get());

        Thread.sleep(500);
        inheritableThreadLocal.set("world");
        // 复用线程池,不再拷贝,无法获取新值 question1
        for (int i = 0; i < 2; i++) {
            threadPool.execute(() -> {
                System.out.println(Thread.currentThread().getName() + ":" + inheritableThreadLocal.get());
            });
        }
        System.out.println(Thread.currentThread().getName() + ":" + inheritableThreadLocal.get());

    }
}

可以看到主线程第二次调用set("world")后,子线程并没有获取到。

当然我们可以这样做:

java 复制代码
inheritableThreadLocal.set("world");
String parentValue = inheritableThreadLocal.get();
// 复用线程池,不再拷贝,无法获取新值
for (int i = 0; i < 2; i++) {
    threadPool.execute(() -> {
        String oldValue = inheritableThreadLocal.get();
        try {
            inheritableThreadLocal.set(parentValue);
            System.out.println(Thread.currentThread().getName() + ":" + inheritableThreadLocal.get());
        } finally {
            // restore:一些场景下线程被托管后,要恢复原来的值继续运行
            inheritableThreadLocal.set(oldValue);
        }
    });
}

这样做有点麻烦,当ThreadLocal变量多起来时,维护就成了一个问题。而且有时候忘记这个操作,出现问题也不好排查。

有什么办法来把这个东西简化掉呢?

搜索question1,回头看这里

// 复用线程池,不再拷贝,无法获取新值 question1

重写Thread.start()

是不是在每一次线程运行(不是初始化)的时候,再进行一次copy就可以呢?

那就是把start的方法重写掉,讲道理是可以实现的。

但是这样做的话,也会有问题:

  1. 用到Thread的地方要改成重写后的Thread,在依赖库中我们没办法把构造Thread的代码全都替换成自己的,这样使用框架时就很麻烦。

2.Thread的parent属性,是在构造时赋值的Thread parent = currentThread();,想要在start中获取线程池提交运行时的父线程就有点难。

那么问题1是最致命的,无法做到完美的适配。

不妨把时机再下沉一下,通过threadLocal读写的时候再拷贝。

重写ThreadLocal

讲道理这样就可以解决上面的问题1,因为ThreadLocal的定义与存储是我们可控的,不受框架限制。

按照这个两个思路,我们需要在Thread中增加一个ThreadLocalMap类型(暂时叫pooledThreadLocals吧)(或者修改inheritableThreadLocals)的属性,用来在子线程中接收父线程的ThreadLocalMap拷贝。

先别急!!!Thread肯定是不能改了,ThreadLocal也只暴露了3个方法(get, set, remove),想重写createMap都不容易。

既然没办法直接增改属性,我们可以使用神器ThreadLocal

反正都是根据Thread来操作一个变量,存入Thread中与ThreadLocal变量中都一样。

现在需要一个全局的ThreadLocal来存ThreadLocalMap,用来为线程池中子线程存储父线程的map拷贝。

不过ThreadLocalMap这东西可不是随便用的,权限限制了。退一步,我们用WeakHashMap代替它。

java 复制代码
public static class PooledThreadLocal extends ThreadLocal<String> {
    static ThreadLocal<WeakHashMap<PooledThreadLocal, String>> holder = new ThreadLocal<>();
    @Override
    public String get() {
        return super.get();
    }

    @Override
    public void set(String value) {
        super.set(value);
    }
}

看到没,使用ThreadLocal来代替Thread的pooledThreadLocals属性,存的结构也类似,都是Map<ThreadLocal,Object>,这就离目标只有一步了:怎么copy。

老样子,先处理set方法。仿照ThreadLocal的set方法

ThreadLocal.java:

java 复制代码
public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        map.set(this, value);
    } else {
        createMap(t, value);
    }
}
java 复制代码
@Override
public void set(String value) {
    // 获取当前线程的pooledThreadLocals变量
    WeakHashMap<PooledThreadLocal, String> pooledThreadLocals = holder.get();
    // 没有的话创建
    if (pooledThreadLocals == null) {
        pooledThreadLocals = new WeakHashMap<>();
        holder.set(pooledThreadLocals);
    }
    // set新value
    pooledThreadLocals.put(this, value);
}

看到没,一模一样的流程。再来处理get方法

kotlin 复制代码
@Override
public String get() {
    // 获取当前线程的pooledThreadLocals变量
    WeakHashMap<PooledThreadLocal, String> pooledThreadLocals = holder.get();
    // 没有的话创建
    if (pooledThreadLocals == null) {
        pooledThreadLocals = new WeakHashMap<>();
        holder.set(pooledThreadLocals);
    }
    return pooledThreadLocals.get(this);
}

最后,为了让子线程访问到父线程的pooledThreadLocals,需要使用InheritableThreadLocal

再测试一下

java 复制代码
static PooledThreadLocal threadLocal = new PooledThreadLocal();

public static void main(String[] args) throws InterruptedException {
    threadLocal.set("hello");
    // 初始化线程池的线程,从Thread构造方法中进行copy
    System.out.println("线程池新建线程,父线程main: hello");
    for (int i = 0; i < 2; i++) {
        threadPool.execute(() -> {
            System.out.println(Thread.currentThread().getName() + ":" + threadLocal.get());
        });
    }
    System.out.println(Thread.currentThread().getName() + ":" + threadLocal.get());

    Thread.sleep(500);
    threadLocal.set("world");
    System.out.println("复用线程池重新设置world");
    for (int i = 0; i < 2; i++) {
        threadPool.execute(() -> {
            System.out.println(Thread.currentThread().getName() + ":" + threadLocal.get());
        });
    }
    System.out.println(Thread.currentThread().getName() + ":" + threadLocal.get());

    Thread.sleep(500);
    System.out.println("启动新线程代替main线程");
    new Thread(() -> {
        threadLocal.set("welcome");
        System.out.println("复用线程池,使用新的父线程 welcome");
        Thread threadNew = Thread.currentThread();
        for (int i = 0; i < 4; i++) {
            threadPool.submit(() -> {
                System.out.println(Thread.currentThread().getName() + ":" + threadLocal.get());
            });
        }
    }).start();


    System.out.println("main: " + threadLocal.get());
}

可以基本满足需要了(子线程获取父线程最新数据)。不过要注意,目前的holder.get()获取的都是同一个WeekHashMap,即父子共享了,子线程如果set值,也会影响父线程或其他子线程。

机智的同学可以发现,这种共享会出现污染问题(那不就是一个全局变量就能解决吗,用这个解决了啥)。

为了每一个线程独自维护一个WeekHashMap,我们这样做:

java 复制代码
static InheritableThreadLocal<WeakHashMap<PooledThreadLocal, String>> holder = new InheritableThreadLocal<>() {
    @Override
    protected WeakHashMap<PooledThreadLocal, String> initialValue() {
        return new WeakHashMap<>();
    }

    @Override
    protected WeakHashMap<PooledThreadLocal, String> childValue(WeakHashMap<PooledThreadLocal, String> parentValue) {
        return new WeakHashMap<PooledThreadLocal, String>(parentValue);
    }
};

通过childValue,让holder对于每一个线程创建一个独立的WeakHashMap的浅拷贝。

不过这个结果还是不对,只能在线程Thread创建时正确传递了值,效果同InheritableThreadLocal

因为childValue只能在线程初始化时,复制父线程的Map时调用ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);

看来通过重写ThreadLocal.get/set方法的企图失败了。还是那个原因:

对父线程ThreadLocalMap的复制只能从Thread的构造方法中开始。

而且我们的目的应该是在Thread.start()方法执行时,对threadLocalMap进行重新赋值。在set/get方法中处理还是不能满足其他场景。

既然不能重写Thread.start()方法,那么接下来的思路剩下两个地方。自定义线程池;自定义Runnable。

线程池尝试处理

可不可以重写execute方法来获取ThreadLocal变量再赋值呢?

线程池也是调用Thread.start()方法来执行的,跟普通执行没区别,所以这个方案也不行。

自定义Runnable

那么可不可以在构造Runnable的时候,在其中存一下ThreadLocal变量呢,构造的过程是在父线程直接调用,肯定可以拿到父线程的ThreadLocal。

类似这样

java 复制代码
public static class PooledRunnable implements Runnable {
    private final WeakHashMap<PooledThreadLocal, String> runnablePooledThreadLocals;

    public PooledRunnable() {
        this.runnablePooledThreadLocals = new WeakHashMap<>(PooledThreadLocal.holder.get());
       
    }

    @Override
    public void run() {
        // 把ThreadLocal和runnablePooledThreadLocals关联起来
    }
}

直接这么干肯定不行,因为要想在子线程操作数据,只能从run改.

那么我们代理一下,传来一个Runnable做参数而不是直接实现,在代理的run中进行pooledThreadLocals的设置。

java 复制代码
public static class PooledRunnable implements Runnable {
    private final Runnable actualRunnable;
    private final WeakHashMap<PooledThreadLocal, String> copyFromParent;
    public PooledRunnable(Runnable runnable, WeakHashMap<PooledThreadLocal, String> copyFromParent) {
        actualRunnable = runnable;
        this.copyFromParent = copyFromParent;
    }
    public static PooledRunnable get(Runnable runnable) {
        // 获取当前线程所有的ThreadLocals的Map,复制一个
        WeakHashMap<PooledThreadLocal, String> copyFromParent = new WeakHashMap<>(PooledThreadLocal.holder.get());
        return new PooledRunnable(runnable, copyFromParent);
    }

    @Override
    public void run() {
        // 在子线程中运行
        WeakHashMap<PooledThreadLocal, String> oldMap = null;
        try {
            // 获取老值,当前线程的所有PooledThreadLocal实例的映射关系
            oldMap = PooledThreadLocal.holder.get();
            // 设置新值,从父线调用get程获取的父线程pooledThreadLocals属性
            PooledThreadLocal.holder.set(this.copyFromParent);
            actualRunnable.run();
        } finally {
            // 恢复老值,以便后续使用
            PooledThreadLocal.holder.set(oldMap);
        }
    }
}

这样简单处理一下,基本能满足要求。测试一下看看:

java 复制代码
threadLocal.set("hello");
// 初始化线程池的线程,从Thread构造方法中进行copy
System.out.println("线程池新建线程,父线程main: hello");
for (int i = 0; i < 2; i++) {
    threadPool.execute(PooledRunnable.get(() -> {
        System.out.println("test1:" + Thread.currentThread().getName() + ":" + threadLocal.get());
    }));
}
System.out.println(Thread.currentThread().getName() + ":" + threadLocal.get());

System.out.println("-------------------");
Thread.sleep(1000);
threadLocal.set("world");
System.out.println("复用线程池重新设置world");
// 复用线程池,不再拷贝,无法获取新值
for (int i = 0; i < 2; i++) {
    threadPool.execute(PooledRunnable.get(() -> {
        System.out.println("test2:" + Thread.currentThread().getName() + ":" + threadLocal.get());
        threadLocal.set("我是子线程改的");
    }));
}
System.out.println(Thread.currentThread().getName() + ":" + threadLocal.get());
System.out.println("-------------------");

Thread.sleep(1000);
System.out.println("main:" + threadLocal.get());

可以看到,在子线程改的值并不影响其他线程,这个才是一个应该有的样子。

有了这个,那么就可以修改线程池的execute方法了,仅仅需要使用PooledRunnable.get包装一下,就可以实现上一小节没实现的功能。

注意:PooledRunnable.get方法还要处理一下,防止接收了PooledRunnable做参数,导致嵌套层级过多,导致其run中的代码多次执行。

总结

好了,这就是我的思考过程,初步实现了这样的功能,还不够完善。

其实核心就是:

1.模拟Thread.threadLocals,创建自己的ThreadLocalMap来存需要从父线程传递的ThreadLocal的映射。

2.既然Thread没有开放我们需要的api,那么使用ThreadLocal来代替Thread的直接成员属性, 即Thread.abc属性 equals 在Thread当前线程的ThreadLocal abc = new ThreadLocal()的映射值。两者表达的是同一个意思。

也许可以使用反射,但是代价太大,还影响了jdk的安全性,所以不这么做。

如果你听说过TransmittableThreadLocal,可以看到他们很相似,TransmittableThreadLocal还做了一些优化来支持更多的场景。如果你有这个需求,那么可以参考下面的链接,处理的很完善,方案也很多。

参考: alibaba/transmittable-thread-local

相关推荐
xlsw_2 小时前
java全栈day20--Web后端实战(Mybatis基础2)
java·开发语言·mybatis
神仙别闹3 小时前
基于java的改良版超级玛丽小游戏
java
黄油饼卷咖喱鸡就味增汤拌孜然羊肉炒饭3 小时前
SpringBoot如何实现缓存预热?
java·spring boot·spring·缓存·程序员
暮湫3 小时前
泛型(2)
java
超爱吃士力架4 小时前
邀请逻辑
java·linux·后端
南宫生4 小时前
力扣-图论-17【算法学习day.67】
java·学习·算法·leetcode·图论
转码的小石4 小时前
12/21java基础
java
李小白664 小时前
Spring MVC(上)
java·spring·mvc
GoodStudyAndDayDayUp4 小时前
IDEA能够从mapper跳转到xml的插件
xml·java·intellij-idea
装不满的克莱因瓶5 小时前
【Redis经典面试题六】Redis的持久化机制是怎样的?
java·数据库·redis·持久化·aof·rdb