简单实现一个分布式锁

一,背景

在日常项目中经常会涉及到一些秒杀的情况,需要保证不超卖,库存等数据正确,这时候就需要引入分布式锁了。下面介绍基于redisson的分布式锁实现。

二,实现

1,引入依赖
复制代码
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-web</artifactId>
</dependency>
复制代码
<dependency>
    <groupId>org.aspectj</groupId>
    <artifactId>aspectjweaver</artifactId>
</dependency>
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-aop</artifactId>
</dependency>
复制代码
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
复制代码
<dependency>
    <groupId>org.redisson</groupId>
    <artifactId>redisson-spring-boot-starter</artifactId>
    <exclusions>
        <exclusion>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
        </exclusion>
    </exclusions>
</dependency>
2,添加注解
java 复制代码
import com.xxx.common.enumeration.LockFailStrategyEnum;
import com.xxx.common.enumeration.LockTypeEnum;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.util.concurrent.TimeUnit;


@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface DistributedLock {
    /**
     * 用来区分场景,可以为空
     */
    String prefix() default "";
    /**
     * 锁的key,支持SpEL表达式
     */
    String key();

    /**
     * 锁的类型
     */
    LockTypeEnum lockType() default LockTypeEnum.REENTRANT;

    /**
     * 获取锁的等待时间
     * 等待获取锁请求时间,这个值为负数的话表示一致等待
     */
    long waitTime() default -1;

    /**
     * 持有锁的时间(leaseTime)
     * 自动释放锁的时间,这个值为负数表示需要手动释放锁
     */
    long leaseTime() default -1;

    /**
     * 时间单位
     */
    TimeUnit timeUnit() default TimeUnit.SECONDS;

    /**
     * 获取锁失败时的处理策略
     */
    LockFailStrategyEnum failStrategy() default LockFailStrategyEnum.EXCEPTION;

    /**
     * 获取锁失败时抛出的异常信息
     */
    String failMessage() default "获取分布式锁失败";
}
3, 实现注解
java 复制代码
import lombok.RequiredArgsConstructor;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.redisson.config.Config;
import org.redisson.Redisson;
import org.redisson.api.RedissonClient;


@Configuration
@EnableConfigurationProperties(RedissonProperties.class)
@RequiredArgsConstructor
public class RedissonConfig {
    private final RedissonProperties redissonProperties;

    @Bean
    public RedissonClient redissonClient() {
        Config config = new Config();
        // 单节点模式
        config.useSingleServer()
            .setIdleConnectionTimeout(redissonProperties.getSingleServerConfig().getIdleConnectionTimeout())
            .setConnectTimeout(redissonProperties.getSingleServerConfig().getConnectTimeout())
            .setTimeout(redissonProperties.getSingleServerConfig().getTimeout())
            .setRetryAttempts(redissonProperties.getSingleServerConfig().getRetryAttempts())
            .setRetryInterval(redissonProperties.getSingleServerConfig().getRetryInterval())
            .setPassword(redissonProperties.getSingleServerConfig().getPassword())
            .setAddress(redissonProperties.getSingleServerConfig().getAddress());

        return Redisson.create(config);
    }
}
java 复制代码
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Configuration;


@Data
@Configuration
@ConfigurationProperties(prefix = "spring.redis.redisson.config")
public class RedissonProperties {
    private RedissonSingleServerConfig singleServerConfig;
}

@Data
class RedissonSingleServerConfig {
    private int idleConnectionTimeout;
    private int connectTimeout;
    private int timeout;
    private int retryAttempts;
    private int retryInterval;
    private String password;
    private String address;
}
java 复制代码
import com.xxx.common.enumeration.LockTypeEnum;
import org.redisson.api.RLock;
import org.redisson.api.RReadWriteLock;
import org.redisson.api.RedissonClient;
import org.springframework.stereotype.Component;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;


@Component
public class RedissonLockManager {

    private final RedissonClient redissonClient;
    private final Map<String, RLock> lockCache = new ConcurrentHashMap<>();

    public RedissonLockManager(RedissonClient redissonClient) {
        this.redissonClient = redissonClient;
    }

    /**
     * 获取指定类型的锁
     *
     * @param key 锁的key
     * @param lockType 锁类型
     * @return 对应的锁对象
     */
    public Lock getLock(String key, LockTypeEnum lockType) {
        switch (lockType) {
            case REENTRANT:
                return redissonClient.getLock(key);
            case FAIR:
                return redissonClient.getFairLock(key);
            case READ:
                RReadWriteLock readWriteLock = redissonClient.getReadWriteLock(key);
                return readWriteLock.readLock();
            case WRITE:
                RReadWriteLock rwLock = redissonClient.getReadWriteLock(key);
                return rwLock.writeLock();
            default:
                return redissonClient.getLock(key);
        }
    }

    /**
     * 尝试获取锁
     *
     * @param lock 锁对象
     * @param waitTime 等待时间
     * @param leaseTime 持有锁时间
     * @param timeUnit 时间单位
     * @return 是否获取成功
     * @throws InterruptedException 中断异常
     */
    public boolean tryLock(Lock lock, long waitTime, long leaseTime, TimeUnit timeUnit) throws InterruptedException {
        if (lock instanceof RLock) {
            RLock rLock = (RLock) lock;
            if (waitTime > 0) {
                if (leaseTime > 0) {
                    return rLock.tryLock(waitTime, leaseTime, timeUnit);
                } else {
                    return rLock.tryLock(waitTime, timeUnit);
                }
            } else {
                if (leaseTime > 0) {
                    rLock.lock(leaseTime, timeUnit);
                    return true;
                } else {
                    rLock.lock();
                    return true;
                }
            }
        } else {
            // 对于非RLock类型,使用简单实现
            if (waitTime > 0) {
                return lock.tryLock(waitTime, timeUnit);
            } else {
                lock.lock();
                return true;
            }
        }
    }

