基于令牌桶算法实现一个限流器


序言:本文章基于令牌桶算法实现了简单的一个限流器

1 令牌桶算法

实现原理

  • 令牌生成:在固定的时间间隔内,算法会向一个桶中放入一定数量的令牌。令牌的生成速率是固定的,通常以每秒钟生成的令牌数来表示。
    • 桶的容量:桶有一个最大容量,如果桶满了,新的令牌将被丢弃。这意味着即使在高流量情况下,系统也不会无限制地增加请求。
  • 请求处理:每当一个请求到达时,它需要从桶中获取一个令牌。如果桶中有令牌,请求就可以被处理,桶中的令牌数量减一。如果没有令牌,请求将被拒绝或被延迟,具体取决于实现。
  • 流量控制:通过调整令牌的生成速率和桶的容量,可以控制流量的平稳性和最大流量。

算法优点:

  • 可以承载一定的突发流量情况。
  • 限流窗口的变化相对平稳。

coding

根据令牌桶算法原理,可以先定义出三个变量。桶容量令牌产生速率当前桶中的令牌数量。同时定义一个rateLimiter类和对应的构造方法:

java 复制代码
public class RateLimiter {
    // 自己写的日志打印工具(在线程池的文章中有贴)
    static Logger log = new Logger(Logger.LogLevel.DEBUG, RateLimiter.class);
    // 桶容量
    private final int maxPermit;
    // 令牌产生速率 / 秒
    private final int rate;
    // 当前桶中的令牌数量(考虑到这个变量会多线程使用, 使用原子类来实现)
    private final AtomicInteger holdPermits;
    
    public RateLimiter(int maxPermit, int rate, int initPermits) {
        if (rate < 1) throw new IllegalArgumentException("the rate must be greater than 1");
        if (initPermits < 0) throw new IllegalArgumentException("the initPermits must be greater than 0");
        if (maxPermit < 1) throw new IllegalArgumentException("the maxPermit must be greater than 1");
        this.maxPermit = maxPermit;
        this.rate = rate;
        this.holdPermits = new AtomicInteger(initPermits);
    }
}

然后我们需要给这个类添加一个 boolean tryAcquire(int permit) 方法。表示获取permit数量个令牌。如果获取成功返回true,获取失败则返回false。

可以写出这个方法:

java 复制代码
/**
 * 尝试获取 permit 数量的令牌
 *
 * @param permit 令牌数量
 * @return 获取到 permit 数量的令牌则返回 true, 否则返回 false
 */
public boolean tryAcquire(int permit) {
    if (permit > maxPermit) throw new IllegalArgumentException("the permit must be smaller than maxPermit:" + maxPermit);
    if (permit < 1) throw new IllegalArgumentException("the permit must be greater than 1");
    int currentPermits;
    do {
        currentPermits = holdPermits.get();
        if (currentPermits < permit) {
            return false;
        }
    } while (!holdPermits.compareAndSet(currentPermits, currentPermits - permit));
    // 日志打印
    log.debug("原令牌数:" + currentPermits + ", 减少:" + permit + ", 当前总数:" + (currentPermits - permit));
    return true;
}

这个方法中借助了原子类的compareAndSet操作和自旋来实现令牌的扣减。当当前桶中的令牌数量大于等于请求获取的令牌数时,使用compareAndSet来实现令牌的扣减。

但是这个方法可能由于其他线程的并发执行(提前扣减了令牌)而失败。所以需要自旋操作保证令牌数量足够时可以正确获得令牌许可。只有当桶中的令牌数小于请求要求的令牌数量时才会返回失败。

至此,令牌桶的构造方法获取令牌的方法已经实现完成。但是何时且如何向桶中放入令牌呢?

如果使用定时任务的话,那么就需要创建额外的线程对象来完成。

这里借鉴前人的智慧,在每一次获取令牌时顺便计算和更新令牌数量,这样的话,我们还需要一个变量记住上一次计算令牌的时间。

所以在类中加一个变量 lastFreshTime 记录上一次计算更新令牌的时间,同时由于这个变量可能被多个线程更改,使用原子类对象保证线程安全。

java 复制代码
// 上次计算更新令牌的时间
private final AtomicLong lastFreshTime;

然后,我们还需要一个方法,用于每次获取令牌前计算更新桶中的令牌数量,这个方法中首先根据过去时间的纳秒数量计算出应该产生的令牌数量,使用int向下取整,

