gin 框架学习之路

文章目录

前言

最近准备面试的时候发现很多职位的描述都说要熟悉gin,我一直没搞懂,什么样才算熟悉?不知道我下面的内容算不算到了熟悉gin的阶段?

我理解能搞懂gin的context,route 和middleware三大组件就算是熟悉gin了。

创建engin有两种方式

  • new() 函数
  • Default()函数

主要是这个Default()在创建的时候会添加两个中间件Logger()和Recovery()以后调用New()

这个engin 其实除了gin.Contex{}意外没什么说的。下面这部分马上要说了

route

这块源码就算了,一个使用者不需要过多的关注。

http 主要使用的method主要是

  • get
  • put
  • post
  • delete
    gin会为每一种method创建一个 radix Tree,这种压缩前缀树可以减少数的高度更快的查询和节省内存空间,构建的路由树如下:
text 复制代码
根节点 /
└── api/
    ├── users
    │   ├── (空)         # 对应 /api/users
    │   └── :id
    │       ├── (空)     # 对应 /api/users/:id,优先级比下面的path低
    │       └── /posts   # 对应 /api/users/:id/posts,优先级较高
    ├── posts            # 对应 /api/posts
    └── static/
        └── *filepath    # 对应 /static/*filepath4
其中 静态路由优先级最高,参数路由次之,通配符路由最后

路由分组

路由分组,其实可以减少冲突,方便路由的管理,主要体现在路由路径的书写和添加路由级别的中间件下面是代码示例

go 复制代码
	api := r.Group("/api")
	api.GET("/test", func(context *gin.Context) {
			fmt.Println("这是一个组级别的中间件注册函数,近对",context.Request.Method, context.Request.URL.Path)

		context.JSON(200, gin.H{"message": "OK"})
	})
	// 下面是一个重点,如果在api的这个路由基础之上再创建一个Group,原来api组的中间件会继承
	// 执行顺序是 globe middleware------>group middleware------>sub-group middleware------>最终执行程序。
	// middleware 的个数限制在63个
	subGroup := api.Group("/demo")
	subGroup.Use(func(c *gin.Context) {
		fmt.Println("======================")
	})
	subGroup.GET("/hello", func(c *gin.Context) {
		fmt.Println(9999999)
	})

gin contex源码整体面貌

gin.Context是什么?有什么用

gin 有自己的 Context gin.Context{},gin对原生context的高级封装,所有的请求,响应处理链控制缓存都被封装到了context了, 中间件流程控制、数据传递等核心能力,是处理器和中间件的核心入参。是单次 HTTP 请求的上下文载体,贯穿请求处理的全生命周期(从请求进入到响应返回)。

主要用处:

  • 它本身是一个sync.Pool的对象池,减少了GC的压力。
  • 中间件的处理顺序控制(context.Next() )
  • get/set进行数据传递:比如分布式链路追踪的requestID
  • 参数预处理
  • 响应处理,比如统一格式(这个对于前端来说非常重要)
  • 请求生命周期管理(超时处理)

源码结构

go 复制代码
// Context is the most important part of gin. It allows us to pass variables between middleware,
// manage the flow, validate the JSON of a request and render a JSON response for example.
type Context struct {
	writermem responseWriter
	Request   *http.Request
	Writer    ResponseWriter
	Params   Params
	handlers HandlersChain
	index    int8
	fullPath string
	engine       *Engine
	params       *Params
	skippedNodes *[]skippedNode
	// This mutex protects Keys map.
	mu sync.RWMutex
	// Keys is a key/value pair exclusively for the context of each request.
	// 每次使用context.Set(k, v)的时候k/v对都放这里,对于keys 的处理是并发安全的,比如我们从Header中获取支持的语言或者时区
	// 链路追踪的时候传递requestID等等场景
	Keys map[any]any
	// Errors is a list of errors attached to all the handlers/middlewares who used this context.
	Errors errorMsgs
	// Accepted defines a list of manually accepted formats for content negotiation.
	Accepted []string
	// queryCache caches the query result from c.Request.URL.Query().
	queryCache url.Values
	// formCache caches c.Request.PostForm, which contains the parsed form data from POST, PATCH,
	// or PUT body parameters.
	formCache url.Values
	// SameSite allows a server to define a cookie attribute making it impossible for
	// the browser to send this cookie along with cross-site requests.
	sameSite http.SameSite
}

