Java基于数据库的分布式可重入锁(带等待时间和过期时间)

文章目录

项目代码

技术背景介绍

一般分布式锁使用最方便的就是使用redis实现,因为他自带超时过期机制、发布订阅模式、高吞吐高性能的优势,但是有些项目里只有mysql数据库,很多数据库都是没有数据超时过期机制和发布订阅模式的,当然也不是所有的,这里我只针对mysql数据库作为基础组件。

代码实现

数据库表结构

sql 复制代码
DROP TABLE IF EXISTS `distributed_lock`;
CREATE TABLE `distributed_lock` (
  `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'id',
  `lock_name` varchar(255) NOT NULL COMMENT '锁名',
  `machine_id` varchar(255) DEFAULT NULL COMMENT '服务器id',
  `expire_time` datetime DEFAULT NULL COMMENT '过期时间,服务里会有一个看门狗续期,如果过期了就说明服务挂了,解锁会设置为空',
  `is_locked` tinyint(4) NOT NULL DEFAULT '0' COMMENT '当前是否锁定状态',
  `state` int(11) NOT NULL DEFAULT '0' COMMENT '锁标记位 类似次数',
  `thread_id` varchar(255) DEFAULT NULL COMMENT '当前获得锁的线程id',
  `gmt_create` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
  `gmt_modified` datetime DEFAULT NULL ON UPDATE CURRENT_TIMESTAMP COMMENT '修改时间',
  `is_deleted` tinyint(4) NOT NULL DEFAULT '0' COMMENT '是否删除',
  PRIMARY KEY (`id`) USING BTREE,
  UNIQUE KEY `idx_lock_name` (`lock_name`) USING BTREE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

尝试获取锁

使用乐观锁模式更新锁记录。如果获取失败,则加入订阅列表中,等待被唤醒或者到达超时时间自动唤醒,待获取到锁后再从订阅列表中移除。他的具体等待时间取决于用户输入的等待时间和锁超时过期的时间,这里使用JUC的Semaphore来实现等待功能。

java 复制代码
public boolean tryLock(String lockName, Long waitTime, Long leaseTime, TimeUnit timeUnit) {
        long startTime = System.currentTimeMillis();
        String threadId = getCurrentThreadId();
        Long ttl = tryAcquire(lockName, leaseTime, timeUnit);
        // lock acquired
        if (ttl == null) {
            return true;
        }

        long time = timeUnit.toMillis(waitTime);
        if (waitTime != -1 && System.currentTimeMillis() - startTime < time) {
            //没有获取到锁,也没到等待时长,执行订阅释放锁的任务
            LockEntry lockEntry = subscribe(lockName, threadId, () -> {
            });

            try {
                while (true) {
                    ttl = tryAcquire(lockName, leaseTime, timeUnit);
                    // lock acquired
                    if (ttl == null) {
                        return true;
                    }
                    long remainTtl = time - System.currentTimeMillis() + startTime;
                    if (remainTtl < 0) {
                        return false;
                    }

                    // waiting for message
                    lockEntry.getLatch().tryAcquire(ttl >= 0 && ttl < remainTtl ? ttl : remainTtl, TimeUnit.MILLISECONDS);
                }
            } catch (InterruptedException e) {
                log.error("thread interrupted", e);
                throw new RuntimeException(e);
            } finally {
                unsubscribe(lockEntry, lockName);
            }
        } else {
            return false;
        }
    }

private Long tryAcquire(String lockName, long leaseTime, TimeUnit unit) {
        String currentThreadId = getCurrentThreadId();
        //设定了自动释放锁的时间
        if (leaseTime != -1) {
            return tryLockInner(leaseTime, unit, lockName, currentThreadId);
        }
        //没有设置自动过期时间,就需要在获取到之后使用看门狗续期
        Long remainTtl = tryLockInner(internalLockLeaseTime, TimeUnit.MILLISECONDS, lockName, currentThreadId);
        // lock acquired
        if (remainTtl == null) {
            scheduleExpirationRenewal(lockName, currentThreadId);
        }
        return remainTtl;
    }

    /**
     * 加锁成功返回null,否则返回锁的过期时间
     *
     * @param leaseTime
     * @param unit
     * @param lockName
     * @param threadId
     * @return
     */
    private Long tryLockInner(long leaseTime, TimeUnit unit, String lockName, String threadId) {
        long internalLockLeaseTime = unit.toMillis(leaseTime);

        //查询是否存在锁
        LockObject existLock = lockRepository.queryLock(lockName);
        LockObject lockObject = new LockObject();
        lockObject.setLockName(lockName);
        lockObject.setThreadId(threadId);
        lockObject.setMachineId(machineId);
        lockObject.setIsLocked(true);
        lockObject.setExpireTime(new Date(System.currentTimeMillis() + internalLockLeaseTime));
        if (existLock == null) {
            //保存锁
            lockObject.setState(1);
            try {
                lockRepository.save(lockObject);
            } catch (Exception e) {
                //抛出数据重复异常,说明被其他线程锁定了
                //返回需要等待的时间
                log.error("lock other thread occupy", e);
                return reCheckTtl(leaseTime, unit, lockName, threadId);
            }
        } else {
            //存在的锁会判断是否是当前线程的,如果是也允许加锁成功,支持可重入
            //如果正好其他锁释放了,那也会抢锁,具体是否公平由各数据库的内部锁决定
            int updateNum = lockRepository.reentrantLock(lockObject);
            if (updateNum == 0) {
                //返回需要等待的时间
                return reCheckTtl(leaseTime, unit, lockName, threadId);
            }
        }
        //加锁成功
        return null;
    }

    private Long reCheckTtl(long leaseTime, TimeUnit unit, String lockName, String threadId) {
        Long ttl = queryLockTtl(lockName);
        if (ttl == null) {
            //如果返回null,那就是获取锁的时候失败了,但是执行查询锁的过期时间的时候释放了
            //就需要重新执行上锁逻辑
            return tryLockInner(leaseTime, unit, lockName, threadId);
        } else {
            return ttl;
        }
    }

    /**
     * 获取锁的释放时间,单位毫秒,
     * 如果锁不存在 或者 未上锁 或者 已过期 则返回null
     *
     * @param lockName
     * @return
     */
    private Long queryLockTtl(String lockName) {
        LockObject lockObject = lockRepository.queryLock(lockName);
        if (lockObject != null && lockObject.getExpireTime() != null) {
            long intervalTime = lockObject.getExpireTime().getTime() - System.currentTimeMillis();
            if (intervalTime > 0) {
                return intervalTime;
            }
        }
        return null;
    }
sql 复制代码
<update id="updateReentrantLock">
        update distributed_lock
        <set>
            is_locked   = true,
            machine_id   = #{machineId,jdbcType=VARCHAR},
            thread_id   = #{threadId,jdbcType=VARCHAR},
            state       = if(expire_time &lt; NOW(), 1, state + 1),
            expire_time = #{expireTime,jdbcType=TIMESTAMP}
        </set>
        where is_deleted = 0
          and lock_name = #{lockName,jdbcType=VARCHAR}
          and (
                expire_time &lt; NOW()
                or is_locked = false
                or (machine_id = #{machineId,jdbcType=VARCHAR}
                and thread_id = #{threadId,jdbcType=VARCHAR})
            )
    </update>

续约

如果锁没有设置过期时间,那么就需要设置自动续期,使用过期和续期的目的也是为了防止服务宕机导致锁无法释放的问题。如果续期失败说明锁已经释放了,那么会自动停止锁的续约任务。

java 复制代码
private void scheduleExpirationRenewal(String lockName, String threadId) {
        ExpirationEntry entry = new ExpirationEntry(lockName, threadId);
        ExpirationEntry oldEntry = expirationRenewalMap.putIfAbsent(expirationRenewalKey(lockName, threadId), entry);
        if (oldEntry != null) {
            oldEntry.addCount();
        } else {
            //只对第一次获取锁的线程续约,后面的属于重入
            renewExpiration(lockName, threadId);
        }
    }

    private void renewExpiration(String lockName, String threadId) {
        String keyName = expirationRenewalKey(lockName, threadId);
        ExpirationEntry ee = expirationRenewalMap.get(keyName);
        if (ee == null) {
            return;
        }

        //获取到锁后过1/3时间开启续约任务
        scheduledExecutor.schedule(() -> {
            ExpirationEntry ent = expirationRenewalMap.get(keyName);
            if (ent == null) {
                return;
            }

            boolean renewResult = renewExpirationLock(lockName, ent.getThreadId());
            if (!renewResult) {
                //更新失败说明锁被释放了
                log.error("Can't update lock " + lockName + " expiration");
                expirationRenewalMap.remove(keyName);
                return;
            }
            // reschedule itself
            renewExpiration(lockName, threadId);
        }, internalLockLeaseTime / 3, TimeUnit.MILLISECONDS);

    }

    private void cancelExpirationRenewal(String lockName, String threadId) {
        String keyName = expirationRenewalKey(lockName, threadId);
        ExpirationEntry task = expirationRenewalMap.get(keyName);
        if (task == null) {
            return;
        }
        Integer count = task.reduceCount();

        if (count == 0) {
            expirationRenewalMap.remove(keyName);
        }
    }

    private String expirationRenewalKey(String lockName, String threadId) {
        return lockName + "_" + threadId;
    }

    /**
     * 续期
     *
     * @param lockName
     * @param threadId
     */
    private boolean renewExpirationLock(String lockName, String threadId) {
        LockObject lockObject = new LockObject();
        lockObject.setLockName(lockName);
        lockObject.setThreadId(threadId);
        lockObject.setMachineId(machineId);
        lockObject.setExpireTime(new Date(System.currentTimeMillis() + internalLockLeaseTime));
        int updateNum = lockRepository.renewExpirationLock(lockObject);
        return updateNum != 0;
    }
sql 复制代码
<update id="updateRenewExpirationLock">
        update distributed_lock
        set expire_time = #{expireTime,jdbcType=TIMESTAMP}
        where is_deleted = 0
        and is_locked = true
        and lock_name = #{lockName,jdbcType=VARCHAR}
        and machine_id   = #{machineId,jdbcType=VARCHAR}
        and thread_id   = #{threadId,jdbcType=VARCHAR}
        and expire_time &gt; NOW()
    </update>

阻塞式获取锁

阻塞式获取锁和非阻塞的区别就是等待锁释放的过程,没有获取到锁的线程会一直等待下去。

java 复制代码
public void lock(String lockName, long leaseTime, TimeUnit unit) {
        LockEntry lockEntry = null;

        try {
            while (true) {
                // 尝试获取锁
                Long ttl = tryAcquire(lockName, leaseTime, unit);

                if (ttl == null) {
                    // 成功获取到锁,直接退出
                    break;
                }

                // 未获取到锁,订阅锁释放通知(如果还没订阅)
                if (lockEntry == null) {
                    lockEntry = subscribe(lockName, getCurrentThreadId(), () -> {
                    });
                }

                // 等待锁释放通知,直到TTL时间结束
                try {
                    lockEntry.getLatch().tryAcquire(ttl, TimeUnit.MILLISECONDS);
                } catch (InterruptedException e) {
                    // 恢复线程的中断状态
                    Thread.currentThread().interrupt();
                    throw new RuntimeException("Thread was interrupted while waiting for the lock", e);
                }
            }
        } finally {
            // 确保在退出时释放锁并取消订阅
            if (lockEntry != null) {
                unsubscribe(lockEntry, getCurrentThreadId());
            }
        }
    }

解锁

获取锁的线程释放锁的时候,state会减1,直到减到0,锁才会真正的释放。这里需要移除锁续约的任务,并且唤醒等待当前锁的线程

java 复制代码
public void unlock(String lockName) {
        if (releaseLock(lockName)) {
            //释放锁成功后去除看门狗的续期
            //如果解锁失败,比如自己获取到锁过期了,然后又去释放锁,因为他没有续约任务所以不需要移除
            cancelExpirationRenewal(lockName, getCurrentThreadId());

            //发送锁释放的通知
            // 这里只处理本机维护的等待锁的线程,其他的机器数据库没法主动发出通知,需要轮训或者由获取锁的线程下次获取锁时自行处理
            LockEntry lockEntry = subscribeMap.get(lockName);
            //要判空,因为如果没有阻塞中的线程,那么lockEntry会为空
            if (lockEntry != null) {
                Semaphore semaphore = lockEntry.getLatch();
                if (semaphore.hasQueuedThreads()) {
                    semaphore.release();
                }
            }
        }
    }
sql 复制代码
<update id="updateReleaseLock">
        update distributed_lock
        <set>
            state       = state - 1,
            expire_time = if(state=0, null, expire_time),
            is_locked   = if(state=0, false, true),
            machine_id   = if(state=0, null, machine_id),
            thread_id   = if(state=0, null, thread_id),
        </set>
        where is_deleted = 0
        and lock_name = #{lockName,jdbcType=VARCHAR}
        and machine_id   = #{machineId,jdbcType=VARCHAR}
        and thread_id   = #{threadId,jdbcType=VARCHAR}
        and expire_time &gt; NOW()
        and is_locked = true
    </update>

检查锁是否过期或者释放

因为mysql数据库没有发布订阅的功能,所以这里采用了定时查询的模式检查锁的状态。如果检测到锁释放了,会发起唤醒等待锁线程的通知,让等待的线程重新尝试获取锁。

java 复制代码
public void process() {
        scheduledExecutor.scheduleAtFixedRate(() -> {
            //执行本机订阅这把锁的检查任务
            List<String> needCheckLockNameList = subscribeMap.entrySet().stream()
                    .filter(entry -> entry.getValue().getCounter().get() != 0)
                    .map(entry -> entry.getKey())
                    .collect(Collectors.toList());
            //查询已经过期或者释放的锁
            List<String> lockNameList = lockRepository.queryAllowObtainLockList(needCheckLockNameList);

            //执行对应锁的唤醒操作
            lockNameList.forEach(lockName -> {
                LockEntry lockEntry = subscribeMap.get(lockName);
                if (lockEntry != null) {
                    //这里最多多唤醒一次,无非就是让等待线程多抢占一次,没什么关系,这种场景发生在tryAcquire正好过期,定时任务正好运行
                    //多一次判断可以大幅度减少冲突时多释放的信号
                    Semaphore semaphore = lockEntry.getLatch();
                    if (semaphore.hasQueuedThreads()) {
                        semaphore.release();
                        log.info("定时任务发起唤醒等待锁的通知");
                    }
                }
            });
        }, 0, 1, TimeUnit.SECONDS);
    }
sql 复制代码
<select id="queryAllowObtainLockList" resultType="java.lang.String">
        select lock_name
        from distributed_lock
        where is_deleted = 0
          and lock_name in
        <foreach collection="list" item="lockName" open="(" close=")" separator=",">
            #{lockName,jdbcType=VARCHAR}
        </foreach>
        and (
                        is_locked = false
                    or expire_time &lt; NOW()
                )
    </select>

使用示例

java 复制代码
public static void main(String[] args) {

        // 第一个Spring容器,加载配置类 Config1
        ApplicationContext context1 = new AnnotationConfigApplicationContext(MybatisPlusConfig.class);

        // 第二个Spring容器,加载配置类 Config2
        ApplicationContext context2 = new AnnotationConfigApplicationContext(MybatisPlusConfig.class);

        DatabaseDistributedLock server1 = context1.getBean(DatabaseDistributedLock.class);
        DatabaseDistributedLock server2 = context2.getBean(DatabaseDistributedLock.class);

        server1.lock("test");
        new Thread(() -> {
            ThreadUtil.sleep(1, TimeUnit.SECONDS);
            if (server2.tryLock("test", 17L, TimeUnit.SECONDS)) {
                System.out.println("我执行了1");
                ThreadUtil.sleep(5, TimeUnit.SECONDS);
                server2.unlock("test");
            }
        }).start();
        new Thread(() -> {
            ThreadUtil.sleep(2, TimeUnit.SECONDS);
            if (server1.tryLock("test", 17L, TimeUnit.SECONDS)) {
                System.out.println("我执行了2");
                ThreadUtil.sleep(5, TimeUnit.SECONDS);
                server1.unlock("test");
            }
        }).start();
        System.out.println("我获取到了锁");
        ThreadUtil.sleep(15, TimeUnit.SECONDS);
        server1.unlock("test");
        ThreadUtil.sleep(100, TimeUnit.SECONDS);
    }

优化方案

订阅通知如果有消息队列的话,可以借助用来实现发布订阅锁通知