源码系列 之 ThreadLocal

简介

ThreadLocal的作用是做数据隔离 ,存储的变量只属于当前线程,相当于当前线程的局部变量 ,多线程环境下,不会被别的线程访问与修改。常用于存储线程私有成员变量上下文 ,和用于同一线程,不同层级方法间传参 等。JDK 1.8 中的 ThreadLocal 共741行代码,其中包含3个成员变量,13个成员方法和两个内部类。 我们先来看下核心原理,再来详细看下源码。

问题

我们可以带着问题去学习这部分内容,希望学习完后,能回答这些问题:

  1. ThreadLocal能不能代替Synchronized?和Synchronized的区别是什么?
  2. Thread、ThreadLocal、ThreadLocalMap的关系是怎么样的?
  3. 存储在jvm的堆还是栈中?
  4. ThreadLocal会导致内存泄漏吗,为什么?
  5. ThreadLocalMap为什么用Entry数组而不是Entry对象?
  6. ThreadLocal里的对象一定是线程安全的吗?
  7. ThreadLocalMap只用单纯的数组存值吗?如果出现哈希冲突怎么存值?

核心原理

这个要从java.lang.Thread类说起,每个Thread对象中都拥有一个ThreadLocalMap(ThreadLocal的内部类)的成员变量。ThreadLocalMap内部又拥有一个Entry数组,每个Entry是一个键值对,key是ThreadLocal本身,value是ThreadLocal的泛型值。 (这里Thread类虽然有ThreadLocalMap成员变量,但没有get(),set(),remove()等增删改查的方法,其实就是通过ThreadLocal来操作的。)

例如我们每一次请求,就是一个线程,然后一个线程里就有且只有一个ThreadLocalMap,然后我们的业务里可能new了好几个ThreadLocal对象,存了几个ThreadLocal的值,这些就存在Entry数组中,然后,我们根据当前线程当前ThreadLocal 就能找到唯一的<T>value值。再简单点讲:每个线程里都有一个成员变量,本质是一个数组,里面存的就是你想线程私有化的几个对象
结构图如下:

核心源码如下:

// java.lang.Thread类里持有ThreadLocalMap的引用
public class Thread implements Runnable {
	... ...
    ThreadLocal.ThreadLocalMap threadLocals = null;
    ... ...
}

// java.lang.ThreadLocal有内部静态类ThreadLocalMap,主要提供给Thread类使用
public class ThreadLocal<T> {
	
	... ...
	
    static class ThreadLocalMap {
    	// 存线程私有变量
        private Entry[] table;
        
        // ThreadLocalMap内部有Entry类,Entry的key是ThreadLocal本身,value是泛型值
        static class Entry extends WeakReference<ThreadLocal<?>> {
            Object value;
            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }
    }

	... ...
}

ThreadLocal的成员变量

这里介绍一下ThreadLocal的成员变量。

  • private final int threadLocalHashCode = nextHashCode()
    自定义的哈希值,主要是调用nextHashCode()方法获取。
  • private static AtomicInteger nextHashCode = new AtomicInteger()
    下一个要给出的哈希码。自动更新。从0开始。
  • private static final int HASH_INCREMENT = 0x61c88647;
    连续生成的哈希码之间的差异,此为一个魔数。

ThreadLocal的成员方法