我们可以从gin.context中获取到前端传来的参数

  1. 获取url中的Param参数
go 复制代码
	r.GET("/greet/:name", func(c *gin.Context) {
		name := c.Param("name")
		c.String(http.StatusOK, "Hello %s", name)
	})
  1. 获取url中的query参数
go 复制代码
	// URL: /search?q=golang&page=2&limit=20
	// 比如你一次查几页的数据,每个页获取几条数据
	context.Query("page")
	// 但是我们大多数的时候都用下面这个DefaultQuery以防止前端的同学给我们传一个非法值导致后端程序出错
	context.DefaultQuery("page", "20")
	 limit := c.DefaultQuery("limit", "10")
    // 注意:返回的是字符串,需要转换
    page, _ := strconv.Atoi(pageStr)
    limit, _ := strconv.Atoi(limitStr)
// 校验参数是否存在
// URL: /filter?category=books
func FilterHandler(c *gin.Context) {
    // 检查必填参数是否存在
    if category, exists := c.GetQuery("category"); exists {
        // 参数存在,进行处理
        // category 是字符串类型
    } else {
        // 参数不存在
        c.JSON(400, gin.H{"error": "category参数必填"})
        return
    }
}

middleware

其实这个是包含在context里的,单列出来是想把context模块化学习。

middleware的常见功能

  • 参数认证
  • 流量控制(限流)
  • gin崩溃自动recover
  • 统计监控
  • 比如想统计P99 的请求用时
  • 解决跨域问题
  • 路由分组的时候,基于路由组的中间件是继承的。
  • 登录认证/权限管理

常用登录认证方式

  • jwt
  • session
  • Oauth 2.0

c.Next()的作用是什么(中间件怎么保证顺序)

c.Next()是 Gin 框架中间件机制的核心,它控制着中间件的执行流程。

c.Next() 的作用是将控制权传递给下一个中间件,当后续的中间件执行完成以后,返回到当前的中间件继续执行,他的执行顺序是按middleware的注册顺序执行的,它的执行模型是洋葱模型

通过c.Use()的是全局的中间件,如果在具体的调用的时候也可以通过c.Next()函数来转移执行权限

go 复制代码
func Middleware() {
	r := gin.Default()

	// 中间件A
	r.Use(func(c *gin.Context) {
		now := time.Now()
		fmt.Println("Middleware A - Before Next")

		c.Next() // 执行下一个中间件,下一个执行完返回以后继续从这里执行
		fmt.Println("Middleware A - After Next")
		fmt.Println(time.Since(now))
	})
	r.Use(func(context *gin.Context) {
		fmt.Println("这个middleware 一次性执行完,后面不会再返回到这个函数来了")
	})
	// 中间件B
	r.Use(func(c *gin.Context) {
		fmt.Println("Middleware B - Before Next")
		c.Next() // 执行下一个中间件,下一个执行完返回以后继续从这里执行,如果没有下一个Middleware 就行外部调用的接口
		fmt.Println("Middleware B - After Next")
	})
	// 处理函数
	r.GET("/ping", func(c *gin.Context) {
		fmt.Println("Handler - Processing Request")
		time.Sleep(1 * time.Second)
		c.JSON(200, gin.H{"message": "OK"})
	})
	r.Run(":8080")
}

gin 解决跨域问题

gin 是通过middleware解决跨域问题的

text 复制代码
什么是跨域问题?
* 同源:协议 + 域名 + 端口 3个完全相同
* 当一个 Web 应用(源)向另一个源请求资源时,如果协议、域名、端口任意一个不同,就发生了跨域。
* github.com/gin-contrib/cors
go 复制代码
func main(){
	r := gin.Default()

	// 根据环境加载配置
	corsConfig := getCORSConfig(os.Getenv("APP_ENV"))
	r.Use(cors.New(corsConfig))
}
func getCORSConfig(env string) cors.Config {
	config := cors.Config{
		AllowMethods: []string{
			"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS",
		},
		AllowHeaders: []string{
			"Origin", "Content-Type", "Content-Length",
			"Accept-Encoding", "Authorization", "X-CSRF-Token",
		},
		ExposeHeaders:    []string{"Content-Length"},
		AllowCredentials: true,
		MaxAge:           12 * time.Hour,
	}

	switch env {
	case "production":
		config.AllowOrigins = strings.Split(
			os.Getenv("ALLOWED_ORIGINS"), ",")
	case "staging":
		config.AllowOrigins = []string{
			"https://staging.example.com",
			"https://staging-admin.example.com",
		}
	default: // development
		config.AllowOrigins = []string{
			"http://localhost:3000",
			"http://localhost:3001",
			"http://127.0.0.1:3000",
		}
	}
	return config
}

