Redis实现分布式锁

1、 锁状态管理

java 复制代码
private static final ThreadLocal<Map<String, LockEntry>> currentThreadLocks =
        ThreadLocal.withInitial(ConcurrentHashMap::new);

ThreadLocal存储当前线程持有的锁信息

结构:Map<锁Key, LockEntry>

LockEntry包含:

  • 锁值(lockValue)

  • 重入计数器(AtomicInteger)

  • 租约时间(leaseTime)

2、 续期机制

java 复制代码
private ScheduledExecutorService executorService;
private final ConcurrentHashMap<String, ScheduledFuture<?>> renewalTasks = new ConcurrentHashMap<>();
  • ScheduledExecutorService:续期任务线程池

  • renewalTasks:存储每个锁的续期任务

  • 续期间隔 = 租约时间 / 3

java 复制代码
executorService.scheduleAtFixedRate(() -> {
    String currentValue = redisTemplate.opsForValue().get(key);
    if (value.equals(currentValue)) {
        Boolean expire = redisTemplate.expire(key, leaseTime, TimeUnit.MILLISECONDS);
        if (!Boolean.TRUE.equals(expire)) {
            LOG.warn("锁续期失败: {}", key);
            stopRenewalTask(key);
        }
    } else {
        stopRenewalTask(key);
    }
}, renewalInterval, renewalInterval, TimeUnit.MILLISECONDS);

3、线程工厂

java 复制代码
    private static class RenewalThreadFactory implements ThreadFactory {
        private final AtomicInteger threadNumber = new AtomicInteger(1);

        @Override
        public Thread newThread(Runnable r) {
            Thread t = new Thread(r, "redis-lock-renewal-" + threadNumber.getAndIncrement());
            t.setDaemon(true);
            return t;
        }
    }
  • 自定义线程命名:redis-lock-renewal-1

  • 设置为守护线程(daemon)

4、代码

MyRedisLock

java 复制代码
package com.redislock;

import io.netty.util.concurrent.DefaultThreadFactory;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.stereotype.Component;

import java.lang.management.LockInfo;
import java.util.Collections;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author xiaoman
 * @date 2025/7/23 23:08
 */
@Component
public class MyRedisLock {

    private static final Logger LOG = LogManager.getLogger(MyRedisLock.class);

    @Autowired
    private RedisTemplate<String, String> redisTemplate;

    /**
     * 存储当前线程的锁信息(锁key -> 重入次数)
     */
    private static final ThreadLocal<Map<String, LockEntry>> currentThreadLocks =
            ThreadLocal.withInitial(ConcurrentHashMap::new);


    /**
     * 续期任务执行器
     */
    private ScheduledExecutorService executorService;

    /**
     * 存储每个锁的续期任务(锁key -> 续期任务Future)
     */
    private final ConcurrentHashMap<String, ScheduledFuture<?>> renewalTasks = new ConcurrentHashMap<>();


    @PostConstruct
    public void init() {
        // 初始化线程池
        this.executorService = Executors.newScheduledThreadPool(
                4, new RenewalThreadFactory());
    }

