基于 AQS 快速实现可重入锁

在 Java 中, AQS(抽象的队列同步器) 是重量级基础框架及整个 JUC 体系的基石,主要是用于解决锁分配给谁的问题。

其制定了一套多线程场景下访问共享资源的方案,其统一规范并简化了锁的实现,屏蔽了同步状态管理、同步队列的管理和维护、阻塞线程排队和通知、唤醒机制等,Java 中很多同步类底层都是使用 AQS 实现,比如:ReentrantLockCountDownLatchReentrantReadWriteLock,它们里头有一个 Sync 类的属性,该类是继承 AQS 的。

简单来说,AQS 主要是维护着抽象的FIFO队列来完成资源获取线程的排队工作,并通过一个 int 类型变量 state 来表示持有锁的状态,state 是通过 volatile 修饰的,确保可见性。

为了加深对 AQS 的理解,可以基于 AQS 自实现一个可重入锁,类似于 ReentrantLock 的简化版实现。

继承 AbstractQueuedSynchronizer,主要是需要实现 tryAcquire 和 tryRelease 这两个钩子,分别是加锁和释放锁的逻辑。

主要实现

java 复制代码
package com.txing.aqs;

import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;

/**
 * 基于AQS实现的可重入锁
 * 类似于ReentrantLock的简化版实现
 */
public class CustomReentrantLock implements Lock {
    
    // 内部同步器类,继承自AQS
    private static class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = -5179523762034025860L;
        
        // 是否为独占模式
        protected final boolean isHeldExclusively() {
            return getState() > 0 && getExclusiveOwnerThread() == Thread.currentThread();
        }
        
        // 尝试获取锁(非公平模式)
        protected final boolean tryAcquire(int acquires) {
            final Thread current = Thread.currentThread();
            int c = getState();
            
            // 如果锁未被持有
            if (c == 0) {
                // 使用CAS尝试设置状态并设置独占线程
                if (compareAndSetState(0, acquires)) {
                    setExclusiveOwnerThread(current);
                    return true;
                }
            } 
            // 如果当前线程已经持有锁,实现可重入
            else if (current == getExclusiveOwnerThread()) {
                int nextc = c + acquires;
                // 检查是否溢出
                if (nextc < 0) {
                    throw new Error("Maximum lock count exceeded");
                }
                setState(nextc);
                return true;
            }
            return false;
        }
        
        // 尝试释放锁
        protected final boolean tryRelease(int releases) {
            int c = getState() - releases;
            
            // 如果当前线程不是锁的持有者,抛出异常
            if (Thread.currentThread() != getExclusiveOwnerThread()) {
                throw new IllegalMonitorStateException();
            }
            
            boolean free = false;
            // 如果状态为0,表示锁已完全释放
            if (c == 0) {
                free = true;
                setExclusiveOwnerThread(null);
            }
            setState(c);
            return free;
        }
        
        // 创建新的条件变量
        Condition newCondition() {
            return new ConditionObject();
        }
        
        // 获取锁的状态,供外部使用
        final int getHoldCount() {
            return getState();
        }
    }
    
    // 同步器实例
    private final Sync sync = new Sync();
    
    @Override
    public void lock() {
        sync.acquire(1);
    }
    
    @Override
    public boolean tryLock() {
        return sync.tryAcquire(1);
    }
    
    @Override
    public void unlock() {
        sync.release(1);
    }
    
    @Override
    public Condition newCondition() {
        return sync.newCondition();
    }
    
    @Override
    public void lockInterruptibly() throws InterruptedException {
        sync.acquireInterruptibly(1);
    }
    
    @Override
    public boolean tryLock(long timeout, TimeUnit unit) throws InterruptedException {
        return sync.tryAcquireNanos(1, unit.toNanos(timeout));
    }
    
    // 获取当前锁状态,主要用于测试
    public int getHoldCount() {
        return sync.getHoldCount();
    }
    
    // 判断当前线程是否持有锁,主要用于测试
    public boolean isHeldByCurrentThread() {
        return sync.isHeldExclusively();
    }
    
    // 判断锁是否被任何线程持有,主要用于测试
    public boolean isLocked() {
        return sync.getHoldCount() > 0;
    }
} 

使用示例

java 复制代码
package com.txing.aqs;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

/**
 * 自定义锁使用示例
 */
public class LockExample {
    // 共享资源
    private static int counter = 0;
    // 自定义锁
    private static final CustomReentrantLock lock = new CustomReentrantLock();
    