gin 服务启动后,每次请求的入口函数

go 复制代码
// gin 实现了http/net的ServeHTTP 接口
// ServeHTTP conforms to the http.Handler interface.
func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// 每次请求到来都从pool中获取一个context()来减少gc带来的压力是gin一个很好对资源申请和回收的一个示例
	c := engine.pool.Get().(*Context)
	c.writermem.reset(w)
	c.Request = req
	// 每次取到新的context都要重置也是非常关键的一步
	c.reset()

	engine.handleHTTPRequest(c)

	// 用完放回,等待下一次复用
	engine.pool.Put(c)
}

其他

gin 对于零拷贝的使用

首先需要明确的一点是:很多框架都是使用了零拷贝的技术,而不是自己实现了零拷贝技术,最终都是调用内核的sendfile()函数。

gin 使用零拷贝技术的地方:

  • 主要是用在了gin.Static()和 gin.StaticFS()
  • 响应发送 gin.Context.File()和 gin.Context.DataFromReader()

零拷贝技术

零拷贝之前的过程

以一个读取本地文件的例子来说明未使用零拷贝的过程

步骤 操作 cpu模式切换 复制类型 备注
1 进程调用 read() 用户态 → 内核态 - 用户程序执行系统调用
2 从磁盘读取到页缓存 内核态 硬件复制(DMA) 若缓存命中则跳过
3 从页缓存复制到用户缓冲区 内核态 需要CPU参与 copy_to_user
4 read()返回 内核态 → 用户态 - -
5 用户进程调用 send() 用户态 → 内核态 - -
6 从用户缓冲区复制到 socket 缓冲区 用户态 需要CPU参与 copy_from_user
7 从 socket 缓冲区到网卡 内核态 硬件复制(DMA)
8 send()返回 内核态 → 用户态 - -
tex 复制代码
DMA是Direct Memory Access,硬件直接访问内存,不消耗CPU周期
零拷贝并不是完全绕过CPU,不需要CPU的参与了,这期间仍然需要CPU的参与比如执行系统调用,管理文件描述符,网络协议栈的处理和发起DMA指令等,只是不参与CPU复制了

上面的简化步骤如下:
数据流:
1.磁盘 → 页缓存(DMA,0次CPU复制)
2.页缓存 → 用户缓冲区(CPU复制1次)
3.用户缓冲区 → socket缓冲区(CPU复制2次)
4.socket缓冲区 → 网卡(DMA,0次CPU复制)

上下文切换:4次(read调用+返回,send调用+返回)

零拷贝技术

复制代码
数据流:
不经过用户态的缓冲区
1.磁盘 → 页缓存(DMA,0次CPU复制)
2.页缓存 → socket缓冲区(DMA,0次CPU复制)
3.socket缓冲区 → 网卡(DMA,0次CPU复制)

上下文切换:2次(sendfile调用+返回)
零拷贝的技术核心
  • 内核内部的"管道"传递机制
  • 核心:传递page引用而不是复制数据
  • 这个管道是一个虚拟管道,用户态并不可见
  • 这个管道不存储数据,只传递引用,调用完成后立即销毁
  • 页缓存到socket缓冲区采用直接映射, 不复制数据,只是指针传递,共享同一文件的页缓存。舍弃原来的 memcpy(skb->data, user_buffer, size);
  • DMA硬件技术支持
    这里补充一个知识点: 内存的带宽
tex 复制代码
DDR4 3200MHz,单通道64位:3200 × 64 × 1 ÷ 8 = 25.6 GB/s
DDR5 4800MHz,双通道64位:4800 × 64 × 2 ÷ 8 = 76.8 GB/s

