1、 锁状态管理
java
private static final ThreadLocal<Map<String, LockEntry>> currentThreadLocks =
ThreadLocal.withInitial(ConcurrentHashMap::new);
ThreadLocal存储当前线程持有的锁信息
结构:Map<锁Key, LockEntry>
LockEntry包含:
-
锁值(lockValue)
-
重入计数器(AtomicInteger)
-
租约时间(leaseTime)
2、 续期机制
java
private ScheduledExecutorService executorService;
private final ConcurrentHashMap<String, ScheduledFuture<?>> renewalTasks = new ConcurrentHashMap<>();
-
ScheduledExecutorService:续期任务线程池
-
renewalTasks:存储每个锁的续期任务
-
续期间隔 = 租约时间 / 3
java
executorService.scheduleAtFixedRate(() -> {
String currentValue = redisTemplate.opsForValue().get(key);
if (value.equals(currentValue)) {
Boolean expire = redisTemplate.expire(key, leaseTime, TimeUnit.MILLISECONDS);
if (!Boolean.TRUE.equals(expire)) {
LOG.warn("锁续期失败: {}", key);
stopRenewalTask(key);
}
} else {
stopRenewalTask(key);
}
}, renewalInterval, renewalInterval, TimeUnit.MILLISECONDS);
3、线程工厂
java
private static class RenewalThreadFactory implements ThreadFactory {
private final AtomicInteger threadNumber = new AtomicInteger(1);
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(r, "redis-lock-renewal-" + threadNumber.getAndIncrement());
t.setDaemon(true);
return t;
}
}
-
自定义线程命名:redis-lock-renewal-1
-
设置为守护线程(daemon)
4、代码
MyRedisLock
java
package com.redislock;
import io.netty.util.concurrent.DefaultThreadFactory;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.stereotype.Component;
import java.lang.management.LockInfo;
import java.util.Collections;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
/**
* @author xiaoman
* @date 2025/7/23 23:08
*/
@Component
public class MyRedisLock {
private static final Logger LOG = LogManager.getLogger(MyRedisLock.class);
@Autowired
private RedisTemplate<String, String> redisTemplate;
/**
* 存储当前线程的锁信息(锁key -> 重入次数)
*/
private static final ThreadLocal<Map<String, LockEntry>> currentThreadLocks =
ThreadLocal.withInitial(ConcurrentHashMap::new);
/**
* 续期任务执行器
*/
private ScheduledExecutorService executorService;
/**
* 存储每个锁的续期任务(锁key -> 续期任务Future)
*/
private final ConcurrentHashMap<String, ScheduledFuture<?>> renewalTasks = new ConcurrentHashMap<>();
@PostConstruct
public void init() {
// 初始化线程池
this.executorService = Executors.newScheduledThreadPool(
4, new RenewalThreadFactory());
}
@PreDestroy
public void destroy() {
// 关闭线程池
executorService.shutdown();
try {
if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) {
executorService.shutdownNow();
}
} catch (InterruptedException e) {
executorService.shutdownNow();
Thread.currentThread().interrupt();
}
}
/**
* 获取锁
* @param key 锁的key
* @param waitTime 等待时间
* @param leaseTime 过期时间
* @return
*/
public String tryLock(String key, long waitTime, long leaseTime) {
long endTime = System.currentTimeMillis() + waitTime;
Map<String, LockEntry> threadLocks = currentThreadLocks.get();
LockEntry lockEntry = threadLocks.get(key);
if (lockEntry != null) {
lockEntry.incrementCount();
return lockEntry.getLockValue();
}
String res = null;
while ((res = lock(key, leaseTime)) == null && endTime > System.currentTimeMillis()) {
try {
long sleepTime = 100 + (long)(Math.random() * 50);
TimeUnit.MILLISECONDS.sleep(sleepTime);
} catch (InterruptedException e) {
LOG.error("休眠失败", e);
throw new RuntimeException("休眠失败");
}
}
if (res != null) {
threadLocks.put(key, new LockEntry(res, new AtomicInteger(1), leaseTime));
startRenewalTask(key, res, leaseTime);
}
return res;
}
/**
* 自动续期任务
* @param key
* @param value
* @param leaseTime
*/
private void startRenewalTask(String key, String value, long leaseTime) {
// 续期间隔为租约时间的1/3
long renewalInterval = leaseTime / 3;
ScheduledFuture<?> future = executorService.scheduleAtFixedRate(() -> {
try {
String lockValue = redisTemplate.opsForValue().get(key);
// 检查锁是否还存在且值匹配
if (value.equals(lockValue)) {
Boolean expire = redisTemplate.expire(key, leaseTime, TimeUnit.MILLISECONDS);
if (!Boolean.TRUE.equals(expire)) {
LOG.warn("锁续期失败: {}", key);
stopRenewalTask(key);
}
} else {
// 锁已经被释放或占用,取消续期
stopRenewalTask(key);
}
} catch (Exception e) {
LOG.warn("锁续期失败: {}", key);
stopRenewalTask(key);
}
}, renewalInterval, renewalInterval, TimeUnit.MILLISECONDS);
renewalTasks.put(key, future);
}
/**
* 停止续期任务
* @param key
*/
private void stopRenewalTask(String key) {
ScheduledFuture<?> future = renewalTasks.remove(key);
if (future != null) {
future.cancel(false);
}
}
/**
* lock
* @param key 锁的key
* @param leaseTime 过期时间
* @return 锁的唯一标识
*/
private String lock(String key, long leaseTime) {
String lockValue = UUID.randomUUID().toString();
Boolean success = redisTemplate.opsForValue().setIfAbsent(key, lockValue, leaseTime, TimeUnit.MILLISECONDS);
return Boolean.TRUE.equals(success) ? lockValue : null;
}
/**
* unlock
* @param key 锁的key
* @param value 锁的唯一标识
* @return 是否成功释放
*/
public boolean unlock(String key, String value) {
if (value == null || value.isEmpty()) {
throw new IllegalMonitorStateException("当前线程未持有锁: " + key);
}
Map<String, LockEntry> entryMap = currentThreadLocks.get();
LockEntry lockEntry = entryMap.get(key);
if (lockEntry == null || !value.equals(lockEntry.getLockValue())) {
throw new IllegalMonitorStateException("当前线程未持有锁: " + key);
}
int count = lockEntry.decrementCount();
if (count > 0) {
// 还有重入计数,不释放锁
return false;
}
entryMap.remove(key);
if (entryMap.isEmpty()) {
currentThreadLocks.remove();
}
stopRenewalTask(key);
// 执行Redis解锁操作
return doUnlock(key, lockEntry.getLockValue());
}
private boolean doUnlock(String key, String lockValue) {
String script = "if redis.call('get', KEYS[1]) == ARGV[1] then" +
" return redis.call('del', KEYS[1]) " +
"else " +
" return 0" +
"end";
DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>(script, Long.class);
Long result = (Long) redisTemplate.execute(
redisScript,
Collections.singletonList(key),
lockValue
);
return result != null && result == 1;
}
/**
* 续期线程工厂(命名线程)
*/
private static class RenewalThreadFactory implements ThreadFactory {
private final AtomicInteger threadNumber = new AtomicInteger(1);
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(r, "redis-lock-renewal-" + threadNumber.getAndIncrement());
t.setDaemon(true);
return t;
}
}
}
LockEntry
java
package com.redislock;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.springframework.beans.factory.annotation.Autowired;
import java.util.concurrent.atomic.AtomicInteger;
/**
* @author xiaoman
* @date 2025/7/24 21:52
*/
@NoArgsConstructor
@Data
@AllArgsConstructor
public class LockEntry {
/**
* lockValue
*/
private String lockValue;
/**
* 重入次数
*/
private AtomicInteger count = new AtomicInteger(0);
/**
* 过期时间
*/
private Long lessTime;
/**
* 加一
* @return
*/
public int incrementCount() {
return count.incrementAndGet();
}
/**
* 减一
* @return
*/
public int decrementCount() {
return count.decrementAndGet();
}
}