    public static void main(String[] args) throws InterruptedException {
        int numThreads = 10;
        int incrementsPerThread = 1000;
        
        ExecutorService executor = Executors.newFixedThreadPool(numThreads);
        CountDownLatch latch = new CountDownLatch(numThreads);
        
        System.out.println("开始测试自定义锁...");
        long startTime = System.currentTimeMillis();
        
        // 创建多个线程,每个线程对计数器进行递增操作
        for (int i = 0; i < numThreads; i++) {
            executor.submit(() -> {
                try {
                    for (int j = 0; j < incrementsPerThread; j++) {
                        // 使用锁保护临界区
                        lock.lock();
                        try {
                            // 临界区:递增计数器
                            counter++;
                            
                            // 演示可重入性
                            incrementNestedCounter();
                        } finally {
                            lock.unlock();
                        }
                    }
                } finally {
                    latch.countDown();
                }
            });
        }
        
        // 等待所有线程完成
        latch.await();
        long endTime = System.currentTimeMillis();
        
        // 关闭线程池
        executor.shutdown();
        executor.awaitTermination(1, TimeUnit.SECONDS);
        
        // 输出结果
        System.out.println("预期计数器值: " + (numThreads * incrementsPerThread * 2));
        System.out.println("实际计数器值: " + counter);
        System.out.println("执行时间: " + (endTime - startTime) + " ms");
    }
    
    // 演示锁的可重入性
    private static void incrementNestedCounter() {
        // 再次获取锁(可重入)
        lock.lock();
        try {
            // 这里可以安全地访问共享资源,因为锁是可重入的
            counter++;
        } finally {
            lock.unlock();
        }
    }
} 

单元测试

java 复制代码
package com.txing.aqs;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import static org.junit.jupiter.api.Assertions.*;

/**
 * CustomReentrantLock的单元测试类
 */
public class CustomReentrantLockTest {

    /**
     * 测试基本的锁获取和释放功能
     */
    @Test
    public void testBasicLockUnlock() {
        CustomReentrantLock lock = new CustomReentrantLock();
        
        // 获取锁
        lock.lock();
        assertTrue(lock.isLocked());
        assertTrue(lock.isHeldByCurrentThread());
        assertEquals(1, lock.getHoldCount());
        
        // 释放锁
        lock.unlock();
        assertFalse(lock.isLocked());
        assertFalse(lock.isHeldByCurrentThread());
        assertEquals(0, lock.getHoldCount());
    }
    
    /**
     * 测试锁的可重入性
     */
    @Test
    public void testReentrant() {
        CustomReentrantLock lock = new CustomReentrantLock();
        
        // 第一次获取锁
        lock.lock();
        assertEquals(1, lock.getHoldCount());
        
        // 第二次获取锁(可重入)
        lock.lock();
        assertEquals(2, lock.getHoldCount());
        
        // 第一次释放
        lock.unlock();
        assertEquals(1, lock.getHoldCount());
        assertTrue(lock.isLocked());
        
        // 第二次释放
        lock.unlock();
        assertEquals(0, lock.getHoldCount());
        assertFalse(lock.isLocked());
    }
    
    /**
     * 测试tryLock方法
     */
    @Test
    public void testTryLock() {
        CustomReentrantLock lock = new CustomReentrantLock();
        
        // 尝试获取锁应该成功
        assertTrue(lock.tryLock());
        assertTrue(lock.isLocked());
        
        // 释放锁
        lock.unlock();
        assertFalse(lock.isLocked());
    }
    
    /**
     * 测试带超时的tryLock方法
     */
    @Test
    public void testTryLockTimeout() throws InterruptedException {
        CustomReentrantLock lock = new CustomReentrantLock();
        
        // 尝试获取锁应该成功
        assertTrue(lock.tryLock(1000, TimeUnit.MILLISECONDS));
        assertTrue(lock.isLocked());
        
        // 释放锁
        lock.unlock();
        assertFalse(lock.isLocked());
    }
    