其实这个是冷知识,一般不需要知道,截止到2025年6月 dell的一款普通服务器760已经是8通道了。如果你选择服务器时需要的总内存不大,尽量用多个小内存条,而不是一个大内存条,可以提高速度 
零拷贝不适用的情况
  1. 你能明显的知道 零拷贝不经过用户态,如果你需要修改数据的情况下不能使用零拷贝技术
  2. 输入必须是文件
  • in_fd必须是普通文件或块设备
  • 不能是管道、socket或其他特殊文件
  1. 输出必须是socket
  • out_fd必须是socket,不能是普通文件或标准输出(sendfile()的设计初衷是优化网络文件传输, 文件到文件复制已有 copy_file_range()等专用接口)

gin 默认是开启长链接的

实际项目如果需要可以修改持续时长来调整

go 复制代码
package main

import (
    "net/http"
    "time"
    "github.com/gin-gonic/gin"
)

func main() {
    // 1. 创建 Gin 路由
    router := gin.Default()
    
    // 2. 定义路由
    router.GET("/ping", func(c *gin.Context) {
        c.String(http.StatusOK, "pong")
    })
    
    // 3. 配置并启动 HTTP 服务器
    server := &http.Server{
        Addr:         ":8080",           // 监听地址
        Handler:      router,            // 使用Gin作为处理器
        IdleTimeout:  60 * time.Second,  // 空闲连接超时,启用Keep-Alive
        ReadTimeout:  10 * time.Second,  // 读取请求超时
        WriteTimeout: 10 * time.Second,  // 写入响应超时
    }
    
    // 4. 启动服务器
    server.ListenAndServe()
}

gin 为什么为什么没有使用协程池,而是直接使用http/net的每个请求直接创建一个协程

以下是个人见解

  • go本身追求简单
  • 如果单个程序的并发不超过10w 没必要使用协程池,使用协程池有点过度优化了
  • 如果有突发流量,资源池的大小设定不合适可能这里会有瓶颈
  • 协程本身就不是通过系统创建/销毁/调度的,而是用户态创建,是一个轻量级的单位
  • 实际业务的协程大部分时间都是在处理io,都不在runQ里面(只有_Grunnable在runq里),而是分不到了各种等待队列,挂起的协程是不消耗CPU的
  • 网络io等待队列
  • channel 等等队列
  • 同步原语等待队列(互斥锁,读写锁,waitGroup等)
  • 定时器等待队列

知名公司的实际经验

案例1:Cloudflare DNS 服务

峰值:1300万 QPS

goroutine 数量:百万级

全局队列大小:很少超过 1000

原因:DNS 查询是纯 I/O 操作

案例2:Uber 的 Go 服务

10万+ QPS

全局队列竞争:< 1%

调度延迟:< 1ms

常见限流的实现方案

1.1.2 滑动窗口

  1. 先给一个简单容易理解的版本(基于切片)
go 复制代码
// 统计需要清理的时间戳进行清理
// 清理完成以后检查长度
// 实现简单,但是如果qps 设计成 10w/s 遍历消耗时间切可能出现频繁内存的申请和释放
package main

import (
	"fmt"
	"sync"
	"time"
)

// SlidingWindowLimiter 滑动窗口限流器
type SlidingWindowLimiter struct {
	windowSize time.Duration // 例如 1 秒
	limit      int           // 允许最大请求数
	mu         sync.Mutex
	timestamps []int64 // 记录通过请求的时间戳(毫秒)
}

// NewSlidingWindowLimiter 创建滑动窗口限流器
func NewSlidingWindowLimiter(window time.Duration, limit int) *SlidingWindowLimiter {
	return &SlidingWindowLimiter{
		windowSize: window,
		limit:      limit,
		timestamps: make([]int64, 0),
	}
}

// Allow 尝试请求,成功返回 true,否则 false
// 每次有请求进来才计算已经记录的是不是有过期的时间戳,需要更新
func (s *SlidingWindowLimiter) Allow() bool {
	s.mu.Lock()
	defer s.mu.Unlock()

	now := time.Now().UnixMilli()

	// 1. 清除窗口之外的时间戳
	cutoff := now - s.windowSize.Milliseconds()
	idx := 0
	for idx < len(s.timestamps) && s.timestamps[idx] < cutoff {
		idx++
	}
	s.timestamps = s.timestamps[idx:] // 保留窗口内的请求

	// 2. 判断当前窗口内是否超过限制
	if len(s.timestamps) >= s.limit {
		return false
	}

	// 3. 加入当前请求
	s.timestamps = append(s.timestamps, now)
	return true
}

