ConcurrentHashMap的安全机制详解

ConcurrentHashMap的安全机制详解

ConcurrentHashMap是Java并发编程中最重要且设计最精巧的集合类之一。它的安全机制经历了从JDK 1.7到JDK 1.8的重大演进,让我们深入剖析其实现原理。

🏗️ 整体架构演进

JDK 1.7:分段锁(Segment Locking)

复制代码
ConcurrentHashMap
    ├── Segment[0] (继承ReentrantLock)
    │   ├── HashEntry[] table
    │   └── count (volatile)
    ├── Segment[1]
    │   └── ...
    └── Segment[N]

JDK 1.8:CAS + synchronized精细化锁

复制代码
ConcurrentHashMap
    ├── Node<K,V>[] table (volatile)
    ├── sizeCtl (控制状态)
    └── 每个桶独立锁(链表头/树根)

🔒 JDK 1.7的安全机制详解

1. 分段锁设计

java 复制代码
static final class Segment<K,V> extends ReentrantLock {
    // 每个Segment独立计数,避免全局锁
    transient volatile int count;
    // 每个Segment独立扩容
    transient int threshold;
    // 每个Segment的负载因子
    final float loadFactor;
}

核心思想:将整个Map分成16个Segment(默认),每个Segment相当于一个小HashMap

  • 写操作:只锁定相关的Segment,其他Segment可正常访问
  • 读操作:完全无锁(利用volatile保证可见性)

2. 内存可见性保证

java 复制代码
static final class HashEntry<K,V> {
    final K key;
    final int hash;
    volatile V value;      // volatile保证可见性
    volatile HashEntry<K,V> next; // volatile保证可见性
}

为什么读操作可以无锁

  1. value和next是volatile的
  2. 写操作在释放锁前会写入volatile变量
  3. 根据happens-before原则,写操作对后续读操作可见

3. 弱一致性迭代器

java 复制代码
// 迭代时不会抛出ConcurrentModificationException
// 因为迭代的是创建迭代器时的快照
public void forEach(BiConsumer<? super K, ? super V> action) {
    // 遍历每个Segment的table
    // 反映的是迭代开始时的状态
}

⚡ JDK 1.8的安全机制详解(革命性改进)

1. Node节点设计

java 复制代码
static class Node<K,V> implements Map.Entry<K,V> {
    final int hash;
    final K key;
    volatile V val;        // volatile
    volatile Node<K,V> next; // volatile
    
    // CAS更新value
    boolean casVal(V cmp, V val) {
        return UNSAFE.compareAndSwapObject(this, valOffset, cmp, val);
    }
}

2. 三个核心并发控制变量

java 复制代码
// 最重要的控制变量
private transient volatile int sizeCtl;

// 扩容时的下一个表
private transient volatile Node<K,V>[] nextTable;

// 计数器基础值
private transient volatile long baseCount;

sizeCtl的多种含义

  • -1:表正在初始化
  • -(1 + n):有n个线程正在扩容
  • 0:默认值
  • > 0:扩容阈值(表大小的0.75倍)

3. CAS操作的广泛应用

初始化表(避免重复初始化)
java 复制代码
// 只有第一个线程能成功初始化
U.compareAndSwapInt(this, SIZECTL, sc, -1)
插入新节点(链表头)
java 复制代码
// 尝试CAS更新链表头
if (casTabAt(tab, i, null, new Node<K,V>(hash, key, value)))
    break;  // 成功则退出循环
计数更新(LongAdder风格)
java 复制代码
// 使用CounterCell数组分散竞争
if (U.compareAndSwapLong(this, BASECOUNT, b = baseCount, s = b + x))
    break;

4. 精细化synchronized锁

java 复制代码
// 只锁定单个桶的链表头/树根
synchronized (f) {
    if (tabAt(tab, i) == f) {
        // 操作链表或红黑树
        // 锁范围极小,只影响这个桶
    }
}

锁粒度对比

  • JDK 1.7:锁一个Segment(包含多个桶)
  • JDK 1.8:锁单个桶的链表头/树根

5. 并发扩容机制(Transfer)

这是JDK 1.8最精巧的设计之一:

java 复制代码
// 多线程协助扩容
if ((stride = (NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE)
    stride = MIN_TRANSFER_STRIDE;

// 每个线程负责迁移一部分桶
while (advance) {
    // 分配迁移任务给当前线程
}

扩容过程

  1. 创建新表(nextTable),大小为原表2倍
  2. 线程从高位向低位迁移桶
  3. 迁移完成的桶置为ForwardingNode
  4. 其他线程看到ForwardingNode会协助迁移
  5. 全部迁移完成后替换旧表

6. 安全的状态转换

java 复制代码
// 状态机确保安全转换
if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1)) {
    // 从普通状态 -> 扩容状态
    transfer(tab, nextTab);
    break;
}