这里介绍一下ThreadLocal的成员方法。

  • public ThreadLocal()
    构造器。

      public ThreadLocal() {
      }
    
  • private static int nextHashCode()
    返回下一个哈希码。

      private static int nextHashCode() {
          return nextHashCode.getAndAdd(HASH_INCREMENT);
      }
    
  • protected T initialValue()
    用以初始化值,这个方法将在线程第一次使用get方法访问变量时被调用,如果调用get()前使用了set()设置值,则不会被调用,这里只是简单的初始化成了null,如果需要初始化成其它值,则必须将ThreadLocal子类化,并重写此方法。

    // 由子类提供实现。
    // protected
    protected T initialValue() {
    return null;
    }

  • public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier)
    可根据提供的函数,生成初始值。

    public static ThreadLocal withInitial(Supplier<? extends S> supplier) {
    return new SuppliedThreadLocal<>(supplier);
    }

  • public T get()
    返回该当前线程对应的线程局部变量值。

    public T get() {
    // 获取当前线程,这里的currentThread()是个native方法
    Thread t = Thread.currentThread();
    // 获取当前线程对应的ThreadLocalMap对象
    ThreadLocalMap map = getMap(t);
    // 若获取到了。则获取此ThreadLocalMap下的entry对象,若entry也获取到了,那么直接获取entry对应的value返回即可
    if (map != null) {
    // 获取此ThreadLocalMap下的entry对象,把当前ThreadLocal当参数传进去
    ThreadLocalMap.Entry e = map.getEntry(this);
    // 若entry也获取到了
    if (e != null) {
    @SuppressWarnings("unchecked")
    // 直接获取entry对应的value返回
    T result = (T)e.value;
    return result;
    }
    }
    // 若没获取到ThreadLocalMap或没获取到Entry,则设置初始值
    // 初始值方法是延迟加载
    return setInitialValue();
    }

  • boolean isPresent()
    如果当前线程中ThreadLocalMap不为空,且该局部变量也不为空,则返回true,否则返回false。

      boolean isPresent() {
          Thread t = Thread.currentThread();
          ThreadLocalMap map = getMap(t);
          return map != null && map.getEntry(this) != null;
      }
    
  • private T setInitialValue()
    设置初始值,调用initialValue。

    // 设置初始值
    private T setInitialValue() {
    // 调用初始值方法,由子类提供。
    T value = initialValue();
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 获取map
    ThreadLocalMap map = getMap(t);
    // 获取到了
    if (map != null)
    // set
    map.set(this, value);
    else
    // 没获取到。创建map并赋值
    createMap(t, value);
    // 返回初始值。
    return value;
    }

  • public void set(T value)
    设置当前线程的线程局部变量的值。

    public void set(T value) {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 获取当前线程对应的ThreadLocalMap实例
    ThreadLocalMap map = getMap(t);
    // 若当前线程有对应的ThreadLocalMap实例,则将当前ThreadLocal对象作为key,value做为值存到ThreadLocalMap的entry里。
    if (map != null)
    map.set(this, value);
    else
    // 若当前线程没有对应的ThreadLocalMap实例,则创建ThreadLocalMap,并将此线程与之绑定
    createMap(t, value);
    }

  • public void remove()
    删除当前线程局部变量的值,目的是为了减少内存占用,和防止内存泄漏。

    public void remove() {
    // 获取当前线程的ThreadLocalMap对象,并将其移除。
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
    m.remove(this);
    }

  • ThreadLocalMap getMap(Thread t)
    获取ThreadLocalMap,在你调用ThreadLocal.get()方法的时候就会调用这个方法,它的返回是当前线程里的threadLocals的引用。

    ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
    }

  • void createMap(Thread t, T firstValue)
    创建ThreadLocalMap,ThreadLocal底层其实就是一个map来维护的。

    void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

    // ThreadLocalMap构造器。
    ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    table = new Entry[INITIAL_CAPACITY];
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    // new了一个ThreadLocalMap的内部类Entry,且将key和value传入。
    // key是ThreadLocal对象。
    table[i] = new Entry(firstKey, firstValue);
    size = 1;
    setThreshold(INITIAL_CAPACITY);
    }

  • static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap)
    用工厂方法创建继承的线程局部变量的映射。

      static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
          return new ThreadLocalMap(parentMap);
      }
    
  • T childValue(T parentValue)
    childValue()是用来在ThreadLocal子类中定义实现的,在这里定义,主要是提供给createInheritedMap工厂方法调用。

    T childValue(T parentValue) {
    throw new UnsupportedOperationException();
    }

ThreadLocal的内部类