    /**
     * 测试多线程环境下的锁竞争
     */
    @Test
    @Timeout(value = 5, unit = TimeUnit.SECONDS)
    public void testMultiThreadedAccess() throws InterruptedException {
        final CustomReentrantLock lock = new CustomReentrantLock();
        final AtomicInteger counter = new AtomicInteger(0);
        final int numThreads = 10;
        final int incrementsPerThread = 1000;
        final CountDownLatch startLatch = new CountDownLatch(1);
        final CountDownLatch finishLatch = new CountDownLatch(numThreads);
        
        // 创建多个线程,每个线程对计数器进行递增操作
        for (int i = 0; i < numThreads; i++) {
            new Thread(() -> {
                try {
                    // 等待所有线程准备就绪
                    startLatch.await();
                    
                    for (int j = 0; j < incrementsPerThread; j++) {
                        lock.lock();
                        try {
                            // 临界区:递增计数器
                            counter.incrementAndGet();
                        } finally {
                            lock.unlock();
                        }
                    }
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                } finally {
                    finishLatch.countDown();
                }
            }).start();
        }
        
        // 启动所有线程
        startLatch.countDown();
        
        // 等待所有线程完成
        finishLatch.await();
        
        // 验证计数器的值是否正确
        assertEquals(numThreads * incrementsPerThread, counter.get());
    }
    