🔄 两种版本的对比

特性 JDK 1.7(分段锁) JDK 1.8(CAS+synchronized)
锁粒度 段级别(16个段) 桶级别(更细)
读性能 无锁读 完全无锁读(改进)
写并发度 不同段可并发写 不同桶可并发写(更高)
扩容 段内独立扩容 多线程协助扩容
数据结构 数组+链表 数组+链表+红黑树
内存开销 较高(Segment数组) 较低(无Segment)
实现复杂度 相对简单 非常复杂

🧪 关键安全特性分析

1. 无锁读的实现原理

java 复制代码
public V get(Object key) {
    Node<K,V>[] tab; Node<K,V> e, p; int n, eh; K ek;
    int h = spread(key.hashCode());
    
    // 完全无锁,依靠volatile读
    if ((tab = table) != null && (n = tab.length) > 0 &&
        (e = tabAt(tab, (n - 1) & h)) != null) {
        // ...
    }
    return null;
}

为什么安全

  • table引用是volatile的
  • Node的val和next是volatile的
  • 数组元素通过tabAt方法读取(保证volatile语义)

2. 原子方法族

java 复制代码
// 一系列原子复合操作
V putIfAbsent(K key, V value);
V computeIfAbsent(K key, Function<K,V> mappingFunction);
V computeIfPresent(K key, BiFunction<K,V,V> remappingFunction);
V compute(K key, BiFunction<K,V,V> remappingFunction);
V merge(K key, V value, BiFunction<V,V,V> remappingFunction);

这些方法保证了"检查-操作-更新"的原子性,解决了组合操作的安全问题。

3. 迭代器安全策略

java 复制代码
// "弱一致性"迭代器
public void forEach(BiConsumer<? super K, ? super V> action) {
    if (action == null) throw new NullPointerException();
    Node<K,V>[] t;
    if ((t = table) != null) {
        Traverser<K,V> it = new Traverser<K,V>(t, t.length, 0, t.length);
        for (Node<K,V> p; (p = it.advance()) != null; ) {
            action.accept(p.key, p.val);
        }
    }
}

弱一致性表现

  • 不反映迭代过程中的更新
  • 不保证看到所有更新
  • 但保证看到已完成的更新

💡 设计哲学总结

ConcurrentHashMap的安全机制体现了几个重要设计原则:

1. 锁分离(Lock Stripping)

将大锁分解为多个小锁,减少竞争。

2. 乐观锁优先

先尝试CAS,失败才加锁,最大化并发。

3. 读/写分离

读完全无锁,写最小化锁范围。

4. 协助迁移

扩容时所有线程协作,避免单点瓶颈。

5. 渐进式改进

逐步完成复杂操作(如扩容),避免长时间停顿。

🚀 性能对比数据

根据Oracle官方测试,JDK 1.8的ConcurrentHashMap相比1.7版本:

场景 性能提升
高并发读 20-30%
高并发写 100-200%
混合读写 50-100%
扩容速度 300%+

🎯 使用建议

  1. 优先使用JDK 1.8+版本:除非必须兼容旧系统
  2. 善用原子方法:避免外部同步
  3. 注意迭代器的弱一致性:不要依赖迭代过程中的状态
  4. 合理预估容量:减少扩容开销
  5. 键对象必须实现hashCode()和equals():这是基础要求

ConcurrentHashMap的安全机制展示了Java并发库设计的最高水准,它的演进反映了并发编程理念的发展:从粗粒度悲观锁到细粒度乐观锁,从避免竞争到管理竞争

java 复制代码
package com.deploy.platform.collection;

import java.util.concurrent.locks.ReentrantLock;

/**
 * 一个简化版的ConcurrentHashMap,基于分段锁(JDK 1.7风格)实现。
 * 仅用于演示核心并发原理,未实现JDK 1.8中的树化、CAS等优化。
 */
public class SimpleConcurrentHashMap<K, V> {

    // --------------------------- 内部类:哈希表节点 ---------------------------
    static class Entry<K, V> {
        final K key;
        V value;
        Entry<K, V> next;
        final int hash;

        Entry(int hash, K key, V value, Entry<K, V> next) {
            this.hash = hash;
            this.key = key;
            this.value = value;
            this.next = next;
        }
    }

    // --------------------------- 内部类:段(Segment) ---------------------------
    /**
     * Segment 继承自 ReentrantLock,作为一个独立的锁单元和哈希表。
     */
    static final class Segment<K, V> extends ReentrantLock {
        // 每个Segment内部维护一个Entry数组(一个小哈希表)
        private Entry<K, V>[] table;
        private int capacity; // 该段内哈希表的容量