这里介绍一下ThreadLocal的内部类。

  • static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {}
    ThreadLocal的扩展,从指定的提供者获取初始值。

    static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {

          private final Supplier<? extends T> supplier;
          
          SuppliedThreadLocal(Supplier<? extends T> supplier) {
              this.supplier = Objects.requireNonNull(supplier);
          }
    
          @Override
          protected T initialValue() {
              return supplier.get();
          }
      }
    
  • static class ThreadLocalMap {}
    ThreadLocalMap是一个定制的散列映射,只适合维护线程本地值。ThreadLocalMap用类似HashMap的方式,存储ThreadLocal和他对应泛型的值,只不过这里只单纯的用了数组没有用到链表。没有用链表,怎么解决哈希冲突问题呢?其实很简单,依次向下一个索引查找,把值存在下一个为null的位置。

      static class ThreadLocalMap {
    
          /**
           * ThreadLocalMap 里数组里具体存的值
           */
          static class Entry extends WeakReference<ThreadLocal<?>> {
              /**
              * 与当前ThreadLocal 对应的值
              */
              Object value;
    
              Entry(ThreadLocal<?> k, Object v) {
                  super(k);
                  value = v;
              }
          }
    
          /**
           * 初始容量必须是2的幂次数,当前默认为16
           */
          private static final int INITIAL_CAPACITY = 16;
    
          /**
           * ThreadLocalMap 里的Entry[] 数组,长度必须为2的幂次数
           */
          private Entry[] table;
    
          /**
           * 当前 Entry[] table 的长度
           */
          private int size = 0;
    
          /**
           * 阈值,超过后需扩容
           */
          private int threshold; // Default to 0
    
          /**
           * 设置阈值为长度的 三分之二
           */
          private void setThreshold(int len) {
              threshold = len * 2 / 3;
          }
    
          /**
           * I对len取模
           */
          private static int nextIndex(int i, int len) {
              return ((i + 1 < len) ? i + 1 : 0);
          }
    
          /**
           * 获取前一个索引值
           */
          private static int prevIndex(int i, int len) {
              return ((i - 1 >= 0) ? i - 1 : len - 1);
          }
    
          /**
           * 构造方法,懒加载的
           */
          ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
          	// 根据初始容量,初始化表
              table = new Entry[INITIAL_CAPACITY];
              // 获取当前哈希值后,对len进行取模,确定索引位置
              int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
              table[i] = new Entry(firstKey, firstValue);
              // 初始化长度,和阈值
              size = 1;
              setThreshold(INITIAL_CAPACITY);
          }
    
          /**
           * 从给定父映射创建新映射         
           * */
          private ThreadLocalMap(ThreadLocalMap parentMap) {
          	// 初始化参数
              Entry[] parentTable = parentMap.table;
              int len = parentTable.length;
              setThreshold(len);
              table = new Entry[len];
    
              for (Entry e : parentTable) {
                  if (e != null) {
                      @SuppressWarnings("unchecked")
                      ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                      if (key != null) {
                          Object value = key.childValue(e.value);
                          Entry c = new Entry(key, value);
                          int h = key.threadLocalHashCode & (len - 1);
                          while (table[h] != null)
                              h = nextIndex(h, len);
                          table[h] = c;
                          size++;
                      }
                  }
              }
          }
    
          /**
           * 查找与key相关联的条目
           */
          private Entry getEntry(ThreadLocal<?> key) {
          	// 确定索引值
              int i = key.threadLocalHashCode & (table.length - 1);
              Entry e = table[i];
              if (e != null && e.get() == key)
                  return e;
              else
                  return getEntryAfterMiss(key, i, e);
          }
    
          /**
           * 根据key值找不到Entry时,用以下方法找,当前i值找不到,就到i+1处找
           */
          private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
              Entry[] tab = table;
              int len = tab.length;
    
              while (e != null) {
                  ThreadLocal<?> k = e.get();
                  if (k == key)
                      return e;
                  if (k == null)
                      expungeStaleEntry(i);
                  else
                  	// 当前索引 i 处找不到,就到索引 i + 1 处查找,这也是ThreadLocalMap解决哈希冲突的方法,即,当前有值,则顺位往下一个索引存
                      i = nextIndex(i, len);
                  e = tab[i];
              }
              return null;
          }
    
          /**
           * 根据ThreadLocal,存对应value值
           */
          private void set(ThreadLocal<?> key, Object value) {
    
              Entry[] tab = table;
              int len = tab.length;
              // 用key的哈希值,对len取模,以计算存储的索引位置
              int i = key.threadLocalHashCode & (len-1);
    
      		// 这几行代码就很有意思了,前面说ThreadLocalMap是没有链表的,那么怎么解决哈希冲突问题呢
      		// 就是,依次往下一位索引存,直到有空位为止
              for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
                  ThreadLocal<?> k = e.get();
    
      			// 如果key相等,则更新值
                  if (k == key) {
                      e.value = value;
                      return;
                  }
    
                  if (k == null) {
                      replaceStaleEntry(key, value, i);
                      return;
                  }
              }
    
              tab[i] = new Entry(key, value);
              int sz = ++size;
              if (!cleanSomeSlots(i, sz) && sz >= threshold)
                  rehash();
          }
    
          /**
           * 删除指定key的节点
           */
          private void remove(ThreadLocal<?> key) {
              Entry[] tab = table;
              int len = tab.length;
              int i = key.threadLocalHashCode & (len-1);
              for (Entry e = tab[i];
                   e != null;
                   e = tab[i = nextIndex(i, len)]) {
                  if (e.get() == key) {
                      e.clear();
                      expungeStaleEntry(i);
                      return;
                  }
              }
          }
    
          /**
           * 替换已经不再被使用的旧值,(key为null的值       
           */
          private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                         int staleSlot) {
              Entry[] tab = table;
              int len = tab.length;
              Entry e;
    
              int slotToExpunge = staleSlot;
              for (int i = prevIndex(staleSlot, len);
                   (e = tab[i]) != null;
                   i = prevIndex(i, len))
                  if (e.get() == null)
                      slotToExpunge = i;
    
              for (int i = nextIndex(staleSlot, len);
                   (e = tab[i]) != null;
                   i = nextIndex(i, len)) {
                  ThreadLocal<?> k = e.get();
    
    
                  if (k == key) {
                      e.value = value;
    
                      tab[i] = tab[staleSlot];
                      tab[staleSlot] = e;
    
                      if (slotToExpunge == staleSlot)
                          slotToExpunge = i;
                      cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                      return;
                  }
    
                  if (k == null && slotToExpunge == staleSlot)
                      slotToExpunge = i;
              }
    
              tab[staleSlot].value = null;
              tab[staleSlot] = new Entry(key, value);
    
              if (slotToExpunge != staleSlot)
                  cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
          }
    
          /**
           * staleSlot和下个空槽之间的所有空槽都将被检查和清除
           */
          private int expungeStaleEntry(int staleSlot) {
              Entry[] tab = table;
              int len = tab.length;
    
      		// 清楚当前槽
              tab[staleSlot].value = null;
              tab[staleSlot] = null;
              size--;
    
              Entry e;
              int i;
              for (i = nextIndex(staleSlot, len);
                   (e = tab[i]) != null;
                   i = nextIndex(i, len)) {
                  ThreadLocal<?> k = e.get();
                  // 如果发现key为空,则清除value值
                  if (k == null) {
                      e.value = null;
                      tab[i] = null;
                      size--;
                  } else {
                      int h = k.threadLocalHashCode & (len - 1);
                      if (h != i) {
                          tab[i] = null;
    
                          while (tab[h] != null)
                              h = nextIndex(h, len);
                          tab[h] = e;
                      }
                  }
              }
              return i;
          }
    
          /**
           * 启发式地扫描一些单位,寻找陈旧的条目
           */
          private boolean cleanSomeSlots(int i, int n) {
              boolean removed = false;
              Entry[] tab = table;
              int len = tab.length;
              do {
                  i = nextIndex(i, len);
                  Entry e = tab[i];
                  if (e != null && e.get() == null) {
                      n = len;
                      removed = true;
                      i = expungeStaleEntry(i);
                  }
              } while ( (n >>>= 1) != 0);
              return removed;
          }
    
          /**
           * 清除旧条目后,长度依然3/4threshold,则扩容
           */
          private void rehash() {
              expungeStaleEntries();
    
              if (size >= threshold - threshold / 4)
                  resize();
          }
    
          /**
           * 将原来的容量,扩大为两倍
           */
          private void resize() {
              Entry[] oldTab = table;
              int oldLen = oldTab.length;
              int newLen = oldLen * 2;
              Entry[] newTab = new Entry[newLen];
              int count = 0;
    
              for (Entry e : oldTab) {
                  if (e != null) {
                      ThreadLocal<?> k = e.get();
                      // 清除旧表无用值
                      if (k == null) {
                          e.value = null; // Help the GC
                      } else {
                      	// 找到合适位置并存储
                          int h = k.threadLocalHashCode & (newLen - 1);
                          while (newTab[h] != null)
                              h = nextIndex(h, newLen);
                          newTab[h] = e;
                          count++;
                      }
                  }
              }
      		
      		// 更新成员变量
              setThreshold(newLen);
              size = count;
              table = newTab;
          }
    
          /**
           * 删除表中所有陈旧的条目
           */
          private void expungeStaleEntries() {
              Entry[] tab = table;
              int len = tab.length;
              for (int j = 0; j < len; j++) {
                  Entry e = tab[j];
                  if (e != null && e.get() == null)
                      expungeStaleEntry(j);
              }
          }
      }
    