    /**
     * 释放锁
     *
     * @param lock 锁对象
     */
    public void unlock(Lock lock) {
        if (lock instanceof RLock) {
            RLock rLock = (RLock) lock;
            if (rLock.isHeldByCurrentThread()) {
                rLock.unlock();
            }
        } else {
            lock.unlock();
        }
    }
}
java 复制代码
import com.xxx.annotation.DistributedLock;
import com.xxx.common.constant.RedisKey;
import com.xxx.common.exception.RedissonLockException;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.stereotype.Component;

import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.concurrent.locks.Lock;

@Aspect
@Component
@Slf4j
public class RedissonLockAspect {
    private final RedissonLockManager lockManager;
    private final ExpressionParser parser = new SpelExpressionParser();

    public RedissonLockAspect(RedissonLockManager lockManager) {
        this.lockManager = lockManager;
    }

    @Around("@annotation(distributedLock)")
    public Object around(ProceedingJoinPoint joinPoint, DistributedLock distributedLock) throws Throwable {
        // 解析锁的key
        String key = RedisKey.LOCK_BASE + distributedLock.prefix() + "_" +parseKey(distributedLock.key(), joinPoint);
        log.info("获取锁: {}", key);
        // 获取对应类型的锁
        Lock lock = lockManager.getLock(key, distributedLock.lockType());

        boolean locked = false;
        try {
            // 尝试获取锁
            locked = lockManager.tryLock(
                lock,
                distributedLock.waitTime(),
                distributedLock.leaseTime(),
                distributedLock.timeUnit()
            );

            if (locked) {
                log.info("获取锁成功,开始执行目标方法");
                // 执行目标方法
                return joinPoint.proceed();
            } else {
                log.info("获取锁失败,返回失败信息");
                return handleLockFail(joinPoint, distributedLock);
            }
        } finally {
            // 释放锁
            if (locked) {
                lockManager.unlock(lock);
            }
        }
    }

    /**
     * 解析锁的key,支持SpEL表达式
     */
    private String parseKey(String keyExpression, ProceedingJoinPoint joinPoint) {
        if (!keyExpression.contains("#")) {
            return keyExpression;
        }

        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();

        Parameter[] paramNames = method.getParameters();
        Object[] args = joinPoint.getArgs();

        EvaluationContext context = new StandardEvaluationContext();
        if (paramNames != null) {
            for (int i = 0; i < paramNames.length; i++) {
                context.setVariable(paramNames[i].getName(), args[i]);
            }
        }

        return parser.parseExpression(keyExpression).getValue(context, String.class);
    }

    /**
     * 处理获取锁失败的情况
     */
    private Object handleLockFail(ProceedingJoinPoint joinPoint, DistributedLock distributedLock) throws Throwable {
        switch (distributedLock.failStrategy()) {
            case RETURN_NULL:
                Class<?> returnType = ((MethodSignature) joinPoint.getSignature()).getReturnType();
                if (returnType.isPrimitive()) {
                    // 原始类型无法返回null,抛出异常
                    throw new RedissonLockException(distributedLock.failMessage());
                }
                return null;
            case CONTINUE:
                return joinPoint.proceed();
            case EXCEPTION:
            default:
                throw new RedissonLockException(distributedLock.failMessage());
        }
    }
}
java 复制代码
public enum LockTypeEnum {
    /**
     * 可重入锁
     */
    REENTRANT,

    /**
     * 公平锁
     */
    FAIR,

    /**
     * 读锁
     */
    READ,

    /**
     * 写锁
     */
    WRITE
}
java 复制代码
public enum LockFailStrategyEnum {
    /**
     * 抛出异常
     */
    EXCEPTION,

    /**
     * 返回null(仅适用于有返回值的方法)
     */
    RETURN_NULL,

    /**
     * 忽略锁直接执行
     */
    CONTINUE
}
4,使用注解

下面只是一个例子,实际场景中可能是锁订单号或者锁商品id或仓库id之类的。

java 复制代码
@DistributedLock(prefix = "oauth", key =  "#id + ':' + #loginCondition.account", lockType = LockTypeEnum.REENTRANT)
    @Override
    public RestResponse refreshToken(String id, LoginCondition loginCondition) {
        log.info("刷新token");
        return oAuthFeign.refreshToken();
    }

prefix可以自己定义为场景值,或者为空都行,它只是redis key中的一段儿,没啥意义。

关键在于key参数,key是支持SpEL表达式,具体怎么设置可以看下面,可以根据自己的需要来修改。

// 1. 简单参数引用

@DistributedLock(key = "#userId")

// 2. 多个参数拼接

@DistributedLock(key = "#userId + ':' + #orderId")

// 3. 对象属性引用

@DistributedLock(key = "#user.id")

// 4. 复杂对象属性拼接

@DistributedLock(key = "#user.id + ':' + #order.orderNo")

// 5. 使用方法调用

@DistributedLock(key = "#user.getId().concat('-').concat(#type)")

// 6. 字面量和参数混合

@DistributedLock(key = "'order_lock:' + #orderId")