浅谈InheritableThreadLocal---线程可继承的小书包

在前文中我们讲过ThreadLocal,相当于是每个线程有一个小书包,线程之间的小书包是隔离的,只存放了属于当前线程自己的变量,因此不会发生数据安全的问题。

(前文博客浅谈ThreadLocal----每个线程一个小书包 https://www.cnblogs.com/jilodream/p/19118986)

但是有时任务太繁重时,父线程希望new出新的子线程来为自己的业务提供帮助,同时希望子线程在处理时,也能用到自己保存在ThreadLocal中的变量。

现在有三种办法:
(1)直接把父线程的ThreadLocalMap传递给子线程,让子线程直接拿去用
如:

复制代码
sonThread.threadLocals=parent.threadLocals

(2)父线程在创建子线程时将子线程需要用到的ThreadLocal数据,专门指定。

如:

复制代码
 Thread sonThread=new Thread(new Runnable() {
            @Override
            public void run() {
                threadLocal1.set(xxx1); //手动指定
                threadLocal2.set(xxx2); //手动指定
                threadLocal3.set(xxx3); //手动指定
            }
        });

(3)将父线程中threadLocals 中的数据,打包整理后,统一传递给子线程。

第一种显然不行,如果两者指向了相同的Map 会导致缓存数据被共享,父子线程存在并发读写,又会导致安全问题。
方法二虽然灵活,但是指定起来过于繁琐,每个子线程都要单独设置一遍。
java选用的是方法三,但是实现起来稍有不同:
java 定义了一个InheritableThreadLocal类。通过这个类来实现线程可继承的ThreadLocal
定义中的:
Inheritable [ɪn'herɪtəbl]
adj. 可继承的,会遗传的
InheritableThreadLocal指可以继承/遗传的ThreadLocal。接下来看源码:

在Thread类中,定义了以下属性:

复制代码
    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null; 

类似于thread.threadLocals,这是给InheritableThreadLocal 来存储可以继承的属性的Map。

InheritableThreadLocal类的源码,它继承了ThreadLocal:

复制代码
public class InheritableThreadLocal<T> extends ThreadLocal<T> {
    /**
     * Creates an inheritable thread local variable.
     */
    public InheritableThreadLocal() {}

    /**
     * Computes the child's initial value for this inheritable thread-local
     * variable as a function of the parent's value at the time the child
     * thread is created.  This method is called from within the parent
     * thread before the child is started.
     * <p>
     * This method merely returns its input argument, and should be overridden
     * if a different behavior is desired.
     *
     * @param parentValue the parent thread's value
     * @return the child thread's initial value
     */
    protected T childValue(T parentValue) {
        return parentValue;
    }

    /**
     * Get the map associated with a ThreadLocal.
     *
     * @param t the current thread
     */
    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }

    /**
     * Create the map associated with a ThreadLocal.
     *
     * @param t the current thread
     * @param firstValue value for the initial entry of the table.
     */
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

通过源码我们可以发现,它的用法和实现与ThreadLocal基本相同,但是它重写了父类(ThreadLocal)的几个方法:

(1)
首先是getMap(),重写的目的是当需要操作map对象时,(防盗连接:本文首发自http://www.cnblogs.com/jilodream/ )请使用inheritableThreadLocals而不是再是父类中的t.threadLocals写法,这样所有使用的map就都切换到了
ThreadLocal.ThreadLocalMap inheritableThreadLocals这个属性上来了。

(2)
其次是createMap() 方法,该方法的主要目的是初始化Map对象
这里强调了初始化的是 t.inheritableThreadLocals 属性。为啥没直接用getMap() 方法来获取 t.inheritableThreadLocals呢?这是由于 t.inheritableThreadLocals 还没有初始化,返回是null,你调用getMap()没有意义。
(3)
最后是 childValue() 方法,它是指当发生继承动作时,父类中的存储的变量转化为子类对象的转化转换。这里直接抽象成方法了,方便大家重写自己的InheritableThreadLocal类时,可以直接重写该方法。
如我们想前边加一个标签,代表是子线程专用,亦或者进行翻译转换等等,总之方便你自己实现转换。

来看下继承动作具体是怎么操作的:
继承主要发生在主线程创建子线程程时,我们来看下Thread的构造方法源码,经过一堆跳转最后会跳转进这个方法:

复制代码
    private Thread(ThreadGroup g, Runnable target, String name,
                   long stackSize, AccessControlContext acc,
                   boolean inheritThreadLocals) {
        //....
        
        if (inheritThreadLocals && parent.inheritableThreadLocals != null)
            this.inheritableThreadLocals =
                ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
        /* Stash the specified stack size in case the VM cares */
        this.stackSize = stackSize;

        /* Set thread ID */
        this.tid = nextThreadID();
    }
    

方法参数中的boolean inheritThreadLocals表示是否要继承/遗传。手动创建的线程,这里都会是true,代表要继承/遗传。

parent.inheritableThreadLocals != null 判断父线程的inheritableThreadLocals 属性是否为null (是否初始化了),如果不为null(父线程有需要遗传的本地变量),才发生遗传动作。
条件都满足则发生遗传,(防盗连接:本文首发自http://www.cnblogs.com/jilodream/ )入参为父线程的遗传本地变量,也就是属性inheritableThreadLocals。

之后进入ThreadLocalMap的构造方法,开始构造子线程的inheritableThreadLocals 属性:

复制代码
        private ThreadLocalMap(ThreadLocalMap parentMap) {
            Entry[] parentTable = parentMap.table;
            int len = parentTable.length;
            setThreshold(len);
            table = new Entry[len];

            for (Entry e : parentTable) {
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                    if (key != null) {
                        Object value = key.childValue(e.value);
                        Entry c = new Entry(key, value);
                        int h = key.threadLocalHashCode & (len - 1);
                        while (table[h] != null)
                            h = nextIndex(h, len);
                        table[h] = c;
                        size++;
                    }
                }
            }
        }

方法中就是遍历父线程的Map生成子线程自己的Map.Entry 了。
这里注意,在获取Entry的key 时,通过entry.get() 拿到父线程的Map 的key的弱引用,强制转化为ThreadLocal类型即可。
在获取value 值时,调用的是key的childValue()方法,也就是InheritableThreadLocal.childValue()中重写的方法,将父线程的value值转为子线程的value时。
这样子线程中:
1、map初始化好了;
2、entry元素也通过父线程都转移到子线程中了;
3、获取的遗传map,使用的都是遗传map了。(通过getMap()重写)
map的get set remove 等核心逻辑都直接使用父类的逻辑。(因为InheritableThreadLocal继承自ThreadLocal,并且没有重写这段逻辑)
总体上了来说,ThreadLocal,InheritableThreadLocal的实现都非常的优雅,不但很好的利用了对象的继承,保证用户在使用时无感知的发生了继承。其次用户在自定义自己的InheritableThreadLocal class时,(防盗连接:本文首发自http://www.cnblogs.com/jilodream/ )也只需要完成如何转化即可。非常值得我们自己在设计类结构和关系时参考这里的细节设计。

类关系整体的结构如下:

除此之外,我们也可以通过源码发现InheritableThreadLocal的一个特性,就是属性的遗传来自于父InheritableThreadLocal的全量属性,是不能根据某个线程自定义的。

并且在遗传时,子线程是在被new出实例时,就已经获得了此刻全部的属性,如果父线程后续调整了InheritableThreadLocal的范围,子线程是感知不到的。
来看这样一个例子即可:

复制代码
    public static void main(String[] args) {
        ThreadLocal<String> tl1 = new InheritableThreadLocal<>();
        ThreadLocal<String> tl2 = new InheritableThreadLocal<>();
        ThreadLocal<String> tl3 = new InheritableThreadLocal<>();
        tl1.set("t1");
        tl2.set("t2");
        Thread sonThread = new Thread(new Runnable() {
            @Override
            public void run() {
                String s1 = tl1.get();
                String s2 = tl2.get();
                String s3 = tl3.get();
                System.out.println("out:" + s1);
                System.out.println("out:" + s2);
                System.out.println("out:" + s3);
            }
        });
        tl2.set("newTl2Value");
        sonThread.start();
    }

输出如下,所有的值都是在new 子线程时就已经发生了,而并不是在子线程start之前:

复制代码
out:t1
out:t2
out:null