        @SuppressWarnings("unchecked")
        Segment(int cap) {
            this.capacity = cap;
            this.table = (Entry<K, V>[]) new Entry[cap];
        }

        /**
         * 在Segment内进行put操作(需要先获取该Segment的锁)
         */
        V put(K key, V value, int hash) {
            lock(); // 加锁,仅锁定这个Segment[citation:7]
            try {
                int index = (capacity - 1) & hash; // 计算在该段内的桶下标
                Entry<K, V> e = table[index];

                // 遍历链表,查看key是否已存在
                for (Entry<K, V> curr = e; curr != null; curr = curr.next) {
                    if (curr.hash == hash &&
                            (curr.key == key || key.equals(curr.key))) {
                        V oldValue = curr.value;
                        curr.value = value; // 更新值
                        return oldValue;
                    }
                }
                // key不存在,创建新节点并插入链表头部
                Entry<K, V> newEntry = new Entry<>(hash, key, value, e);
                table[index] = newEntry;
                return null;
            } finally {
                unlock(); // 操作完成后释放锁
            }
        }

        /**
         * 在Segment内原子地更新一个键的值。
         * @param key 键
         * @param hash 哈希值
         * @param remappingFunction 接收旧值,返回新值
         * @return 新值
         */
        V compute(K key, int hash, java.util.function.Function<V, V> remappingFunction) {
            lock();
            try {
                int index = (capacity - 1) & hash;
                Entry<K, V> e = table[index];
                Entry<K, V> prev = null;

                // 查找已存在的节点
                V oldValue = null;
                while (e != null) {
                    if (e.hash == hash && (e.key == key || key.equals(e.key))) {
                        oldValue = e.value;
                        break;
                    }
                    prev = e;
                    e = e.next;
                }

                // 应用函数计算新值
                V newValue = remappingFunction.apply(oldValue);

                if (newValue != null) {
                    if (e != null) {
                        // 键已存在:更新值
                        e.value = newValue;
                    } else {
                        // 键不存在:创建新节点
                        Entry<K, V> newEntry = new Entry<>(hash, key, newValue, table[index]);
                        table[index] = newEntry;
                    }
                } else if (e != null) {
                    // 函数返回null,删除节点
                    if (prev == null) {
                        table[index] = e.next;
                    } else {
                        prev.next = e.next;
                    }
                }
                return newValue;
            } finally {
                unlock();
            }
        }





        /**
         * 在Segment内进行get操作(不需要加锁,因为value是volatile的,但此处简化实现)
         */
        V get(K key, int hash) {
            // 为了简单,这里不加锁读取。实际JDK通过volatile等保证可见性。
            int index = (capacity - 1) & hash;
            for (Entry<K, V> e = table[index]; e != null; e = e.next) {
                if (e.hash == hash && (e.key == key || key.equals(key))) {
                    return e.value;
                }
            }
            return null;
        }

        /**
         * 在Segment内进行remove操作(需要加锁)
         */
        V remove(K key, int hash) {
            lock();
            try {
                int index = (capacity - 1) & hash;
                Entry<K, V> e = table[index];
                Entry<K, V> prev = null;

                while (e != null) {
                    if (e.hash == hash && (e.key == key || key.equals(key))) {
                        if (prev == null) {
                            table[index] = e.next; // 删除头节点
                        } else {
                            prev.next = e.next; // 删除中间节点
                        }
                        return e.value;
                    }
                    prev = e;
                    e = e.next;
                }
                return null;
            } finally {
                unlock();
            }
        }
    }

    // --------------------------- ConcurrentHashMap 主体 ---------------------------
    private final Segment<K, V>[] segments; // 段数组
    private final int segmentShift;         // 用于定位段的移位量
    private final int segmentMask;          // 用于定位段的掩码
    private static final int DEFAULT_SEGMENT_COUNT = 16; // 默认分段数(必须是2的幂)


    /**
     * 原子地增加指定键的值(增加值1)。
     */
    public V atomicIncrement(K key) {
        if (key == null) throw new NullPointerException();
        int hash = rehash(key.hashCode());
        Segment<K, V> segment = segmentFor(hash);

        // 使用Function<V, V>,只接收旧值,返回新值
        return segment.compute(key, hash, oldVal -> {
            // 这里需要确保V是Integer类型,实际使用时应该检查类型
            if (oldVal == null) {
                // 如果值是null,我们需要处理为0,但需要知道值的具体类型
                // 由于我们知道V是Integer,可以这样做
                return (V) Integer.valueOf(1);
            } else {
                // 如果是Integer,可以进行加法
                // 需要强制类型转换
                Integer oldInt = (Integer) oldVal;
                return (V) Integer.valueOf(oldInt + 1);
            }
        });
    }