然后使用令牌数量反向计算产生这些令牌所需的准确时间(因为令牌数量使用int取整了),加上当前lastFreshTime的值即可以得到新的lastFreshTime的值。

这里为了保证线程安全,lastFreshTime 的更新使用compareAndSet保证只有一个线程可以获取更新权限。这个线程在成功更新lastFreshTime后,需要继续更新令牌的数量,

由于在tryAcquire中还可能出现其他线程扣减令牌数量的行为,所以这里还需要保证更新操作的原子性。

java 复制代码
/**
 * 刷新令牌数量
 */
private void freshPermit() {
    long now = System.nanoTime();
    long lastTime = lastFreshTime.get();
    if (now <= lastTime) return;
    int increment = (int) ((now - lastTime) * rate / 1_000_000_000);
    long thisTime = lastTime + increment * 1_000_000_000L / rate;
    if (increment > 0 && lastFreshTime.compareAndSet(lastTime, thisTime)) {
        int current;
        int next;
        do {
            current = holdPermits.get();
            next = Math.min(maxPermit, current + increment);
        } while (!holdPermits.compareAndSet(current, next));
        log.debug("原令牌数:" + current + ", 增加:" + increment + ", 当前总数:" + holdPermits.get());
    }
}

至此,我们就实现了一个简单的令牌桶实现代码,只要每次tryAcquire时先使用freshPermit更新一下令牌数量就可以了。

但通常来说,令牌桶还会有一个带超时时间的boolean tryAcquire(int permit, long timeOut)方法。这里我们做一个简单的实现,使用定期的sleep操作而不是锁机制来完成。

假设要获取的令牌数量为 p,超时时间为 t,那么在tryAcquire(p, t)方法中,如果当前令牌不足p的话,那么线程将会sleep一定时间会再次尝试获取令牌,直到使用时间超过t仍未获取成功才会返回失败。

这里有一个问题就是睡眠时间sleepDuration的确定,在这个代码中,sleepDuration的值为 (p / rate) / 10,且最小为10,最大为100。

实现代码:

java 复制代码
/**
 * 在 timeout 时间内尝试获取 permit 数量的令牌
 *
 * @param permit  令牌数量
 * @param timeOut 超时时间 单位 秒
 * @return 如果在 timout 时间内获取到 permit 数量的令牌则返回 true, 否则返回 false
 */
public boolean tryAcquire(int permit, long timeOut) {
    if (permit < 1) throw new IllegalArgumentException("the permit must be greater than 1");
    if (timeOut < 0) throw new IllegalArgumentException("the timeOut must be greater than 0");
    timeOut = timeOut * 1_000_000_000 + System.nanoTime();
    long sleepDuration = (long) (1.0 * permit / rate);
    sleepDuration = Math.min(sleepDuration, 100);
    sleepDuration = Math.max(sleepDuration, 10);
    while (System.nanoTime() < timeOut) {
        if (tryAcquire(permit)) return true;
        else {
            try {
                Thread.sleep(sleepDuration);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                return false;
            }
        }
    }
    return false;
}

至此,我们就实现了一个简单的使用令牌桶算法的限流器类。

完整代码:

java 复制代码
public class RateLimiter {

    static Logger log = new Logger(Logger.LogLevel.DEBUG, RateLimiter.class);
    /**
     * 最大令牌数量
     */
    private final int maxPermit;
    /**
     * 令牌产生速率 / 每秒
     */
    private final int rate;

    /**
     * 当前可用令牌数量
     */
    private final AtomicInteger holdPermits;
    /**
     * 上次计算令牌的时间
     */
    private final AtomicLong lastFreshTime;


    public RateLimiter(int maxPermit, int rate, int initPermits) {
        if (rate < 1) throw new IllegalArgumentException("the rate must be greater than 1");
        if (initPermits < 0) throw new IllegalArgumentException("the initPermits must be greater than 0");
        if (maxPermit < 1) throw new IllegalArgumentException("the maxPermit must be greater than 1");
        if (maxPermit < rate) throw new IllegalArgumentException("the maxPermit must be greater than rate");
        this.maxPermit = maxPermit;
        this.rate = rate;
        this.holdPermits = new AtomicInteger(initPermits);
        this.lastFreshTime = new AtomicLong(System.nanoTime());
    }