问题解答

相信看了上面的源码,文章开头的部分问题,你已经有了自己的答案,接下来我们再对一下答案。
1. ThreadLocal能不能代替Synchronized?和Synchronized的区别是什么?

答:ThreadLocal肯定不能代替Synchronized,ThreadLocal只是让变量变成了完全私有化,别的线程是无法访问的。而Synchronized除了解决线程冲突外,更重要的是,可以使变量被所有线程访问和修改。
2.Thread、ThreadLocal、ThreadLocalMap的关系是怎么样的?

答:关系有点绕,Thread类中包括ThreadLocalMap成员变量,ThreadLocalMap是ThreadLocal的内部类,ThreadLocalMap有Entry数组,Entry实体是键值对,其中key即是ThreadLocal类型。
3 存储在jvm的堆还是栈中?

答:很会人会觉得变量变成线程私有了,就一定是存在虚拟机栈中的,其实不是,我们私有变量是存在Thread对象里的,而对象都是存在堆里的,所以ThreadLocal的实例和他的值都是存在堆上的。

4. ThreadLocal会导致内存泄漏吗,为什么?

答:这个要从两方面分析,即ThreadLocalMap.Entry的key和value值分别讲:key直接是交给了父类处理super(key),父类是个弱引用,所以key完全不存在内存泄漏问题,value是个强引用,如果线程终止了,也会被GC干掉,但有时线程是不会被终止的,比如线程池里的核心线程,此时引用链就变成了:Thread->ThreadLocalMap->Entry(key为null)->value,由于value和Thread还存在链路关系,还是可达的,所以不会被回收,这样越来越多的垃圾对象产生却无法回收,最终可能导致OOM,当然解决办法也简单,用完私有变量后使用remove()方法即可,它会删除所有value值。

