导读
- 理解分布式锁的使用场景以及分布式锁实现过程中的一些常见问题
- 掌握如何使用 Redis 实现分布式锁
代码版本
- JDK运行版本:JDK8
引言
- 在系统中修改已有数据时,需要先读取,然后进行修改保存,由于读取、修改和写入不是原子操作,在并发场景下,部分对数据的操作可能会丢失。在单服务器系统我们常用本地锁来避免并发带来的问题,然而,当服务采用集群方式部署时,本地锁无法在多个服务器之间生效,这时候保证数据的一致性就需要分布式锁来实现。

实现

加锁和解锁(如何保证原子性)
- 加锁我们通常使用 setnx + expire 或者 set 方法实现(set 具备 setnx 的所有能力推荐使用 set)。
lua
未来的版本中 setnx 可能被移除:
因为 SET 命令可以通过参数来实现 SETNX 、 SETEX 以及 PSETEX 命令的效果, 所以 Redis 将来的版本可能会移除并废弃 SETNX 、 SETEX 和 PSETEX 这三个命令。
http://redisdoc.com/string/set.html
// 加锁 setnx
redis.call('setnx', KEYS[1], ARGV[1])
// 获取锁后 设置锁的失效时间为 30s 避免线程意外挂掉之后无法释放
// 若线程未释放锁每20s执行一次续期操作(需要定时任务)
redis.call('expire', KEYS[1], 30)
上述操作无法保证原子操作,我们需要使用lua脚本来实现:用户可以向服务器发送 lua 脚本来执行自定义动作,获取脚本的响应数据。
Redis 服务器会单线程原子性执行 lua 脚本,保证 lua 脚本在处理的过程中不会被任意其它请求打断。
虽然不会被其它请求打断,但比如语法错误、参数类型错误、服务器宕机等会导致执行中断,返回错误。
if (redis.call('setnx', KEYS[1], ARGV[1]) < 1)
then return 0;
end;
redis.call('expire', KEYS[1], tonumber(ARGV[2]));
return 1;
// 加锁 set
上述操作等同于:SET key "value" EX 1000 NX
- 解锁我们通过使用 del 方法删除锁键值,以便其它线程可以获取锁
lua
redis.call('del',KEYS[1])
存在的问题
- 上面是最基本的实现,但实际的复杂情况中存在以下几个问题:
加锁 - 如何等待锁释放
- 抢占锁时,如果我们第一次未未成功获取到锁,这时候需要等到其它线程释放锁,实现方案和本地锁的实现逻辑基本一致,可以进行轮询或者线程挂起,结合分布式场景可以使用推荐 发布订阅 + 线程阻塞唤醒的方式。

加锁 - 如何实现可重入
- 当一个线程持有锁的情况下再次加锁,如果支持再次加锁,那么说明这个锁时可重入的。
- 分布式场景下,实现锁重入的逻辑和本地锁也基本一致,但重入次数的存储有两种选择,一种选择时放在本地性能更高,但会增加本地加锁解锁的实现成功,而是存储在远端,但性能会稍差。
本地存储
- 本地存储可以使用 ThreadLocal 实现
Java
/**
* 本地线程可重入计数
*/
private static final ThreadLocal<Map<String, Integer>> LOCK_COUNT = ThreadLocal.withInitial(HashMap::new);
// 加锁
// 可重入:锁已被占用且持有线程为当前线程
if (result == 0 && LOCK_COUNT.get().getOrDefault(lockKey, 0) > 0) {
LOCK_COUNT.get().put(lockKey, LOCK_COUNT.get().get(lockKey) + 1);
return;
}
// 解锁 若当前持有锁标记不为 0 重入次数-1
LOCK_COUNT.get().put(lockKey, remainLock);
if (remainLock > 0) {
return;
}
远端存储
- 结合 lua 脚本实现:
lua
加锁流程:
1、判断锁是否被持有
2、若未被持有,获取锁并设置过期时间、重入次数初始化为 1
3、若被持有,判断是否为当前线程持有,若为当前线程持有重入次数 +1
4、其它情况获取锁失败
if (redis.call('HEXISTS',KEYS[1],'LOCK_THREAD') == 0) then
redis.call('HSET',KEYS[1],'LOCK_THREAD',ARGV[1])
redis.call('HSET',KEYS[1],'COUNT',1)
redis.call('expire',KEYS[1],tonumber(ARGV[2]))
return 1;
elseif (redis.call('HGET',KEYS[1],'LOCK_THREAD') == ARGV[1]) then
redis.call('HSET',KEYS[1],'COUNT',redis.call('HGET',KEYS[1],'COUNT') + 1)
return 1;
else
return 0;
end;
加锁 - 如何保证锁持有的高可用
- 一般有两种可选方案,但都依赖于 Redis 集群,一种是我们常说的级联锁,即在多台服务器上同时加锁,第二种是仅依赖于 Redis 的高可用机制,相比而言第一种可靠性更高,但需要更多的性能损耗。
解锁 - 如何解决线程挂掉或服务器宕机时锁未释放
-
若线程或服务突然挂机,而锁未释放,会导致其它线程无法获取锁。
-
我们可以通过设置持有锁失效时间避免该问题:
1、锁有失效时间,则直接设置失效时间。
2、锁未设置失效时间,则设置初始失效时间为30,每20s 定时执行续期任务(2/3提前续期提高续期成功率),判断锁是否是当前线程持有并进行续期失效时间为30s(两个操作需要保证原子性) -
即使发生线程挂掉或服务器宕机,最坏情况仅在初始失效时间这段时间内不可获取锁。