    /**
     * 尝试获取 permit 数量的令牌
     *
     * @param permit 令牌数量
     * @return 获取到 permit 数量的令牌则返回 true, 否则返回 false
     */
    public boolean tryAcquire(int permit) {
        if (permit > maxPermit)
            throw new IllegalArgumentException("the permit must be smaller than maxPermit:" + maxPermit);
        if (permit < 1) throw new IllegalArgumentException("the permit must be greater than 1");
        freshPermit();

        int currentPermits;
        do {
            currentPermits = holdPermits.get();
            if (currentPermits < permit) {
                return false;
            }
        } while (!holdPermits.compareAndSet(currentPermits, currentPermits - permit));

        log.debug("原令牌数:" + currentPermits + ", 减少:" + permit + ", 当前总数:" + (currentPermits - permit));
        return true;
    }

    /**
     * 刷新令牌数量
     */
    private void freshPermit() {
        long now = System.nanoTime();
        long lastTime = lastFreshTime.get();
        if (now <= lastTime) return;
        int increment = (int) ((now - lastTime) * rate / 1_000_000_000);
        long thisTime = lastTime + increment * 1_000_000_000L / rate;
        if (increment > 0 && lastFreshTime.compareAndSet(lastTime, thisTime)) {
            int current;
            int next;
            do {
                current = holdPermits.get();
                next = Math.min(maxPermit, current + increment);
            } while (!holdPermits.compareAndSet(current, next));
            log.debug("原令牌数:" + current + ", 增加:" + increment + ", 当前总数:" + holdPermits.get());
        }
    }

    /**
     * 在 timeout 时间内尝试获取 permit 数量的令牌
     *
     * @param permit  令牌数量
     * @param timeOut 超时时间 单位 秒
     * @return 如果在 timout 时间内获取到 permit 数量的令牌则返回 true, 否则返回 false
     */
    public boolean tryAcquire(int permit, long timeOut) {
        if (permit < 1) throw new IllegalArgumentException("the permit must be greater than 1");
        if (timeOut < 0) throw new IllegalArgumentException("the timeOut must be greater than 0");
        timeOut = timeOut * 1_000_000_000 + System.nanoTime();
        long sleepDuration = (long) (1.0 * permit / rate);
        sleepDuration = Math.min(sleepDuration, 100);
        sleepDuration = Math.max(sleepDuration, 10);
        while (System.nanoTime() < timeOut) {
            if (tryAcquire(permit)) return true;
            else {
                try {
                    Thread.sleep(sleepDuration);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    return false;
                }
            }
        }
        return false;
    }

    public static void main(String[] args) {
        RateLimiter rateLimiter = new RateLimiter(100, 50, 0);
        log.info("开始");
        for (int i = 0; i < 3; i++) {
            int j = i;
            new Thread(() -> {
                int k = 0;
                while (true) {
                    if (rateLimiter.tryAcquire(20, 1)) {
                        log.info("第" + k++ + "轮, 线程 " + j + " 获取令牌成功");
                    } else log.error("第" + k++ + "轮, 线程 " + j + " 获取令牌失败");
                    try {
                        Thread.sleep(888);
                    } catch (InterruptedException e) {
                        throw new RuntimeException(e);
                    }
                }
            }).start();
        }
    }
}

总结

这个案例中其实还是可以学习到不少关于如何线程安全的实现功能的问题。

相关推荐
是小崔啊几秒前
开源轮子 - Apache Common
java·开源·apache
因我你好久不见5 分钟前
springboot java ffmpeg 视频压缩、提取视频帧图片、获取视频分辨率
java·spring boot·ffmpeg
程序员shen1616117 分钟前
抖音短视频saas矩阵源码系统开发所需掌握的技术
java·前端·数据库·python·算法
Ling_suu36 分钟前
SpringBoot3——Web开发
java·服务器·前端
天使day1 小时前
SpringMVC
java·spring·java-ee
CodeClimb1 小时前
【华为OD-E卷-简单的自动曝光 100分(python、java、c++、js、c)】
java·python·华为od
风清云淡_A1 小时前
【java基础系列】实现数字的首位交换算法
java·算法
Gao_xu_sheng1 小时前
Java程序打包成exe,无Java环境也能运行
java·开发语言
大卫小东(Sheldon)1 小时前
Java的HTTP接口测试框架Gatling
java
谢家小布柔2 小时前
java中的继承
java·开发语言