【防止重复提交】Redis + AOP + 注解的方式实现分布式锁

文章目录

工作原理

分布式环境下,可能会遇到用户对某个接口被重复点击的场景,为了防止接口重复提交造成的问题,可用 Redis 实现一个简单的分布式锁来解决问题。

在 Redis 中, SETNX 命令是可以帮助我们实现互斥。SETNX 即 SET if N ot eX ists (对应 Java 中的 setIfAbsent 方法),如果 key 不存在的话,才会设置 key 的值。如果 key 已经存在, SETNX 啥也不做。

需求实现

  1. 自定义一个防止重复提交的注解,注解中可以携带到期时间和一个参数的key
  2. 为需要防止重复提交的接口添加注解
  3. 注解AOP会拦截加了此注解的请求,进行加解锁处理并且添加注解上设置的key超时时间
  4. Redis 中的 key = token + "-" + path + "-" + param_value; (例如:17800000001 + /api/subscribe/ + zhangsan)
  5. 如果重复调用某个加了注解的接口且key还未到期,就会返回重复提交的Result。

1)自定义防重复提交注解

自定义防止重复提交注解,注解中可设置 超时时间 + 要扫描的参数(请求中的某个参数,最终拼接后成为Redis中的key)

java 复制代码
package com.lihw.lihwtestboot.noRepeatSubmit;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
 * 防重复提交注解
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface NoRepeatSubmit {

    /**
     * 锁过期的时间
     */
    int seconds() default 5;

    /**
     * 要扫描的参数
     */
    String scanParam() default "";
}

2)定义防重复提交AOP切面

@Pointcut("@annotation(noRepeatSubmit)") 表示切点表达式,它使用了注解匹配的方式来选择被注解 @NoRepeatSubmit 标记的方法。

java 复制代码
package com.lihw.lihwtestboot.noRepeatSubmit;

import com.alibaba.fastjson.JSONObject;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.io.BufferedReader;
import java.io.IOException;
import java.util.UUID;
/**
 * 重复提交aop
 */
@Aspect
@Component
public class RepeatSubmitAspect {

    private static final Logger LOGGER = LoggerFactory.getLogger(RepeatSubmitAspect.class);

    @Autowired
    private RedisLock redisLock;

    @Pointcut("@annotation(noRepeatSubmit)")
    public void pointCut(NoRepeatSubmit noRepeatSubmit) {
    }

    @Around("pointCut(noRepeatSubmit)")
    public Object around(ProceedingJoinPoint pjp, NoRepeatSubmit noRepeatSubmit) throws Throwable {

        //获取基本信息
        ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        HttpServletRequest request = attributes.getRequest();
        Assert.notNull(request, "request can not null");
        int lockSeconds = noRepeatSubmit.seconds();//过期时间
        String threadName = Thread.currentThread().getName();// 获取当前线程名称
        String param = noRepeatSubmit.scanParam();//请求参数
        String path = request.getServletPath();
        String type = request.getMethod();
        String param_value = "";

        if (type.equals("POST")){
            param_value = JSONObject.parseObject(new BodyReaderHttpServletRequestWrapper(request).getBodyString()).getString(param);
        }else if (type.equals("GET")){
            param_value = request.getParameter(param);
        }

        String token = request.getHeader("uid");
        LOGGER.info("线程:{}, 接口:{},重复提交验证",threadName,path);
        String key;
        if (!"".equals(param) && param != null){
            key = token + "-" + path + "-" + param_value;//生成key

        }else {
            key = token + "-" + path;//生成key
        }

        String clientId = getClientId();// 调接口时生成临时value(UUID)

        // 用于添加锁,如果添加成功返回true,失败返回false 
        boolean isSuccess = redisLock.tryLock(key, clientId, lockSeconds);
      
        ApiResult result = new ApiResult();
        if (isSuccess) {
            LOGGER.info("加锁成功:接口 = {}, key = {}", path, key);
            // 获取锁成功
            Object obj;
            try {
                // 执行进程
                obj = pjp.proceed();// aop代理链执行的方法
            } finally {
                // 据key从redis中获取value
                if (clientId.equals(redisLock.get(key))) {
                    // 解锁
                    redisLock.releaseLock(key, clientId);
                    LOGGER.info("解锁成功:接口={}, key = {},",path, key);
                }
            }
            return obj;
        } else {
            // 添加锁失败,认为是重复提交的请求
            LOGGER.info("重复请求:接口 = {}, key = {}",path, key);
            result.setData("重复提交");
            return result;
        }
    }