解锁 - 如何解决锁误解除、误续期问题
- 思考以下场景,A的有效时间为 30s 但方法的执行时间,远超 30s,30s 后实际上此时相当于锁已经释放此时B获取到锁,A执行完后,尝试释放锁,会导致将B的锁错误释放。

- 因此,可以在抢占锁时设置持有锁线程信息在远端,当锁进行释放和续期时判断锁是否是当前线程持有,若不是,则不进行操作。
lua
// 解锁
1、判断当前锁是否由当前线程持有
2、若为当前线程持有,重入次数减 1,并判断剩余重入次数是否大于 0 ,若大于 0,则更新重入次数,否则释放锁
3、若不为当前线程持有不进行操作
if (redis.call('HGET',KEYS[1],'LOCK_THREAD') == ARGV[1]) then
local remain_count = redis.call('HGET',KEYS[1],'COUNT') -1 ;
if (remain_count > 0) then
redis.call('HSET',KEYS[1],'COUNT',remain_count)
return 0;
end;
redis.call('del',KEYS[1]);
return 1;
end;
return -1;
// 续期
判断持有锁线程是否当前线程,若是,进行续期,否则返回 0
if (redis.call('HGET',KEYS[1],'LOCK_THREAD') == ARGV[1]) then
redis.call('expire',KEYS[1],tonumber(ARGV[2]))
return 1;
end;
return 0;
完整代码
- 使用 lua 脚本保证原子性,最终实现因需要存储锁的持有线程和重入次数,使用 哈希表 结构实现
Lua 脚本
lua
加锁流程:
1、判断锁是否被持有
2、若未被持有,获取锁并设置过期时间、重入次数初始化为 1
3、若被持有,判断是否为当前线程持有,若为当前线程持有重入次数 +1
4、其它情况获取锁失败
if (redis.call('HEXISTS',KEYS[1],'LOCK_THREAD') == 0) then
redis.call('HSET',KEYS[1],'LOCK_THREAD',ARGV[1])
redis.call('HSET',KEYS[1],'COUNT',1)
redis.call('expire',KEYS[1],tonumber(ARGV[2]))
return 1;
elseif (redis.call('HGET',KEYS[1],'LOCK_THREAD') == ARGV[1]) then
redis.call('HSET',KEYS[1],'COUNT',redis.call('HGET',KEYS[1],'COUNT') + 1)
return 1;
else
return 0;
end;
KEYS[1] 锁 key 值
ARGV[1] 持有锁线程
ARGV[2] 锁失效时间
解锁流程:
1、判断当前锁是否由当前线程持有
2、若为当前线程持有,重入次数减 1,并判断剩余重入次数是否大于 0 ,若大于 0,则更新重入次数,否则释放锁
3、若不为当前线程持有不进行操作
if (redis.call('HGET',KEYS[1],'LOCK_THREAD') == ARGV[1]) then
local remain_count = redis.call('HGET',KEYS[1],'COUNT') -1 ;
if (remain_count > 0) then
redis.call('HSET',KEYS[1],'COUNT',remain_count)
return 0;
end;
redis.call('del',KEYS[1]);
return 1;
end;
return -1;
KEYS[1] 锁 key 值
ARGV[1] 持有锁线程
续期流程:
判断持有锁线程是否当前线程,若是,进行续期,否则返回 0
if (redis.call('HGET',KEYS[1],'LOCK_THREAD') == ARGV[1]) then
redis.call('expire',KEYS[1],tonumber(ARGV[2]))
return 1;
end;
return 0;
KEYS[1] 锁 key 值
ARGV[1] 持有锁线程
ARGV[2] 锁失效时间
Java
Java
public class RedisDistributedLockTest {
static LockHelper lockHelper = new LockHelper();
private static int shareValue = 0;
private static final String LOCK_STR = "LOCK&STR";
public static void main(String[] args) throws Exception {
// 模拟 service A
Thread serviceA = new Thread(() -> {
RedisDistributedLock lock;
try {
lock = lockHelper.getLock(LOCK_STR);
for (int i = 0; i < 5000; i++) {
lock.lock();
System.out.println("serviceA 获取锁");
lock.lock();
System.out.println("serviceA 获取锁");
if (i == 1) {
// 验证续期
Thread.sleep(50000);
}
increment();
// 可重复入验证
lock.unlock();
System.out.println("serviceA 释放锁,a = " + shareValue);
lock.unlock();
System.out.println("serviceA 释放锁,a = " + shareValue);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
});
// 模拟 service B
Thread serviceB = new Thread(() -> {
RedisDistributedLock lock;
try {
lock = lockHelper.getLock(LOCK_STR);
for (int i = 0; i < 5000; i++) {
lock.lock();
System.out.println("ServiceB 获取锁");
increment();
lock.unlock();
System.out.println("ServiceB 释放锁,a = " + shareValue);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
});
serviceA.start();
serviceB.start();
serviceA.join();
serviceB.join();
System.out.println("shareValue:" + shareValue);
}
public static void increment() {
shareValue = shareValue + 1;
}
}
interface Lock {
void lock() throws Exception;
void unlock() throws Exception;
}
class RedisDistributedLock implements Lock {
/**
* 当前线程 LockKey 和 RedisDistributedLock
* key :uniqueThreadKey_lockKey
* value :RedisDistributedLock
*/
static ConcurrentHashMap<String, Lock> lockMap = new ConcurrentHashMap<>();
/**
* 存储当前正持有锁的 uniqueThreadKey_lockKey
* 用于锁续期
*/
static List<String> usingThreadKeyList = new CopyOnWriteArrayList<>();
/**
* Redis 互斥状态 key
*/
private final String lockKey;
/**
* 线程唯一ID
*/
private final String uniqueThreadKey;
private final JedisPool jedisPool;
public RedisDistributedLock(JedisPool jedisPool, String lockKey, String uniqueThreadKey) {
this.jedisPool = jedisPool;
this.lockKey = lockKey;
this.uniqueThreadKey = uniqueThreadKey;
}
@Override
public void lock() {
Jedis jedis = jedisPool.getResource();
long result = setKeyAtomic(jedis);
// 获取锁失败轮询尝试获取锁
// 可以优化为发布订阅的模式避免无效轮询操作
while (result == 0) {
/* result = jedis.setnx(lockKey, String.valueOf(1));
// 获取锁后 设置锁的失效时间为 30s 避免线程意外挂掉之后无法释放 若线程未释放锁 每 20 s 执行一次续期操作
jedis.expire(lockKey, 30);
// 但两个操作非原子性 若失效时间未设置成功仍会出现线程意外死亡无法释放的问题 需要使用 lua 脚本保证原子性*/
result = setKeyAtomic(jedis);
}
// 锁重入时不重复加入续期列表
if (!usingThreadKeyList.contains(lockKey + "_" + uniqueThreadKey)) {
usingThreadKeyList.add(lockKey + "_" + uniqueThreadKey);
}
jedis.close();
}
/**
* @param jedis
* @return
*/
private Long setKeyAtomic(Jedis jedis) {
// Lua 脚本
String luaScript = "if (redis.call('HEXISTS',KEYS[1],'LOCK_THREAD') == 0) then\n" +
"\tredis.call('HSET',KEYS[1],'LOCK_THREAD',ARGV[1])\n" +
"\tredis.call('HSET',KEYS[1],'COUNT',1) \n" +
"\tredis.call('expire',KEYS[1],tonumber(ARGV[2]))\n" +
"\treturn 1;\n" +
"elseif (redis.call('HGET',KEYS[1],'LOCK_THREAD') == ARGV[1]) then\n" +
"\tredis.call('HSET',KEYS[1],'COUNT',redis.call('HGET',KEYS[1],'COUNT') + 1)\n" +
" return 1;\n" +
"else\n" +
"return 0;\n" +
"end;";
// 执行脚本
return (Long) jedis.eval(luaScript, Collections.singletonList(lockKey), Arrays.asList(uniqueThreadKey, "30"));
}
@Override
public void unlock() {
Jedis jedis = jedisPool.getResource();
// 判断锁是否当前线程持有和释放锁需要原子操作(如 判断锁时正持有,正好失效被其它线程占用会造成锁误解除)
/* if (jedis.del(lockKey) == 0) {
// 归还连接池
jedis.close();
throw new Exception("there is no lock can be unlocked");
}*/
delKeyAtomic(jedis);
lockMap.remove(uniqueThreadKey + "_" + lockKey);
// 归还连接池
jedis.close();
}
private Long delKeyAtomic(Jedis jedis) {
// Lua 脚本
String luaScript = "if (redis.call('HGET',KEYS[1],'LOCK_THREAD') == ARGV[1]) then\n" +
" local remain_count = redis.call('HGET',KEYS[1],'COUNT') -1 ;\n" +
" if (remain_count > 0) then \n" +
" redis.call('HSET',KEYS[1],'COUNT',remain_count) \n" +
" return 0;\n" +
" end;\n" +
" redis.call('del',KEYS[1]);\n" +
" return 1;\n" +
"end;\n" +
"return -1;";
// 执行脚本
return (Long) jedis.eval(luaScript, Collections.singletonList(lockKey), Collections.singletonList(uniqueThreadKey));
}
}
class LockHelper {
private static JedisPool jedisPool = null;
public LockHelper() {
JedisPoolConfig poolConfig = new JedisPoolConfig();
poolConfig.setMaxTotal(10);
poolConfig.setMaxWaitMillis(1000);
jedisPool = new JedisPool(poolConfig, "127.0.0.1", 6379, 2000, "1dBzlYuz1uC4t8mjqskxivO");
System.out.println("连接 Redis 服务成功");
// 定时任务锁续期 获取持有锁的 key + 持有线程 进行续期
// 如何保证 redis 命令阻塞或 命令较多时执行成功 => 剩余 1/3 的时间就进行重试 + 重试
Thread timerRenewalThread = new Thread(() -> {
Jedis jedis = jedisPool.getResource();
while (true) {
for (String lockStr : RedisDistributedLock.usingThreadKeyList) {
String renewalLua = "if (redis.call('HGET',KEYS[1],'LOCK_THREAD') == ARGV[1]) then\n" +
" redis.call('expire',KEYS[1],tonumber(ARGV[2]))\n" +
" return 1;\n" +
"end;\n" +
"return 0;\n";
long result = (long) jedis.eval(renewalLua, Collections.singletonList(lockStr.split("_")[0]), Arrays.asList(lockStr.split("_")[1], "30"));
if (result == 0) {
RedisDistributedLock.usingThreadKeyList.remove(lockStr);
}
}
System.out.println("续期任务执行完成进入休眠");
// 休眠 20s 每 20s 执行一次
try {
Thread.sleep(20000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
});
timerRenewalThread.setDaemon(true);
timerRenewalThread.start();
System.out.println("续期守护线程启动成功");
}
RedisDistributedLock getLock(String lockKey) throws Exception {
/**
* 用于存储每个线程获取的 LockKey => RedisDistributedLock 关系
*/
String uniqueThreadKey = UUID.randomUUID().toString();
if (RedisDistributedLock.lockMap.containsKey(lockKey + "_" + uniqueThreadKey)) {
throw new Exception("Lock is already exist");
}
RedisDistributedLock distributedLock = new RedisDistributedLock(jedisPool, lockKey, uniqueThreadKey);
RedisDistributedLock.lockMap.put(lockKey + "_" + uniqueThreadKey, distributedLock);
return distributedLock;
}
}