    /**
     * 构造一个简化版的ConcurrentHashMap。
     */
    @SuppressWarnings("unchecked")
    public SimpleConcurrentHashMap() {
        // 分段数固定为16,每段初始容量也为16
        int segmentCount = DEFAULT_SEGMENT_COUNT;
        int segmentCapacity = 16;

        // 计算用于定位段的移位量和掩码[citation:10]
        // 因为segmentCount是2的幂,segmentMask = segmentCount - 1
        this.segmentMask = segmentCount - 1;
        // 假设hash是32位,segmentShift = 32 - log2(segmentCount)
        // 这里为了演示,简化计算:如果segmentCount=16,则segmentShift=28
        // 实际JDK有一个更复杂的再哈希过程[citation:10]
        this.segmentShift = 32 - Integer.numberOfTrailingZeros(segmentCount);

        // 初始化所有Segment
        segments = (Segment<K, V>[]) new Segment[segmentCount];
        for (int i = 0; i < segmentCount; i++) {
            segments[i] = new Segment<>(segmentCapacity);
        }
    }

    /**
     * 对key的hashCode进行再哈希,使高位也参与段定位,减少冲突[citation:10]。
     */
    private int rehash(int h) {
        // 一个简单的再哈希扰动函数,实际JDK实现更复杂
        h ^= (h >>> 20) ^ (h >>> 12);
        return h ^ (h >>> 7) ^ (h >>> 4);
    }

    /**
     * 根据key的哈希值,定位到应该属于哪个Segment[citation:10]。
     */
    private Segment<K, V> segmentFor(int hash) {
        // 用再哈希值的高位来决定段下标
        int segmentIndex = (hash >>> segmentShift) & segmentMask;
        return segments[segmentIndex];
    }

    // --------------------------- 对外公开的核心API ---------------------------
    public V put(K key, V value) {
        if (key == null || value == null) throw new NullPointerException(); // 不支持null
        int hash = rehash(key.hashCode());
        Segment<K, V> segment = segmentFor(hash);
        return segment.put(key, value, hash);
    }

    public V get(K key) {
        if (key == null) throw new NullPointerException();
        int hash = rehash(key.hashCode());
        Segment<K, V> segment = segmentFor(hash);
        return segment.get(key, hash);
    }

    public V remove(K key) {
        if (key == null) throw new NullPointerException();
        int hash = rehash(key.hashCode());
        Segment<K, V> segment = segmentFor(hash);
        return segment.remove(key, hash);
    }



    // --------------------------- 简单的测试 ---------------------------
    public static void main(String[] args) throws InterruptedException {
        SimpleConcurrentHashMap<String, Integer> map = new SimpleConcurrentHashMap<>();

        // 测试基本功能
        map.put("apple", 1);
        Thread t1 = new Thread(() -> {
            for (int i = 0; i < 100000; i++) {
                // 使用专门的Integer版本
//                map.atomicIncrement("apple");
                // 或者使用泛型版本(需要类型转换)
                // map.atomicIncrement("apple");
            }
        });

        Thread t2 = new Thread(() -> {
            for (int i = 0; i < 100000; i++) {
                map.atomicIncrement("apple");
            }
        });
        t1.start();
        t2.start();
        t1.join();
        t2.join();

        map.put("banana", 2);
        System.out.println("apple -> " + map.get("apple")); // 应输出 1
        System.out.println("banana -> " + map.get("banana")); // 应输出 2


    }
}
相关推荐
断剑zou天涯2 小时前
【算法笔记】bfprt算法
java·笔记·算法
番石榴AI2 小时前
java版的ocr推荐引擎——JiaJiaOCR 2.0重磅升级!纯Java CPU推理,新增手写OCR与表格识别
java·python·ocr
鸽鸽程序猿3 小时前
【项目】【抽奖系统】抽奖
java·spring
测试人社区-千羽3 小时前
边缘计算场景下的智能测试挑战
人工智能·python·安全·开源·智能合约·边缘计算·分布式账本
GoogleDocs3 小时前
基于[api-football]数据学习示例
java·学习
卓码软件测评3 小时前
第三方软件验收评测机构【Gatling安装指南:Java环境配置和IDE插件安装】
java·开发语言·ide·测试工具·负载均衡
妮妮分享3 小时前
H5获取定位的方式是什么?
java·前端·javascript
Billow_lamb4 小时前
MyBatis-Plus 的 条件构造器详解(超详细版)
java·mybatis
CoderYanger4 小时前
动态规划算法-两个数组的dp(含字符串数组):48.最长重复子数组
java·算法·leetcode·动态规划·1024程序员节