    private String getClientId() {
        return UUID.randomUUID().toString();
    }

    public static String getRequestBodyData(HttpServletRequest request) throws IOException{
        BufferedReader bufferReader = new BufferedReader(request.getReader());
        StringBuilder sb = new StringBuilder();
        String line = null;
        while ((line = bufferReader.readLine()) != null) {
            sb.append(line);
        }
        return sb.toString();
    }
}

3)RedisLock 工具类

java 复制代码
package com.lihw.lihwtestboot.noRepeatSubmit;

import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import java.util.concurrent.TimeUnit;


@Service
public class RedisLock {

    private static final Logger logger = LoggerFactory.getLogger(RedisLock.class);

    /**  不设置过期时长 */
    public final static long NOT_EXPIRE = -1;

    @Autowired
    private StringRedisTemplate redisTemplate;

    /**
     * @param lockKey   加锁键
     * @param clientId  加锁客户端唯一标识(采用UUID)
     * @param seconds   锁过期时间
     * @return
     */
    public boolean tryLock(String lockKey, String clientId, long seconds) {
        if (redisTemplate.opsForValue().setIfAbsent(lockKey, clientId,seconds, TimeUnit.SECONDS)) {
            return true;//得到锁
        }else{
            return false;
        }
    }

    /**
     * 与 tryLock 相对应,用作释放锁
     *
     * @param lockKey
     * @param clientId
     * @return
     */
    public boolean releaseLock(String lockKey, String clientId) {
        String currentValue = redisTemplate.opsForValue().get(lockKey);
        try {
            if (!StringUtils.isEmpty(currentValue) && currentValue.equals(clientId)) {
                redisTemplate.opsForValue().getOperations().delete(lockKey);
                return true;
            }else {
                return false;
            }
        } catch (Exception e) {
            logger.error("解锁异常,,{}" , e);
            return false;
        }
    }

    /**
     * 获取
     * @param key
     * @return
     */
    public String get(String key) {
        return get(key, NOT_EXPIRE);
    }

    public String get(String key, long expire) {
        String value = redisTemplate.opsForValue().get(key);
        if(expire != NOT_EXPIRE){
            redisTemplate.expire(key, expire, TimeUnit.SECONDS);
        }
        return value;
    }

    /**
     * 删除
     * @param key
     */
    public void delete(String key) {
        redisTemplate.delete(key);
    }
}

4)过滤器 + 请求工具类

Filter类

java 复制代码
package com.lihw.lihwtestboot.noRepeatSubmit;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.web.servlet.ServletComponentScan;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;


@ServletComponentScan
@WebFilter(urlPatterns = "/*",filterName = "channelFilter")
public class ChannelFilter implements Filter {

    private final Logger logger = LoggerFactory.getLogger(this.getClass());

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        logger.info("-----------------------Execute filter start---------------------");
        // 防止流读取一次后就没有了, 所以需要将流继续写出去
        HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest;
        ServletRequest requestWrapper = new BodyReaderHttpServletRequestWrapper(httpServletRequest);
        filterChain.doFilter(requestWrapper, servletResponse);
    }

}

BodyReaderHttpServletRequestWrapper

对GET和POST请求的获取参数方法进行了封装

