Java并发包深度解析:从AQS到线程池的完全指南

引言:并发编程的现代武器库

Java并发包(java.util.concurrent)是Java平台提供的强大并发编程工具集,它不仅仅解决了线程安全的问题,更提供了高性能、可扩展的并发解决方案。理解这些工具的内部实现,能让我们在并发编程中游刃有余。

一、AbstractQueuedSynchronizer(AQS):并发框架的基石

1.1 AQS的核心原理与实现

java 复制代码
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;

/**
 * 基于AQS的自定义可重入锁实现
 * 深入理解AQS的等待队列和状态管理机制
 */
public class AQSDeepDive {
    
    // 1. 自定义互斥锁实现
    static class Mutex implements Lock {
        private final Sync sync = new Sync();
        
        // 自定义同步器,继承AQS
        private static class Sync extends AbstractQueuedSynchronizer {
            // 尝试获取锁(重写AQS的tryAcquire)
            @Override
            protected boolean tryAcquire(int acquires) {
                assert acquires == 1; // 互斥锁只允许获取一个许可
                
                // 使用CAS操作尝试将state从0改为1
                if (compareAndSetState(0, 1)) {
                    // 成功获取锁,设置当前线程为独占线程
                    setExclusiveOwnerThread(Thread.currentThread());
                    return true;
                }
                return false;
            }
            
            // 尝试释放锁(重写AQS的tryRelease)
            @Override
            protected boolean tryRelease(int releases) {
                assert releases == 1;
                
                if (getState() == 0) {
                    throw new IllegalMonitorStateException();
                }
                
                // 清空独占线程
                setExclusiveOwnerThread(null);
                // 不需要CAS,因为只有持有锁的线程能释放
                setState(0);
                return true;
            }
            
            // 是否被当前线程独占
            @Override
            protected boolean isHeldExclusively() {
                return getExclusiveOwnerThread() == Thread.currentThread();
            }
            
            // 创建Condition对象
            Condition newCondition() {
                return new ConditionObject();
            }
        }
        
        // Lock接口实现
        @Override
        public void lock() {
            sync.acquire(1);
        }
        
        @Override
        public void lockInterruptibly() throws InterruptedException {
            sync.acquireInterruptibly(1);
        }
        
        @Override
        public boolean tryLock() {
            return sync.tryAcquire(1);
        }
        
        @Override
        public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
            return sync.tryAcquireNanos(1, unit.toNanos(time));
        }
        
        @Override
        public void unlock() {
            sync.release(1);
        }
        
        @Override
        public Condition newCondition() {
            return sync.newCondition();
        }
        
        // 工具方法
        public boolean isLocked() {
            return sync.isHeldExclusively();
        }
        