    /**
     * 测试锁的互斥性
     */
    @Test
    @Timeout(value = 5, unit = TimeUnit.SECONDS)
    public void testMutualExclusion() throws InterruptedException {
        final CustomReentrantLock lock = new CustomReentrantLock();
        final CountDownLatch latch = new CountDownLatch(1);
        final AtomicInteger lockHolders = new AtomicInteger(0);
        final AtomicInteger maxConcurrentHolders = new AtomicInteger(0);
        
        // 创建第一个线程,获取锁并持有一段时间
        Thread t1 = new Thread(() -> {
            lock.lock();
            try {
                // 记录当前持有锁的线程数
                int holders = lockHolders.incrementAndGet();
                maxConcurrentHolders.set(Math.max(maxConcurrentHolders.get(), holders));
                
                // 通知第二个线程开始尝试获取锁
                latch.countDown();
                
                // 持有锁一段时间
                try {
                    Thread.sleep(500);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
                
                // 减少持有锁的线程计数
                lockHolders.decrementAndGet();
            } finally {
                lock.unlock();
            }
        });
        
        // 创建第二个线程,尝试获取锁
        Thread t2 = new Thread(() -> {
            try {
                // 等待第一个线程获取锁
                latch.await();
                
                // 尝试获取锁
                lock.lock();
                try {
                    // 记录当前持有锁的线程数
                    int holders = lockHolders.incrementAndGet();
                    maxConcurrentHolders.set(Math.max(maxConcurrentHolders.get(), holders));
                    
                    // 减少持有锁的线程计数
                    lockHolders.decrementAndGet();
                } finally {
                    lock.unlock();
                }
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        });
        
        // 启动线程
        t1.start();
        t2.start();
        
        // 等待线程完成
        t1.join();
        t2.join();
        
        // 验证同一时间只有一个线程持有锁
        assertEquals(1, maxConcurrentHolders.get());
    }

    /**
     * 测试Condition条件变量功能
     */
    @Test
    @Timeout(value = 5, unit = TimeUnit.SECONDS)
    public void testCondition() throws InterruptedException {
        final CustomReentrantLock lock = new CustomReentrantLock();
        final java.util.concurrent.locks.Condition condition = lock.newCondition();
        final AtomicBoolean conditionMet = new AtomicBoolean(false);
        final AtomicBoolean threadWaited = new AtomicBoolean(false);
        final AtomicBoolean threadSignaled = new AtomicBoolean(false);
        
        // 创建等待线程
        Thread waiter = new Thread(() -> {
            lock.lock();
            try {
                // 如果条件未满足,等待
                if (!conditionMet.get()) {
                    threadWaited.set(true);
                    try {
                        // 等待条件变量的信号
                        condition.await();
                    } catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                        return;
                    }
                }
                
                // 确认条件已满足
                assertTrue(conditionMet.get(), "条件应该已经满足");
            } finally {
                lock.unlock();
            }
        });
        
        // 创建通知线程
        Thread signaler = new Thread(() -> {
            // 给等待线程一点时间先运行
            try {
                Thread.sleep(200);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                return;
            }
            
            lock.lock();
            try {
                // 设置条件已满足
                conditionMet.set(true);
                
                // 发送信号
                condition.signal();
                threadSignaled.set(true);
            } finally {
                lock.unlock();
            }
        });
        
        // 启动线程
        waiter.start();
        signaler.start();
        
        // 等待线程完成
        waiter.join();
        signaler.join();
        
        // 验证线程确实等待了并且收到了信号
        assertTrue(threadWaited.get(), "等待线程应该进入等待状态");
        assertTrue(threadSignaled.get(), "通知线程应该发送了信号");
    }
    
    /**
     * 测试Condition的多线程生产者-消费者模式
     */
    @Test
    @Timeout(value = 5, unit = TimeUnit.SECONDS)
    public void testProducerConsumerWithCondition() throws InterruptedException {
        final CustomReentrantLock lock = new CustomReentrantLock();
        final java.util.concurrent.locks.Condition notEmpty = lock.newCondition();
        final java.util.concurrent.locks.Condition notFull = lock.newCondition();
        
        final int BUFFER_SIZE = 5;
        final int ITEM_COUNT = 20;
        final int[] buffer = new int[BUFFER_SIZE];
        
        // 使用AtomicInteger来跟踪索引和计数,以解决Lambda表达式中的final变量问题
        final AtomicInteger putIndexRef = new AtomicInteger(0);
        final AtomicInteger takeIndexRef = new AtomicInteger(0);
        final AtomicInteger countRef = new AtomicInteger(0);
        
        final AtomicInteger produced = new AtomicInteger(0);
        final AtomicInteger consumed = new AtomicInteger(0);
        
        // 创建生产者线程
        Thread producer = new Thread(() -> {
            for (int i = 0; i < ITEM_COUNT; i++) {
                lock.lock();
                try {
                    // 缓冲区已满,等待
                    while (countRef.get() == BUFFER_SIZE) {
                        try {
                            notFull.await();
                        } catch (InterruptedException e) {
                            Thread.currentThread().interrupt();
                            return;
                        }
                    }
                    
                    // 放入元素
                    int putIndex = putIndexRef.get();
                    buffer[putIndex] = i;
                    putIndexRef.set((putIndex + 1) % BUFFER_SIZE);
                    countRef.incrementAndGet();
                    produced.incrementAndGet();
                    
                    // 通知消费者有新元素可用
                    notEmpty.signal();
                } finally {
                    lock.unlock();
                }
            }
        });
        
        // 创建消费者线程
        Thread consumer = new Thread(() -> {
            for (int i = 0; i < ITEM_COUNT; i++) {
                lock.lock();
                try {
                    // 缓冲区为空,等待
                    while (countRef.get() == 0) {
                        try {
                            notEmpty.await();
                        } catch (InterruptedException e) {
                            Thread.currentThread().interrupt();
                            return;
                        }
                    }
                    
                    // 取出元素
                    int takeIndex = takeIndexRef.get();
                    int item = buffer[takeIndex];
                    takeIndexRef.set((takeIndex + 1) % BUFFER_SIZE);
                    countRef.decrementAndGet();
                    consumed.incrementAndGet();
                    
                    // 通知生产者有空间可用
                    notFull.signal();
                } finally {
                    lock.unlock();
                }
            }
        });
        
        // 启动线程
        producer.start();
        consumer.start();
        
        // 等待线程完成
        producer.join();
        consumer.join();
        
        // 验证所有元素都被生产和消费
        assertEquals(ITEM_COUNT, produced.get(), "应该生产指定数量的元素");
        assertEquals(ITEM_COUNT, consumed.get(), "应该消费与生产相同数量的元素");
    }
} 

总结

通过实现自定义锁,可深入理解 AQS 的​​状态管理​ ​、​​线程排队​ ​和​​条件唤醒​​机制,为设计高性能并发组件奠定基础。

相关推荐
Andrew_Ryan2 分钟前
Chisel 工具使用教程
后端
AntBlack20 分钟前
体验了一把 paddleocr , 顺便撸了一个 桌面端 PDF识别工具
后端·python·计算机视觉
Xxtaoaooo24 分钟前
手撕Spring底层系列之:注解驱动的魔力与实现内幕
java·开发语言·后端开发·spring框架·原理解析
街霸星星33 分钟前
使用 vfox 高效配置 Java 开发环境:一份全面指南
java
用户15129054522037 分钟前
C 语言教程
前端·后端
♛暮辞38 分钟前
java程序远程写入字符串到hadoop伪分布式
java·hadoop·分布式
UestcXiye38 分钟前
Rust Web 全栈开发(十):编写服务器端 Web 应用
前端·后端·mysql·rust·actix
用户1512905452201 小时前
基于YOLOv10算法的交通信号灯检测与识别
后端
用户1512905452201 小时前
Netstat命令详解(windows下)
后端
用户1512905452201 小时前
GetTickCount() 函数的作用和用法
后端