    @PreDestroy
    public void destroy() {
        // 关闭线程池
        executorService.shutdown();
        try {
            if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) {
                executorService.shutdownNow();
            }
        } catch (InterruptedException e) {
            executorService.shutdownNow();
            Thread.currentThread().interrupt();
        }
    }



    /**
     * 获取锁
     * @param key 锁的key
     * @param waitTime 等待时间
     * @param leaseTime 过期时间
     * @return
     */
    public String tryLock(String key, long waitTime, long leaseTime) {

        long endTime = System.currentTimeMillis() + waitTime;
        Map<String, LockEntry> threadLocks = currentThreadLocks.get();
        LockEntry lockEntry = threadLocks.get(key);
        if (lockEntry != null) {
            lockEntry.incrementCount();
            return lockEntry.getLockValue();
        }
        String res = null;
        while ((res = lock(key, leaseTime)) == null && endTime > System.currentTimeMillis()) {
            try {
                long sleepTime = 100 + (long)(Math.random() * 50);
                TimeUnit.MILLISECONDS.sleep(sleepTime);
            } catch (InterruptedException e) {
                LOG.error("休眠失败", e);
                throw new RuntimeException("休眠失败");
            }
        }
        if (res != null) {
            threadLocks.put(key, new LockEntry(res, new AtomicInteger(1), leaseTime));
            startRenewalTask(key, res, leaseTime);
        }
        return res;
    }

    /**
     * 自动续期任务
     * @param key
     * @param value
     * @param leaseTime
     */
    private void startRenewalTask(String key, String value, long leaseTime) {
        // 续期间隔为租约时间的1/3
        long renewalInterval = leaseTime / 3;
        ScheduledFuture<?> future = executorService.scheduleAtFixedRate(() -> {
            try {
                String lockValue = redisTemplate.opsForValue().get(key);

                // 检查锁是否还存在且值匹配
                if (value.equals(lockValue)) {
                    Boolean expire = redisTemplate.expire(key, leaseTime, TimeUnit.MILLISECONDS);
                    if (!Boolean.TRUE.equals(expire)) {
                        LOG.warn("锁续期失败: {}", key);
                        stopRenewalTask(key);
                    }
                } else {
                    // 锁已经被释放或占用,取消续期
                    stopRenewalTask(key);
                }
            } catch (Exception e) {
                LOG.warn("锁续期失败: {}", key);
                stopRenewalTask(key);
            }
        }, renewalInterval, renewalInterval, TimeUnit.MILLISECONDS);

        renewalTasks.put(key, future);
    }

    /**
     * 停止续期任务
     * @param key
     */
    private void stopRenewalTask(String key) {
        ScheduledFuture<?> future = renewalTasks.remove(key);
        if (future != null) {
            future.cancel(false);
        }
    }

    /**
     * lock
     * @param key 锁的key
     * @param leaseTime 过期时间
     * @return 锁的唯一标识
     */
    private String lock(String key, long leaseTime) {
        String lockValue = UUID.randomUUID().toString();
        Boolean success = redisTemplate.opsForValue().setIfAbsent(key, lockValue, leaseTime, TimeUnit.MILLISECONDS);
        return Boolean.TRUE.equals(success) ? lockValue : null;
    }

    /**
     * unlock
     * @param key 锁的key
     * @param value 锁的唯一标识
     * @return 是否成功释放
     */
    public boolean unlock(String key, String value) {
        if (value == null || value.isEmpty()) {
            throw new IllegalMonitorStateException("当前线程未持有锁: " + key);
        }

        Map<String, LockEntry> entryMap = currentThreadLocks.get();
        LockEntry lockEntry = entryMap.get(key);
        if (lockEntry == null || !value.equals(lockEntry.getLockValue())) {
            throw new IllegalMonitorStateException("当前线程未持有锁: " + key);
        }
        int count = lockEntry.decrementCount();

        if (count > 0) {
            // 还有重入计数,不释放锁
            return false;
        }
        entryMap.remove(key);
        if (entryMap.isEmpty()) {
            currentThreadLocks.remove();
        }

        stopRenewalTask(key);

        // 执行Redis解锁操作
        return doUnlock(key, lockEntry.getLockValue());
    }

    private boolean doUnlock(String key, String lockValue) {
        String script = "if redis.call('get', KEYS[1]) == ARGV[1] then" +
                "   return redis.call('del', KEYS[1]) " +
                "else " +
                "   return 0" +
                "end";
        DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>(script, Long.class);
        Long result = (Long) redisTemplate.execute(
                redisScript,
                Collections.singletonList(key),
                lockValue
        );
        return result != null && result == 1;
    }

    /**
     * 续期线程工厂(命名线程)
     */
    private static class RenewalThreadFactory implements ThreadFactory {
        private final AtomicInteger threadNumber = new AtomicInteger(1);

        @Override
        public Thread newThread(Runnable r) {
            Thread t = new Thread(r, "redis-lock-renewal-" + threadNumber.getAndIncrement());
            t.setDaemon(true);
            return t;
        }
    }
}

LockEntry

java 复制代码
package com.redislock;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.springframework.beans.factory.annotation.Autowired;

import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author xiaoman
 * @date 2025/7/24 21:52
 */

@NoArgsConstructor
@Data
@AllArgsConstructor
public class LockEntry {


    /**
     * lockValue
     */
    private String lockValue;

    /**
     * 重入次数
     */
    private AtomicInteger count = new AtomicInteger(0);

    /**
     * 过期时间
     */
    private Long lessTime;

    /**
     * 加一
     * @return
     */
    public int incrementCount() {
        return count.incrementAndGet();
    }

    /**
     * 减一
     * @return
     */
    public int decrementCount() {
        return count.decrementAndGet();
    }
}