        public boolean hasQueuedThreads() {
            return sync.hasQueuedThreads();
        }
    }
    
    // 2. 自定义共享锁(信号量风格)
    static class SharedLock {
        private final Sync sync;
        
        public SharedLock(int permits) {
            sync = new Sync(permits);
        }
        
        private static class Sync extends AbstractQueuedSynchronizer {
            private final int maxPermits;
            
            Sync(int permits) {
                maxPermits = permits;
                setState(permits); // 初始状态表示可用许可数
            }
            
            @Override
            protected int tryAcquireShared(int acquires) {
                for (;;) {
                    int available = getState();
                    int remaining = available - acquires;
                    
                    if (remaining < 0 || 
                        compareAndSetState(available, remaining)) {
                        return remaining;
                    }
                }
            }
            
            @Override
            protected boolean tryReleaseShared(int releases) {
                for (;;) {
                    int current = getState();
                    int next = current + releases;
                    
                    // 防止溢出
                    if (next < current) {
                        throw new Error("Maximum permit count exceeded");
                    }
                    
                    if (compareAndSetState(current, next)) {
                        return true;
                    }
                }
            }
        }
        
        public void acquire() {
            sync.acquireShared(1);
        }
        
        public void release() {
            sync.releaseShared(1);
        }
    }
    
    // 3. AQS内部状态可视化工具
    static class AQSInspector {
        public static void printAQSState(AbstractQueuedSynchronizer aqs) throws Exception {
            System.out.println("\n=== AQS内部状态 ===");
            
            // 反射获取内部字段
            Class<?> clazz = AbstractQueuedSynchronizer.class;
            
            // 获取state
            System.out.println("State: " + aqs.getState());
            
            // 获取等待队列头部
            java.lang.reflect.Field headField = clazz.getDeclaredField("head");
            headField.setAccessible(true);
            Object head = headField.get(aqs);
            
            if (head != null) {
                System.out.println("等待队列头部: " + head);
                
                // 遍历等待队列
                System.out.println("等待线程:");
                int count = 0;
                Object node = head;
                while (node != null) {
                    Class<?> nodeClass = node.getClass();
                    
                    // 获取线程
                    java.lang.reflect.Field threadField = nodeClass.getDeclaredField("thread");
                    threadField.setAccessible(true);
                    Thread thread = (Thread) threadField.get(node);
                    
                    // 获取等待状态
                    java.lang.reflect.Field waitStatusField = nodeClass.getDeclaredField("waitStatus");
                    waitStatusField.setAccessible(true);
                    int waitStatus = waitStatusField.getInt(node);
                    
                    if (thread != null) {
                        System.out.printf("  [%d] %s (状态: %d)%n", 
                            count++, thread.getName(), waitStatus);
                    }
                    
                    // 获取下一个节点
                    java.lang.reflect.Field nextField = nodeClass.getDeclaredField("next");
                    nextField.setAccessible(true);
                    node = nextField.get(node);
                }
            } else {
                System.out.println("等待队列: 空");
            }
        }
    }
    
    public static void main(String[] args) throws Exception {
        System.out.println("=== AQS深度解析 ===\n");
        
        // 测试自定义互斥锁
        System.out.println("1. 自定义互斥锁测试:");
        testMutex();
        
        // 测试自定义共享锁
        System.out.println("\n2. 自定义共享锁测试:");
        testSharedLock();
        
        // AQS状态监控演示
        System.out.println("\n3. AQS状态监控演示:");
        testAQSStateInspection();
    }
    
    private static void testMutex() throws Exception {
        Mutex lock = new Mutex();
        
        Thread t1 = new Thread(() -> {
            lock.lock();
            try {
                System.out.println("线程1获取锁");
                Thread.sleep(1000);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            } finally {
                lock.unlock();
                System.out.println("线程1释放锁");
            }
        });
        
        Thread t2 = new Thread(() -> {
            try {
                Thread.sleep(100); // 确保t1先获取锁
                System.out.println("线程2尝试获取锁...");
                
                if (lock.tryLock(500, TimeUnit.MILLISECONDS)) {
                    try {
                        System.out.println("线程2获取锁成功");
                    } finally {
                        lock.unlock();
                    }
                } else {
                    System.out.println("线程2获取锁超时");
                }
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        });
        
        t1.start();
        t2.start();
        
        t1.join();
        t2.join();
        
        System.out.println("是否有线程在等待队列中: " + lock.hasQueuedThreads());
    }
    
    private static void testSharedLock() throws InterruptedException {
        SharedLock lock = new SharedLock(3); // 允许3个并发
        
        List<Thread> threads = new ArrayList<>();
        CountDownLatch latch = new CountDownLatch(6);
        
        for (int i = 1; i <= 6; i++) {
            int threadId = i;
            Thread t = new Thread(() -> {
                try {
                    System.out.printf("线程%d尝试获取许可...%n", threadId);
                    lock.acquire();
                    System.out.printf("线程%d获取许可成功,执行任务%n", threadId);
                    Thread.sleep(500);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                } finally {
                    lock.release();
                    System.out.printf("线程%d释放许可%n", threadId);
                    latch.countDown();
                }
            });
            threads.add(t);
        }
        
        threads.forEach(Thread::start);
        latch.await();
        System.out.println("所有线程执行完成");
    }
    
    private static void testAQSStateInspection() throws Exception {
        Mutex lock = new Mutex();
        
        // 反射获取内部的Sync对象
        java.lang.reflect.Field syncField = Mutex.class.getDeclaredField("sync");
        syncField.setAccessible(true);
        AbstractQueuedSynchronizer sync = (AbstractQueuedSynchronizer) syncField.get(lock);
        
        System.out.println("初始状态:");
        AQSInspector.printAQSState(sync);
        
        // 创建竞争线程
        Thread[] threads = new Thread[3];
        CountDownLatch startLatch = new CountDownLatch(1);
        CountDownLatch endLatch = new CountDownLatch(threads.length);
        
        for (int i = 0; i < threads.length; i++) {
            int threadId = i;
            threads[i] = new Thread(() -> {
                try {
                    startLatch.await();
                    System.out.printf("线程%d尝试获取锁%n", threadId);
                    lock.lock();
                    try {
                        System.out.printf("线程%d获取锁成功%n", threadId);
                        Thread.sleep(1000);
                    } finally {
                        lock.unlock();
                    }
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                } finally {
                    endLatch.countDown();
                }
            });
            threads[i].start();
        }
        
        // 主线程先获取锁
        lock.lock();
        try {
            System.out.println("\n主线程获取锁,其他线程进入等待队列:");
            AQSInspector.printAQSState(sync);
            
            // 启动竞争线程
            startLatch.countDown();
            Thread.sleep(500); // 等待其他线程进入队列
            
            System.out.println("\n竞争线程进入队列后:");
            AQSInspector.printAQSState(sync);
        } finally {
            lock.unlock();
        }
        
        endLatch.await();
        System.out.println("\n所有线程执行完成后的状态:");
        AQSInspector.printAQSState(sync);
    }
}

二、线程池深度优化与监控

2.1 线程池的核心参数与动态调整

java 复制代码
import java.util.concurrent.*;
import java.util.concurrent.atomic.*;
import java.util.*;

public class ThreadPoolDeepDive {
    
    // 1. 可监控的线程池实现
    static class MonitoredThreadPoolExecutor extends ThreadPoolExecutor {
        // 监控指标
        private final AtomicLong totalTasks = new AtomicLong();
        private final AtomicLong completedTasks = new AtomicLong();
        private final AtomicLong failedTasks = new AtomicLong();
        private final AtomicInteger maxPoolSize = new AtomicInteger();
        private final LongAdder totalQueueTime = new LongAdder();
        private final LongAdder totalExecutionTime = new LongAdder();
        
        // 拒绝策略统计
        private final AtomicLong rejectedTasks = new AtomicLong();
        
        // 线程创建工厂,用于设置线程名称和监控
        private static class MonitoredThreadFactory implements ThreadFactory {
            private final AtomicInteger threadNumber = new AtomicInteger(1);
            private final String namePrefix;
            
            MonitoredThreadFactory(String poolName) {
                namePrefix = poolName + "-thread-";
            }
            
            @Override
            public Thread newThread(Runnable r) {
                Thread t = new Thread(r, namePrefix + threadNumber.getAndIncrement());
                t.setDaemon(true); // 设置为守护线程,防止阻止JVM退出
                return t;
            }
        }
        
        public MonitoredThreadPoolExecutor(int corePoolSize,
                                          int maximumPoolSize,
                                          long keepAliveTime,
                                          TimeUnit unit,
                                          BlockingQueue<Runnable> workQueue,
                                          String poolName) {
            super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue,
                  new MonitoredThreadFactory(poolName),
                  new MonitoringRejectedExecutionHandler());
        }
        