java 复制代码
package com.lihw.lihwtestboot.noRepeatSubmit;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;

public class BodyReaderHttpServletRequestWrapper extends HttpServletRequestWrapper{

    /**
     * Request请求参数获取处理类
     */
    private final byte[] body;

    public BodyReaderHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        String sessionStream = getBodyString(request);
        body = sessionStream.getBytes(StandardCharsets.UTF_8);
    }

    /**
     * 获取请求Body
     *
     * @param request
     * @return
     */
    private String getBodyString(final ServletRequest request) {
        StringBuilder sb = new StringBuilder();
        InputStream inputStream = null;
        BufferedReader reader = null;
        try {
            inputStream = cloneInputStream(request.getInputStream());
            reader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("UTF-8")));
            String line = "";
            while ((line = reader.readLine()) != null) {
                sb.append(line);
            }
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (inputStream != null) {
                try {
                    inputStream.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
        return sb.toString();
    }
    public String getBodyString() {
        return new String(body, StandardCharsets.UTF_8);
    }
    /**
     * Description: 复制输入流
     *
     * @param inputStream
     * @return
     */
    public InputStream cloneInputStream(ServletInputStream inputStream) {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        byte[] buffer = new byte[1024];
        int len;
        try {
            while ((len = inputStream.read(buffer)) > -1) {
                byteArrayOutputStream.write(buffer, 0, len);
            }
            byteArrayOutputStream.flush();
        } catch (IOException e) {
            e.printStackTrace();
        }
        InputStream byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray());
        return byteArrayInputStream;
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        final ByteArrayInputStream bais = new ByteArrayInputStream(body);

        return new ServletInputStream() {

            @Override
            public int read() throws IOException {
                return bais.read();
            }

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {
            }
        };
    }
}

5)测试Controller

java 复制代码
package com.lihw.lihwtestboot.noRepeatSubmit;

import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import javax.validation.constraints.NotEmpty;

@RestController
@RequestMapping("/api")
@Validated
public class noRepeatSubmitController {

    @GetMapping("/subscribe/{channel}")
    @NoRepeatSubmit(seconds = 10,scanParam = "username")
    public ApiResult subscribe(@RequestHeader(name = "uid") String phone,@RequestHeader(name = "username") String username,@PathVariable("channel") @NotEmpty(message = "channel不能为空") String channel) {

        System.out.println("phone=" + phone);
        System.out.println("username=" + username);
        System.out.println("channel=" + channel);

        try {
            Thread.sleep(5000);//模拟耗时
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

        return new ApiResult("success","data");
    }
}

6)测试结果

重复点击

相关推荐
TDengine (老段)6 分钟前
TDengine 字符串函数 CHAR 用户手册
java·大数据·数据库·物联网·时序数据库·tdengine·涛思数据
qq74223498417 分钟前
Python操作数据库之pyodbc
开发语言·数据库·python
姚远Oracle ACE40 分钟前
Oracle 如何计算 AWR 报告中的 Sessions 数量
数据库·oracle
Dxy12393102161 小时前
MySQL的SUBSTRING函数详解与应用
数据库·mysql
码力引擎1 小时前
【零基础学MySQL】第十二章:DCL详解
数据库·mysql·1024程序员节
杨云龙UP1 小时前
【MySQL迁移】MySQL数据库迁移实战(利用mysqldump从Windows 5.7迁至Linux 8.0)
linux·运维·数据库·mysql·mssql
l1t2 小时前
利用DeepSeek辅助修改luadbi-duckdb读取DuckDB decimal数据类型
c语言·数据库·单元测试·lua·duckdb
睡前要喝豆奶粉2 小时前
在.NET Core Web Api中使用redis
redis·c#·.netcore
安当加密2 小时前
Nacos配置安全治理:把数据库密码从YAML里请出去
数据库·安全
ColderYY2 小时前
Python连接MySQL数据库
数据库·python·mysql