5. 为什么用Entry数组而不是Entry对象?

答:在同一个线程里,我们可能需要多个线程私有变量,所以需要数组。

6. ThreadLocal里的对象一定是线程安全的吗?

答:不一定,因为ThreadLocal.set()进去的对象,可能本身可能就是可供多个线程访问的,比如static对象,这样是没办法保存线程安全的。

7. ThreadLocalMap只用单纯的数组存值吗?如果出现哈希冲突怎么存值?

答:ThreadLocalMap是只用数组存Entry值,如果 i 位置出现哈希冲突,则存在 i + 1处,如果 i + 1 也不为空,则依次往下顺延,直到找到空位为止。

相关推荐
禁默31 分钟前
深入浅出:AWT的基本组件及其应用
java·开发语言·界面编程
Cachel wood37 分钟前
python round四舍五入和decimal库精确四舍五入
java·linux·前端·数据库·vue.js·python·前端框架
Code哈哈笑40 分钟前
【Java 学习】深度剖析Java多态:从向上转型到向下转型,解锁动态绑定的奥秘,让代码更优雅灵活
java·开发语言·学习
gb421528743 分钟前
springboot中Jackson库和jsonpath库的区别和联系。
java·spring boot·后端
程序猿进阶43 分钟前
深入解析 Spring WebFlux:原理与应用
java·开发语言·后端·spring·面试·架构·springboot
zfoo-framework1 小时前
【jenkins插件】
java
风_流沙1 小时前
java 对ElasticSearch数据库操作封装工具类(对你是否适用嘞)
java·数据库·elasticsearch
ProtonBase1 小时前
如何从 0 到 1 ,打造全新一代分布式数据架构
java·网络·数据库·数据仓库·分布式·云原生·架构
乐之者v2 小时前
leetCode43.字符串相乘
java·数据结构·算法