func main() {
	limiter := NewSlidingWindowLimiter(1*time.Second, 5)

	for i := 0; i < 10; i++ {
		if limiter.Allow() {
			fmt.Println(i, "passed")
		} else {
			fmt.Println(i, "blocked")
		}
		time.Sleep(200 * time.Millisecond)
	}
}
  1. 基于环形数组的实现
  • 减少内存的分配和垃圾回收
  • 减少内存切片拷贝
go 复制代码
package main

import (
   "fmt"
   "sync"
   "time"
)

type SimpleSlidingWindow struct {
   mu     sync.Mutex
   slots  int           // 窗口分成的槽位数
   limit  int           // 时间窗口内的请求限制数
   window time.Duration // 时间窗口长度
   counts []int         // 每个槽位的计数
   times  []int64       // 每个槽位的起始时间戳
}

func NewSimpleSlidingWindow(window time.Duration, slots, limit int) *SimpleSlidingWindow {
   if slots < 60{
   	slots = 60
   }
   return &SimpleSlidingWindow{
   	window: window,
   	limit:  limit,
   	slots:  slots,
   	counts: make([]int, slots),
   	times:  make([]int64, slots),
   }
}

func (s *SimpleSlidingWindow) Allow() bool {
   s.mu.Lock()
   defer s.mu.Unlock()

   now := time.Now().UnixNano()
   windowNanos := s.window.Nanoseconds()

   // 计算当前slot
   slotWidth := windowNanos / int64(s.slots)
   slotIndex := int((now / slotWidth) % int64(s.slots))
   slotStartTime := (now / slotWidth) * slotWidth

   // 清理过期数据并计算总数
   total := 0
   earliestValidTime := now - windowNanos

   for i := 0; i < s.slots; i++ {
   	// 如果这个slot的时间在有效期内,就累加计数
   	// 这里有个误差,因为每个slot 只记录了一个值,属于这个slot其他的时间戳都没有记录换句话说就是
   	// 所有在同一个槽位内的请求都被视为同一时间,精度只有槽位宽度。
   	if s.times[i] >= earliestValidTime {
   		total += s.counts[i]
   	} else {
   		// 过期了,清空
   		s.counts[i] = 0
   	}
   }

   // 检查是否超限
   if total >= s.limit {
   	return false
   }

   // 如果当前slot的时间戳变了,重置计数
   if s.times[slotIndex] != slotStartTime {
   	s.counts[slotIndex] = 0
   }

   // 记录请求
   s.times[slotIndex] = slotStartTime
   s.counts[slotIndex]++

   return true
}

func main() {
   limiter := NewSimpleSlidingWindow(1*time.Second, 100, 80)

   // 模拟测试
   start := time.Now()
   for i := 0; i < 1000; i++ {
   	elapsed := time.Since(start)
   	elapsedMs := elapsed.Milliseconds()

   	allowed := limiter.Allow()

   	if allowed {
   		//fmt.Printf("+%4dms: 请求%2d: ✓ 通过\n", elapsedMs, i+1)
   	} else {
   		fmt.Printf("+%4dms: 请求%2d: ✗ 拒绝\n", elapsedMs, i+1)
   	}

   	time.Sleep(10 * time.Millisecond)
   }

   // 重点测试:等待1秒后
   fmt.Println("\n等待1秒...")
   time.Sleep(1 * time.Second)

   fmt.Println("继续测试(应该重新开始计数)")
   for i := 0; i < 5; i++ {
   	allowed := limiter.Allow()
   	if allowed {
   		fmt.Printf("请求%2d: ✓ 通过\n", i+1)
   	} else {
   		fmt.Printf("请求%2d: ✗ 拒绝\n", i+1)

   	}
   	time.Sleep(100 * time.Millisecond)
   }
}

1.1.3 漏桶

以固定的速率处理请求,对于的流量只能在桶内等待。如果桶设计的不合理会导致延迟很大,不适合延迟敏感的业务。如果设计的很小就会有大量的请求被丢弃。

其实他的本质就是我们现在说的一个固定大小的队列,后面有个一个已固定速率进行消费的Consumer

go 复制代码
// 这个在工程上很少使用,面试算法也不怎么考
package main

import (
	"fmt"
	"sync"
	"time"
)