        // 重写方法以添加监控
        @Override
        public void execute(Runnable command) {
            totalTasks.incrementAndGet();
            long startQueueTime = System.nanoTime();
            
            super.execute(() -> {
                long queueTime = System.nanoTime() - startQueueTime;
                totalQueueTime.add(queueTime);
                
                long startExecutionTime = System.nanoTime();
                try {
                    command.run();
                    completedTasks.incrementAndGet();
                } catch (Exception e) {
                    failedTasks.incrementAndGet();
                    throw e;
                } finally {
                    long executionTime = System.nanoTime() - startExecutionTime;
                    totalExecutionTime.add(executionTime);
                }
            });
            
            // 更新最大线程数
            int currentPoolSize = getPoolSize();
            maxPoolSize.getAndUpdate(old -> Math.max(old, currentPoolSize));
        }
        
        // 自定义拒绝策略,用于统计
        private class MonitoringRejectedExecutionHandler implements RejectedExecutionHandler {
            @Override
            public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
                rejectedTasks.incrementAndGet();
                System.err.println("任务被拒绝,当前拒绝总数: " + rejectedTasks.get());
                
                // 尝试重新提交到队列
                if (!executor.isShutdown()) {
                    try {
                        boolean offered = executor.getQueue().offer(r, 1, TimeUnit.SECONDS);
                        if (!offered) {
                            System.err.println("重新提交到队列失败,任务被丢弃");
                        }
                    } catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                        System.err.println("重新提交被中断");
                    }
                }
            }
        }
        
        // 监控指标获取方法
        public Map<String, Object> getMetrics() {
            Map<String, Object> metrics = new LinkedHashMap<>();
            
            metrics.put("corePoolSize", getCorePoolSize());
            metrics.put("maximumPoolSize", getMaximumPoolSize());
            metrics.put("poolSize", getPoolSize());
            metrics.put("activeCount", getActiveCount());
            metrics.put("largestPoolSize", getLargestPoolSize());
            metrics.put("taskCount", getTaskCount());
            metrics.put("completedTaskCount", getCompletedTaskCount());
            metrics.put("queueSize", getQueue().size());
            metrics.put("queueRemainingCapacity", getQueue().remainingCapacity());
            
            // 自定义监控指标
            metrics.put("totalTasks", totalTasks.get());
            metrics.put("completedTasks", completedTasks.get());
            metrics.put("failedTasks", failedTasks.get());
            metrics.put("rejectedTasks", rejectedTasks.get());
            metrics.put("maxPoolSizeObserved", maxPoolSize.get());
            
            if (completedTasks.get() > 0) {
                metrics.put("avgQueueTimeNs", totalQueueTime.sum() / completedTasks.get());
                metrics.put("avgExecutionTimeNs", totalExecutionTime.sum() / completedTasks.get());
            }
            
            return metrics;
        }
        
        // 动态调整线程池参数
        public void adjustCorePoolSize(int newCorePoolSize) {
            if (newCorePoolSize < 0 || newCorePoolSize > getMaximumPoolSize()) {
                throw new IllegalArgumentException("无效的核心线程数");
            }
            
            setCorePoolSize(newCorePoolSize);
            System.out.printf("核心线程数从 %d 调整为 %d%n", 
                getCorePoolSize(), newCorePoolSize);
        }
        
        public void adjustMaximumPoolSize(int newMaximumPoolSize) {
            if (newMaximumPoolSize < 1 || newMaximumPoolSize < getCorePoolSize()) {
                throw new IllegalArgumentException("无效的最大线程数");
            }
            
            setMaximumPoolSize(newMaximumPoolSize);
            System.out.printf("最大线程数从 %d 调整为 %d%n", 
                getMaximumPoolSize(), newMaximumPoolSize);
        }
    }
    
    // 2. 任务包装器,用于添加任务级别的监控
    static class MonitoredTask implements Runnable {
        private final Runnable task;
        private final String taskId;
        private final long creationTime;
        private volatile long startTime;
        private volatile long endTime;
        private volatile Thread executionThread;
        private volatile boolean completed;
        private volatile Throwable failureCause;
        
        public MonitoredTask(Runnable task, String taskId) {
            this.task = task;
            this.taskId = taskId;
            this.creationTime = System.currentTimeMillis();
        }
        
        @Override
        public void run() {
            startTime = System.currentTimeMillis();
            executionThread = Thread.currentThread();
            
            try {
                task.run();
                completed = true;
            } catch (Throwable t) {
                failureCause = t;
                throw t;
            } finally {
                endTime = System.currentTimeMillis();
            }
        }
        
        public TaskMetrics getMetrics() {
            return new TaskMetrics(
                taskId,
                creationTime,
                startTime,
                endTime,
                completed,
                failureCause != null,
                executionThread != null ? executionThread.getName() : null
            );
        }
    }
    
    static class TaskMetrics {
        final String taskId;
        final long creationTime;
        final long startTime;
        final long endTime;
        final boolean completed;
        final boolean failed;
        final String executionThread;
        
        TaskMetrics(String taskId, long creationTime, long startTime, 
                   long endTime, boolean completed, boolean failed, 
                   String executionThread) {
            this.taskId = taskId;
            this.creationTime = creationTime;
            this.startTime = startTime;
            this.endTime = endTime;
            this.completed = completed;
            this.failed = failed;
            this.executionThread = executionThread;
        }
        
        public long getWaitTime() {
            return startTime > 0 ? startTime - creationTime : 0;
        }
        
        public long getExecutionTime() {
            return (startTime > 0 && endTime > 0) ? endTime - startTime : 0;
        }
        
        @Override
        public String toString() {
            return String.format("任务%s: 等待%dms, 执行%dms, 线程: %s, 状态: %s",
                taskId, getWaitTime(), getExecutionTime(), 
                executionThread, completed ? "成功" : failed ? "失败" : "未开始");
        }
    }
    
    public static void main(String[] args) throws Exception {
        System.out.println("=== 线程池深度优化与监控 ===\n");
        
        // 测试可监控线程池
        System.out.println("1. 可监控线程池测试:");
        testMonitoredThreadPool();
        
        // 测试动态调整
        System.out.println("\n2. 线程池动态调整测试:");
        testDynamicAdjustment();
        
        // 测试不同队列策略
        System.out.println("\n3. 不同队列策略对比测试:");
        testQueueStrategies();
        
        // 测试拒绝策略
        System.out.println("\n4. 拒绝策略测试:");
        testRejectionPolicies();
    }
    
    private static void testMonitoredThreadPool() throws InterruptedException {
        MonitoredThreadPoolExecutor executor = new MonitoredThreadPoolExecutor(
            2,  // corePoolSize
            4,  // maximumPoolSize
            60, // keepAliveTime
            TimeUnit.SECONDS,
            new LinkedBlockingQueue<>(10), // 队列容量10
            "TestPool"
        );
        
        List<MonitoredTask> tasks = new ArrayList<>();
        CountDownLatch latch = new CountDownLatch(20);
        
        // 提交任务
        for (int i = 0; i < 20; i++) {
            int taskId = i;
            MonitoredTask task = new MonitoredTask(() -> {
                try {
                    // 模拟任务执行时间
                    Thread.sleep(100 + (int)(Math.random() * 400));
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                } finally {
                    latch.countDown();
                }
            }, "Task-" + taskId);
            
            tasks.add(task);
            executor.execute(task);
        }
        
        // 定期打印监控指标
        ScheduledExecutorService monitor = Executors.newSingleThreadScheduledExecutor();
        monitor.scheduleAtFixedRate(() -> {
            System.out.println("\n--- 监控快照 ---");
            Map<String, Object> metrics = executor.getMetrics();
            metrics.forEach((key, value) -> 
                System.out.printf("%-25s: %s%n", key, value));
        }, 0, 500, TimeUnit.MILLISECONDS);
        
        latch.await();
        monitor.shutdown();
        
        // 打印任务级别指标
        System.out.println("\n任务级别指标:");
        tasks.stream()
            .map(MonitoredTask::getMetrics)
            .forEach(System.out::println);
        
        executor.shutdown();
        executor.awaitTermination(5, TimeUnit.SECONDS);
        
        System.out.println("\n最终监控指标:");
        executor.getMetrics().forEach((key, value) -> 
            System.out.printf("%-25s: %s%n", key, value));
    }
    
    private static void testDynamicAdjustment() throws InterruptedException {
        MonitoredThreadPoolExecutor executor = new MonitoredThreadPoolExecutor(
            2, 4, 60, TimeUnit.SECONDS,
            new LinkedBlockingQueue<>(20), "DynamicPool"
        );
        
        // 监控线程
        Thread monitorThread = new Thread(() -> {
            try {
                for (int i = 0; i < 10; i++) {
                    Map<String, Object> metrics = executor.getMetrics();
                    int queueSize = (int) metrics.get("queueSize");
                    int activeCount = (int) metrics.get("activeCount");
                    
                    // 根据队列长度动态调整核心线程数
                    if (queueSize > 15 && activeCount < executor.getMaximumPoolSize()) {
                        executor.adjustCorePoolSize(Math.min(
                            executor.getCorePoolSize() + 2,
                            executor.getMaximumPoolSize()
                        ));
                    } else if (queueSize < 5 && executor.getCorePoolSize() > 2) {
                        executor.adjustCorePoolSize(Math.max(2, executor.getCorePoolSize() - 1));
                    }
                    
                    Thread.sleep(1000);
                }
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        });
        
        // 提交任务
        for (int i = 0; i < 50; i++) {
            int taskId = i;
            executor.execute(() -> {
                try {
                    Thread.sleep(200);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            });
            
            // 控制任务提交速度
            if (i % 10 == 9) {
                Thread.sleep(500);
            }
        }
        
        monitorThread.start();
        monitorThread.join();
        
        executor.shutdown();
        executor.awaitTermination(5, TimeUnit.SECONDS);
    }
    
    private static void testQueueStrategies() {
        System.out.println("测试不同队列策略的性能特征:\n");
        
        // 测试不同队列
        BlockingQueue<Runnable>[] queues = new BlockingQueue[] {
            new SynchronousQueue<>(),      // 直接传递,无缓冲
            new LinkedBlockingQueue<>(10), // 有界链表队列
            new ArrayBlockingQueue<>(10),  // 有界数组队列
            new LinkedBlockingQueue<>(),   // 无界队列(危险!)
            new PriorityBlockingQueue<>()  // 优先级队列
        };
        
        String[] queueNames = {
            "SynchronousQueue",
            "LinkedBlockingQueue(有界)",
            "ArrayBlockingQueue",
            "LinkedBlockingQueue(无界)",
            "PriorityBlockingQueue"
        };
        
        for (int i = 0; i < queues.length; i++) {
            System.out.println("测试队列: " + queueNames[i]);
            
            ThreadPoolExecutor executor = new ThreadPoolExecutor(
                2, 4, 60, TimeUnit.SECONDS, queues[i]
            );
            
            long startTime = System.currentTimeMillis();
            CountDownLatch latch = new CountDownLatch(20);
            
            for (int j = 0; j < 20; j++) {
                executor.execute(() -> {
                    try {
                        Thread.sleep(100);
                    } catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                    } finally {
                        latch.countDown();
                    }
                });
            }
            
            try {
                latch.await();
                long duration = System.currentTimeMillis() - startTime;
                System.out.printf("  完成时间: %dms, 队列大小: %d, 最大线程数: %d%n",
                    duration, executor.getQueue().size(), executor.getLargestPoolSize());
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
            
            executor.shutdown();
        }
    }
    
    private static void testRejectionPolicies() {
        System.out.println("测试不同拒绝策略:\n");
        
        // 各种拒绝策略
        RejectedExecutionHandler[] policies = {
            new ThreadPoolExecutor.AbortPolicy(),
            new ThreadPoolExecutor.CallerRunsPolicy(),
            new ThreadPoolExecutor.DiscardPolicy(),
            new ThreadPoolExecutor.DiscardOldestPolicy()
        };
        
        String[] policyNames = {
            "AbortPolicy(抛出异常)",
            "CallerRunsPolicy(调用者运行)",
            "DiscardPolicy(静默丢弃)",
            "DiscardOldestPolicy(丢弃最老)"
        };
        
        for (int i = 0; i < policies.length; i++) {
            System.out.println("测试策略: " + policyNames[i]);
            
            ThreadPoolExecutor executor = new ThreadPoolExecutor(
                1, 1, 0, TimeUnit.SECONDS,
                new SynchronousQueue<>(), // 无缓冲,快速触发拒绝
                policies[i]
            );
            
            int submitted = 0;
            int rejected = 0;
            
            for (int j = 0; j < 5; j++) {
                try {
                    executor.execute(() -> {
                        try {
                            Thread.sleep(1000); // 长时间任务
                        } catch (InterruptedException e) {
                            Thread.currentThread().interrupt();
                        }
                    });
                    submitted++;
                } catch (RejectedExecutionException e) {
                    rejected++;
                }
            }
            
            System.out.printf("  提交: %d, 拒绝: %d%n", submitted, rejected);
            executor.shutdown();
        }
    }
}

三、并发集合与原子类的内部机制

3.1 ConcurrentHashMap与CopyOnWriteArrayList深度分析

java 复制代码
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.*;
import java.util.concurrent.locks.*;
import java.lang.reflect.*;

public class ConcurrentCollectionsDeepDive {
    
    // 1. ConcurrentHashMap分段锁实现(模拟Java 7的实现方式)
    static class SegmentedHashMap<K, V> {
        // 分段(Segment)类
        static final class Segment<K, V> extends ReentrantLock {
            volatile HashEntry<K, V>[] table;
            volatile int count;
            int modCount;
            int threshold;
            final float loadFactor;
            
            Segment(int initialCapacity, float lf) {
                loadFactor = lf;
                setTable(new HashEntry[initialCapacity]);
                threshold = (int)(initialCapacity * loadFactor);
            }
            
            @SuppressWarnings("unchecked")
            void setTable(HashEntry<K, V>[] newTable) {
                table = newTable;
            }
            
            V put(K key, int hash, V value, boolean onlyIfAbsent) {
                lock();
                try {
                    HashEntry<K, V>[] tab = table;
                    int index = hash & (tab.length - 1);
                    HashEntry<K, V> first = tab[index];
                    
                    // 遍历链表
                    for (HashEntry<K, V> e = first; e != null; e = e.next) {
                        if (e.hash == hash && key.equals(e.key)) {
                            V oldValue = e.value;
                            if (!onlyIfAbsent) {
                                e.value = value;
                            }
                            return oldValue;
                        }
                    }
                    
                    // 添加到链表头部
                    modCount++;
                    tab[index] = new HashEntry<>(key, hash, first, value);
                    
                    // 检查是否需要扩容
                    if (++count > threshold) {
                        rehash();
                    }
                    
                    return null;
                } finally {
                    unlock();
                }
            }
            
            @SuppressWarnings("unchecked")
            private void rehash() {
                HashEntry<K, V>[] oldTable = table;
                int oldCapacity = oldTable.length;
                int newCapacity = oldCapacity << 1;
                HashEntry<K, V>[] newTable = new HashEntry[newCapacity];
                int newThreshold = (int)(newCapacity * loadFactor);
                
                for (int i = 0; i < oldCapacity; i++) {
                    HashEntry<K, V> e = oldTable[i];
                    while (e != null) {
                        HashEntry<K, V> next = e.next;
                        int newIndex = e.hash & (newCapacity - 1);
                        e.next = newTable[newIndex];
                        newTable[newIndex] = e;
                        e = next;
                    }
                }
                
                table = newTable;
                threshold = newThreshold;
            }
        }
        
        // 哈希条目
        static final class HashEntry<K, V> {
            final K key;
            final int hash;
            volatile V value;
            final HashEntry<K, V> next;
            
            HashEntry(K key, int hash, HashEntry<K, V> next, V value) {
                this.key = key;
                this.hash = hash;
                this.next = next;
                this.value = value;
            }
        }
        
        // ConcurrentHashMap主体
        private final Segment<K, V>[] segments;
        private static final int DEFAULT_CONCURRENCY_LEVEL = 16;
        private static final int DEFAULT_INITIAL_CAPACITY = 16;
        private static final float DEFAULT_LOAD_FACTOR = 0.75f;
        
        @SuppressWarnings("unchecked")
        public SegmentedHashMap() {
            this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
        }
        
        @SuppressWarnings("unchecked")
        public SegmentedHashMap(int initialCapacity, float loadFactor, int concurrencyLevel) {
            // 确保分段数是2的幂
            int sshift = 0;
            int ssize = 1;
            while (ssize < concurrencyLevel) {
                ++sshift;
                ssize <<= 1;
            }
            
            segments = new Segment[ssize];
            int segmentCapacity = initialCapacity / ssize;
            if (segmentCapacity * ssize < initialCapacity) {
                ++segmentCapacity;
            }
            
            int segmentSize = 1;
            while (segmentSize < segmentCapacity) {
                segmentSize <<= 1;
            }
            
            for (int i = 0; i < segments.length; i++) {
                segments[i] = new Segment<>(segmentSize, loadFactor);
            }
        }
        
        private Segment<K, V> segmentFor(int hash) {
            return segments[(hash >>> 28) & (segments.length - 1)];
        }
        
        public V put(K key, V value) {
            int hash = hash(key.hashCode());
            return segmentFor(hash).put(key, hash, value, false);
        }
        
        private static int hash(int h) {
            h += (h << 15) ^ 0xffffcd7d;
            h ^= (h >>> 10);
            h += (h << 3);
            h ^= (h >>> 6);
            h += (h << 2) + (h << 14);
            return h ^ (h >>> 16);
        }
    }
    
    // 2. CopyOnWriteArrayList的写时复制机制分析
    static class CopyOnWriteAnalysis {
        public static void analyzeCOWBehavior() {
            System.out.println("\n=== CopyOnWriteArrayList行为分析 ===\n");
            
            CopyOnWriteArrayList<String> list = new CopyOnWriteArrayList<>();
            list.add("A");
            list.add("B");
            list.add("C");
            
            System.out.println("初始列表: " + list);
            
            // 获取内部数组的引用(通过反射)
            try {
                Field arrayField = CopyOnWriteArrayList.class.getDeclaredField("array");
                arrayField.setAccessible(true);
                Object[] originalArray = (Object[]) arrayField.get(list);
                System.out.println("内部数组引用: " + System.identityHashCode(originalArray));
                
                // 遍历过程中修改列表
                System.out.println("\n遍历过程中添加元素:");
                int iteration = 0;
                for (String item : list) {
                    System.out.printf("  迭代%d: 元素=%s%n", ++iteration, item);
                    if (iteration == 2) {
                        list.add("D");
                        System.out.println("  添加了新元素D");
                        
                        // 检查数组是否被复制
                        Object[] newArray = (Object[]) arrayField.get(list);
                        System.out.printf("  数组引用变化: %d -> %d%n",
                            System.identityHashCode(originalArray),
                            System.identityHashCode(newArray));
                    }
                }
                
                System.out.println("\n遍历完成后的列表: " + list);
                
                // 测试并发修改
                System.out.println("\n并发修改测试:");
                testConcurrentModification();
                
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        
        private static void testConcurrentModification() throws InterruptedException {
            CopyOnWriteArrayList<Integer> list = new CopyOnWriteArrayList<>();
            for (int i = 0; i < 5; i++) {
                list.add(i);
            }
            
            // 读线程
            Thread reader = new Thread(() -> {
                try {
                    for (int i = 0; i < 10; i++) {
                        System.out.println("读线程: " + list);
                        Thread.sleep(100);
                    }
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            });
            
            // 写线程
            Thread writer = new Thread(() -> {
                try {
                    for (int i = 0; i < 5; i++) {
                        Thread.sleep(200);
                        list.add(list.size() * 10);
                        System.out.println("写线程添加: " + (list.size() * 10));
                    }
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            });
            
            reader.start();
            writer.start();
            
            reader.join();
            writer.join();
        }
    }
    
    // 3. 原子类的CAS操作与ABA问题
    static class AtomicAnalysis {
        
        // ABA问题演示
        static class ABAProblem {
            private static AtomicReference<String> atomicRef = new AtomicReference<>("A");
            
            public static void demonstrate() {
                System.out.println("\n=== ABA问题演示 ===\n");
                
                Thread t1 = new Thread(() -> {
                    String expected = "A";
                    String newValue = "B";
                    
                    System.out.println("线程1: 尝试将 A -> B");
                    boolean success = atomicRef.compareAndSet(expected, newValue);
                    System.out.println("线程1: CAS结果 = " + success);
                    
                    if (success) {
                        System.out.println("线程1: 成功将 A -> B");
                    }
                });
                
                Thread t2 = new Thread(() -> {
                    try {
                        // 等待t1执行CAS
                        Thread.sleep(50);
                        
                        System.out.println("线程2: 将 B -> A");
                        atomicRef.set("A");
                        System.out.println("线程2: 已恢复为 A");
                        
                        // 再次修改
                        Thread.sleep(50);
                        System.out.println("线程2: 将 A -> C");
                        atomicRef.set("C");
                    } catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                    }
                });
                
                t1.start();
                t2.start();
                
                try {
                    t1.join();
                    t2.join();
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
                
                System.out.println("\n最终值: " + atomicRef.get());
                System.out.println("尽管值从A->B->A变化了,但CAS操作仍然成功,这就是ABA问题!");
            }
        }
        
        // 解决ABA问题的AtomicStampedReference
        static class ABASolution {
            private static AtomicStampedReference<String> atomicStampedRef = 
                new AtomicStampedReference<>("A", 0);
            
            public static void demonstrate() {
                System.out.println("\n=== 使用AtomicStampedReference解决ABA问题 ===\n");
                
                int[] stampHolder = new int[1];
                String current = atomicStampedRef.get(stampHolder);
                int currentStamp = stampHolder[0];
                
                System.out.println("初始值: " + current + ", 版本: " + currentStamp);
                
                Thread t1 = new Thread(() -> {
                    try {
                        int[] holder = new int[1];
                        String expected = atomicStampedRef.get(holder);
                        int expectedStamp = holder[0];
                        
                        System.out.println("线程1: 读取值=" + expected + ", 版本=" + expectedStamp);
                        
                        // 模拟处理时间
                        Thread.sleep(100);
                        
                        String newValue = "B";
                        boolean success = atomicStampedRef.compareAndSet(
                            expected, newValue, 
                            expectedStamp, expectedStamp + 1
                        );
                        
                        System.out.println("线程1: CAS(版本" + expectedStamp + " -> " + (expectedStamp + 1) + 
                                         ") 结果 = " + success);
                    } catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                    }
                });
                
                Thread t2 = new Thread(() -> {
                    try {
                        int[] holder = new int[1];
                        String currentValue = atomicStampedRef.get(holder);
                        int currentStamp2 = holder[0];
                        
                        System.out.println("线程2: 当前值=" + currentValue + ", 版本=" + currentStamp2);
                        
                        // 修改为B
                        atomicStampedRef.set("B", currentStamp2 + 1);
                        System.out.println("线程2: 修改为 B, 版本=" + (currentStamp2 + 1));
                        
                        Thread.sleep(50);
                        
                        // 修改回A,但版本号不同
                        currentValue = atomicStampedRef.get(holder);
                        currentStamp2 = holder[0];
                        atomicStampedRef.set("A", currentStamp2 + 1);
                        System.out.println("线程2: 修改回 A, 版本=" + (currentStamp2 + 1));
                    } catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                    }
                });
                
                t1.start();
                t2.start();
                
                try {
                    t1.join();
                    t2.join();
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
                
                String finalValue = atomicStampedRef.get(stampHolder);
                int finalStamp = stampHolder[0];
                System.out.println("\n最终值: " + finalValue + ", 最终版本: " + finalStamp);
                System.out.println("通过版本戳,成功检测到了中间的状态变化!");
            }
        }
        
        // 高性能计数器比较
        static class CounterBenchmark {
            interface Counter {
                void increment();
                long get();
            }
            
            // 使用synchronized的计数器
            static class SynchronizedCounter implements Counter {
                private long value = 0;
                
                @Override
                public synchronized void increment() {
                    value++;
                }
                
                @Override
                public synchronized long get() {
                    return value;
                }
            }
            
            // 使用AtomicLong的计数器
            static class AtomicCounter implements Counter {
                private final AtomicLong value = new AtomicLong(0);
                
                @Override
                public void increment() {
                    value.incrementAndGet();
                }
                
                @Override
                public long get() {
                    return value.get();
                }
            }
            
            // 使用LongAdder的计数器(Java 8+)
            static class LongAdderCounter implements Counter {
                private final LongAdder value = new LongAdder();
                
                @Override
                public void increment() {
                    value.increment();
                }
                
                @Override
                public long get() {
                    return value.sum();
                }
            }
            
            public static void benchmark() throws InterruptedException {
                System.out.println("\n=== 高性能计数器基准测试 ===\n");
                
                Counter[] counters = {
                    new SynchronizedCounter(),
                    new AtomicCounter(),
                    new LongAdderCounter()
                };
                
                String[] names = {
                    "SynchronizedCounter",
                    "AtomicCounter",
                    "LongAdderCounter"
                };
                
                int threadCount = 10;
                int iterations = 1000000;
                
                for (int i = 0; i < counters.length; i++) {
                    System.out.println("测试: " + names[i]);
                    
                    Counter counter = counters[i];
                    List<Thread> threads = new ArrayList<>();
                    
                    long startTime = System.currentTimeMillis();
                    
                    // 创建线程
                    for (int t = 0; t < threadCount; t++) {
                        Thread thread = new Thread(() -> {
                            for (int j = 0; j < iterations; j++) {
                                counter.increment();
                            }
                        });
                        threads.add(thread);
                    }
                    
                    // 启动所有线程
                    threads.forEach(Thread::start);
                    
                    // 等待所有线程完成
                    for (Thread thread : threads) {
                        thread.join();
                    }
                    
                    long endTime = System.currentTimeMillis();
                    long duration = endTime - startTime;
                    
                    System.out.printf("  结果: %,d, 时间: %,d ms, 吞吐量: %,d ops/s%n",
                        counter.get(), duration, 
                        (threadCount * iterations * 1000L) / duration);
                }
            }
        }
        
        public static void runAll() throws InterruptedException {
            ABAProblem.demonstrate();
            ABASolution.demonstrate();
            CounterBenchmark.benchmark();
        }
    }
    
    // 4. 并发队列的性能比较
    static class ConcurrentQueueComparison {
        
        static class Producer implements Runnable {
            private final BlockingQueue<Integer> queue;
            private final int count;
            private final CountDownLatch latch;
            
            Producer(BlockingQueue<Integer> queue, int count, CountDownLatch latch) {
                this.queue = queue;
                this.count = count;
                this.latch = latch;
            }
            
            @Override
            public void run() {
                try {
                    for (int i = 0; i < count; i++) {
                        queue.put(i);
                    }
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                } finally {
                    latch.countDown();
                }
            }
        }
        
        static class Consumer implements Runnable {
            private final BlockingQueue<Integer> queue;
            private final CountDownLatch latch;
            private final LongAdder counter;
            
            Consumer(BlockingQueue<Integer> queue, CountDownLatch latch, LongAdder counter) {
                this.queue = queue;
                this.latch = latch;
                this.counter = counter;
            }
            
            @Override
            public void run() {
                try {
                    while (!Thread.currentThread().isInterrupted()) {
                        Integer value = queue.poll(10, TimeUnit.MILLISECONDS);
                        if (value != null) {
                            counter.increment();
                        }
                        
                        // 检查是否所有生产者都完成了
                        if (latch.getCount() == 0 && queue.isEmpty()) {
                            break;
                        }
                    }
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            }
        }
        
        public static void compare() throws InterruptedException {
            System.out.println("\n=== 并发队列性能比较 ===\n");
            
            BlockingQueue<Integer>[] queues = new BlockingQueue[] {
                new LinkedBlockingQueue<>(10000),
                new ArrayBlockingQueue<>(10000),
                new ConcurrentLinkedQueue<>(),
                new LinkedTransferQueue<>(),
                new PriorityBlockingQueue<>(10000)
            };
            
            String[] queueNames = {
                "LinkedBlockingQueue",
                "ArrayBlockingQueue",
                "ConcurrentLinkedQueue",
                "LinkedTransferQueue",
                "PriorityBlockingQueue"
            };
            
            int producerCount = 5;
            int consumerCount = 5;
            int itemsPerProducer = 200000;
            
            for (int i = 0; i < queues.length; i++) {
                System.out.println("测试队列: " + queueNames[i]);
                
                BlockingQueue<Integer> queue = queues[i];
                CountDownLatch producerLatch = new CountDownLatch(producerCount);
                LongAdder consumedCounter = new LongAdder();
                
                // 创建生产者
                List<Thread> producers = new ArrayList<>();
                for (int p = 0; p < producerCount; p++) {
                    producers.add(new Thread(
                        new Producer(queue, itemsPerProducer, producerLatch)
                    ));
                }
                
                // 创建消费者
                List<Thread> consumers = new ArrayList<>();
                for (int c = 0; c < consumerCount; c++) {
                    consumers.add(new Thread(
                        new Consumer(queue, producerLatch, consumedCounter)
                    ));
                }
                
                long startTime = System.currentTimeMillis();
                
                // 启动所有线程
                consumers.forEach(Thread::start);
                producers.forEach(Thread::start);
                
                // 等待生产者完成
                producerLatch.await();
                
                // 等待消费者处理完所有项目
                while (consumedCounter.sum() < producerCount * itemsPerProducer) {
                    Thread.sleep(10);
                }
                
                long endTime = System.currentTimeMillis();
                
                // 中断消费者线程
                consumers.forEach(Thread::interrupt);
                for (Thread consumer : consumers) {
                    consumer.join(1000);
                }
                
                long duration = endTime - startTime;
                long totalItems = producerCount * itemsPerProducer;
                
                System.out.printf("  生产: %,d 项目, 消费: %,d 项目%n", 
                    totalItems, consumedCounter.sum());
                System.out.printf("  时间: %,d ms, 吞吐量: %,d ops/s%n",
                    duration, (totalItems * 1000L) / duration);
                System.out.printf("  队列最大大小: %,d%n", queue.size());
            }
        }
    }
    
    public static void main(String[] args) throws Exception {
        System.out.println("=== 并发集合与原子类深度分析 ===\n");
        
        // 分析CopyOnWriteArrayList
        CopyOnWriteAnalysis.analyzeCOWBehavior();
        
        // 分析原子类和CAS
        AtomicAnalysis.runAll();
        
        // 比较并发队列
        ConcurrentQueueComparison.compare();
        
        // 测试分段HashMap
        System.out.println("\n=== 分段HashMap测试 ===");
        testSegmentedHashMap();
    }
    
    private static void testSegmentedHashMap() {
        SegmentedHashMap<String, Integer> map = new SegmentedHashMap<>();
        
        // 并发测试
        int threadCount = 10;
        int operationsPerThread = 10000;
        
        List<Thread> threads = new ArrayList<>();
        AtomicInteger successCount = new AtomicInteger();
        
        for (int t = 0; t < threadCount; t++) {
            Thread thread = new Thread(() -> {
                for (int i = 0; i < operationsPerThread; i++) {
                    String key = "key-" + Thread.currentThread().getId() + "-" + i;
                    map.put(key, i);
                    successCount.incrementAndGet();
                }
            });
            threads.add(thread);
        }
        
        threads.forEach(Thread::start);
        threads.forEach(t -> {
            try { t.join(); } catch (InterruptedException e) {}
        });
        
        System.out.printf("并发插入完成,成功操作数: %,d%n", successCount.get());
    }
}

总结:并发编程的最佳实践

关键要点:

  1. 理解AQS原理:大多数并发工具都基于AQS,理解其等待队列和状态管理是基础
  2. 合理配置线程池:根据任务特性(CPU密集型 vs I/O密集型)选择合适的参数
  3. 监控与调优:使用监控工具了解线程池状态,动态调整参数
  4. 避免锁竞争:使用并发集合、原子类减少锁竞争
  5. 注意内存可见性:正确使用volatile和原子操作

工具选择指南:

场景 推荐工具 理由
高并发计数器 LongAdder 比AtomicLong更高的吞吐量
读多写少的集合 CopyOnWriteArrayList 读操作完全无锁
高并发映射 ConcurrentHashMap 分段锁/Node+CAS
任务调度 ScheduledThreadPoolExecutor 定时任务支持
生产者消费者 LinkedBlockingQueue 容量控制,阻塞操作

性能陷阱与解决方案:

  1. 线程池队列过大:导致内存溢出,使用有界队列
  2. 锁竞争激烈:使用读写锁或并发集合
  3. 上下文切换过多:减少线程数量,使用协程(虚拟线程)
  4. CAS的ABA问题:使用AtomicStampedReference
相关推荐
xing-xing7 小时前
Java集合Map总结
java
古城小栈7 小时前
性能边界:何时用 Go 何时用 Java 的技术选型指南
java·后端·golang
古城小栈7 小时前
Go 异步编程:无锁数据结构实现原理
java·数据结构·golang
黄旺鑫7 小时前
系统安全设计规范 · 短信风控篇【参考】
java·经验分享·系统·验证码·设计规范·短信·风控
算法与双吉汉堡8 小时前
【短链接项目笔记】Day1 用户模块
java·spring boot·笔记·后端
一念一花一世界8 小时前
Arbess从基础到实践(23) - 集成GitLab+Hadess实现Java项目构建并上传制品
java·gitlab·cicd·arbess·制品库
啃火龙果的兔子8 小时前
Java 学习路线及学习周期
java·开发语言·学习
Selegant8 小时前
Quarkus vs Spring Boot:谁更适合云原生时代的 Java 开发?
java·spring boot·云原生