// LeakyBucketLimiter 漏桶限流器实现
type LeakyBucketLimiter struct {
	capacity float64   // 桶的容量
	rate     int       // 流出速率
	water    float64   // 桶里当前还有多少水,这个需要使用FLOAT64来保持精度,否则经常出现实际上漏出去水了,但是实际转换成int的时候并没有计算的情况
	lastLeak time.Time // 最后一次流出水的时间,主要用户计算通里还剩多少水
	mu       sync.Mutex
}

// NewLeakyBucketLimiter 创建漏桶限流器
func NewLeakyBucketLimiter(capacity float64, rate int) *LeakyBucketLimiter {
	return &LeakyBucketLimiter{
		capacity: capacity,
		rate:     rate,
		water:    0,
		lastLeak: time.Now(),
	}
}

// leak 漏水
func (l *LeakyBucketLimiter) leak() {
	now := time.Now()
	elapsed := now.Sub(l.lastLeak)

	if elapsed <= 0 {
		return
	}

	// 计算从上一次漏出到现在一共漏出来多少水
	leakAmount := elapsed.Seconds() * float64(l.rate)
	if leakAmount > 0 {
		l.water -= leakAmount
		if l.water < 0 {
			l.water = 0
		}
		l.lastLeak = now
	}
}

// Allow 尝试通过请求
func (l *LeakyBucketLimiter) Allow() bool {
	l.mu.Lock()
	defer l.mu.Unlock()

	l.leak()
	if l.water >= l.capacity {
		return false
	}

	l.water++
	return true
}

// Status 获取当前状态
func (l *LeakyBucketLimiter) Status() (current, capacity float64) {
	l.mu.Lock()
	defer l.mu.Unlock()
	l.leak()
	return l.water, l.capacity
}

// 示例使用
func main() {
	fmt.Println("=== 漏桶限流器演示 ===")

	// 创建限流器:容量5,速率2个/秒
	limiter := NewLeakyBucketLimiter(5.0, 2)

	fmt.Println("\n测试1:快速请求(每100ms一个)")
	for i := 0; i < 20; i++ {
		allowed := limiter.Allow()
		// 目前这里主要是 log 调试日志的需要
		water, capacity := limiter.Status()
		if allowed {
			fmt.Printf("请求%2d: ✅ 通过 (水量: %d/%d)\n", i+1, int(water), int(capacity))
		} else {
			fmt.Printf("请求%2d: ❌ 限流 (水量: %d/%d)\n", i+1, int(water), int(capacity))
		}
		time.Sleep(100 * time.Millisecond)
	}
}

1.1.4 令牌桶

其实go的源码包已经实现了令牌桶算法,源码位置在 golang.org/x/time/rate

  • 令牌桶当流量激增的时候,如果桶内有充裕的令牌,就能抗住并发(令牌桶对激增流量是非常友好的的)
  • 两种实现方式
  • 一个是定期自动填充的(异步新增令牌使得取令牌的逻辑更简单,缺点是异步增加令牌的频率会影响精度,令牌不是平滑增加的,二是有rate字段进行批量增加的<与方案二相比不够平滑>)
  • 惰性增加令牌(只有当流量进来的时候才会更新桶内的令牌数量,高并发的时候平滑的新增令牌)
go 复制代码
package main

import (
	"fmt"
	"sync"
	"time"
)

// TokenBucket 令牌桶结构
type TokenBucket struct {
	capacity   float64   // 桶容量(应该设计成int,但是会有很多的类型转换)
	rate       int       // 每秒生成 token 数
	tokens     float64   // 当前 token 数,高并发情景下,更新token的时候回损失精度,导致token加不上去
	lastFill   time.Time // 上次补充时间
	mu         sync.Mutex
	tickerStop chan struct{} // 停止补充 token
}

// NewTokenBucket 创建令牌桶
func NewTokenBucketWithAutoFill(capacity, rate int) *TokenBucket {
	tb := &TokenBucket{
		capacity:   float64(capacity),
		rate:       rate,
		tokens:     float64(capacity),
		lastFill:   time.Now(),
		tickerStop: make(chan struct{}),
	}

	// 启动定期补充令牌
	go tb.start()
	return tb
}

// NewTokenBucket 创建令牌桶
func NewTokenBucket(capacity, rate int) *TokenBucket {
	tb := &TokenBucket{
		capacity:   float64(capacity),
		rate:       rate,
		tokens:     float64(capacity),
		lastFill:   time.Now(),
		tickerStop: make(chan struct{}),
	}
	return tb
}

// start:每 100ms 补充令牌
func (tb *TokenBucket) start() {
	ticker := time.NewTicker(100 * time.Millisecond)
	for {
		select {
		case <-ticker.C:
			tb.fill()
		case <-tb.tickerStop:
			ticker.Stop()
			return
		}
	}
}

// fill:补充令牌
func (tb *TokenBucket) fill() {
	tb.mu.Lock()
	defer tb.mu.Unlock()

	now := time.Now()
	elapsed := now.Sub(tb.lastFill).Seconds()
	tb.lastFill = now

	// 按经过的时间补充 token(平滑)
	add := elapsed * float64(tb.rate)
	if add > 0 {
		tb.tokens += add
		if tb.tokens > tb.capacity {
			tb.tokens = tb.capacity
		}
	}
	fmt.Println("tokens:", tb.tokens, "add:", add)
}

func (tb *TokenBucket) Allow() bool {
	tb.mu.Lock()
	defer tb.mu.Unlock()

	now := time.Now()
	elapsed := now.Sub(tb.lastFill).Seconds()

	fmt.Println("token:", tb.tokens, "elapsed:", elapsed)

	// 根据时间补充令牌
	tb.tokens += elapsed * float64(tb.rate)
	if tb.tokens > tb.capacity {
		tb.tokens = tb.capacity
	}
	tb.lastFill = now

	// 判断是否还有 token
	if tb.tokens >= 1.0 {
		tb.tokens--
		return true
	}
	return false
}

// AllowWithAutoFill To Allow:请求是否允许
func (tb *TokenBucket) AllowWithAutoFill() bool {
	tb.mu.Lock()
	defer tb.mu.Unlock()

	if tb.tokens > 1 {
		tb.tokens--
		return true
	}
	return false
}

// Stop 停止补充令牌
func (tb *TokenBucket) Stop() {
	close(tb.tickerStop)
}

func main() {
	now := time.Now()
	var passCount = 0
	autoFile := true
	if autoFile {
		bucket := NewTokenBucketWithAutoFill(5, 5) // 容量 5,每秒补 5 个 token

		for i := 0; i < 1000; i++ {
			if bucket.AllowWithAutoFill() {
				passCount++
				fmt.Println(i+1, "passed")
			} else {
				fmt.Println(i+1, "blocked")
			}
			time.Sleep(15 * time.Millisecond)
		}

		bucket.Stop()
		fmt.Println("done", passCount)
		fmt.Println("time :", time.Now().Sub(now).Seconds())
	} else {
		bucket := NewTokenBucket(5, 5)

		for i := 0; i < 100; i++ {
			if bucket.Allow() {
				passCount++
				fmt.Println(i+1, "passed")
			} else {
				fmt.Println(i+1, "blocked")
			}
			time.Sleep(150 * time.Millisecond)
		}
		fmt.Println("done", passCount)
		fmt.Println("time :", time.Now().Sub(now).Seconds())
	}
}
相关推荐
星川皆无恙2 小时前
从“盲人摸象“到“全面感知“:多模态学习的进化之路
大数据·人工智能·python·深度学习·学习
_Kayo_2 小时前
node.js 学习笔记4
笔记·学习·node.js
汉堡包0012 小时前
【网安基础】--Spring/Spring Boot RCE 解析与 Shiro 反序列化漏洞的关联(包括简易加密方式梳理)
学习·安全·spring·信息安全
小龙2 小时前
【学习笔记】PyTorch 中.pth文件格式解析与可视化
人工智能·pytorch·笔记·学习
Gavin在路上2 小时前
AI学习之AI应用框架选型篇
人工智能·学习
暖阳之下2 小时前
学习周报二十八
学习
我命由我123452 小时前
Photoshop - Photoshop 工具栏(47)油漆桶工具
学习·ui·职场和发展·求职招聘·职场发展·学习方法·photoshop
蒙奇D索大2 小时前
【数据结构】排序算法精讲 | 快速排序全解:分治思想、核心步骤与示例演示
数据结构·笔记·学习·考研·算法·排序算法·改行学it
iconball3 小时前
个人用云计算学习笔记 --29 华为云网络云服务
运维·笔记·学习·华为云·云计算