代码
go
// Copyright 2021 ByteDance Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package asynccache
import (
"fmt"
"log"
"sync"
"sync/atomic"
"time"
sf "golang.org/x/sync/singleflight"
)
// Options controls the behavior of AsyncCache.
type Options struct {
RefreshDuration time.Duration
Fetcher func(key string) (interface{}, error)
// If EnableExpire is true, ExpireDuration MUST be set.
EnableExpire bool
ExpireDuration time.Duration
ErrorHandler func(key string, err error)
ChangeHandler func(key string, oldData, newData interface{})
DeleteHandler func(key string, oldData interface{})
IsSame func(key string, oldData, newData interface{}) bool
ErrLogFunc func(str string)
}
// AsyncCache .
type AsyncCache interface {
// SetDefault sets the default value of given key if it is new to the cache.
// It is useful for cache warming up.
// Param val should not be nil.
SetDefault(key string, val interface{}) (exist bool)
// Get tries to fetch a value corresponding to the given key from the cache.
// If error occurs during the first time fetching, it will be cached until the
// sequential fetching triggered by the refresh goroutine succeed.
Get(key string) (val interface{}, err error)
// GetOrSet tries to fetch a value corresponding to the given key from the cache.
// If the key is not yet cached or error occurs, the default value will be set.
GetOrSet(key string, defaultVal interface{}) (val interface{})
// Dump dumps all cache entries.
// This will not cause expire to refresh.
Dump() map[string]interface{}
// DeleteIf deletes cached entries that match the `shouldDelete` predicate.
DeleteIf(shouldDelete func(key string) bool)
// Close closes the async cache.
// This should be called when the cache is no longer needed, or may lead to resource leak.
Close()
}
// asyncCache .
type asyncCache struct {
sfg sf.Group
opt Options
data sync.Map
}
type tickerType int
const (
refreshTicker tickerType = iota
expireTicker
)
type sharedTicker struct {
sync.Mutex
started bool
stopChan chan bool
ticker *time.Ticker
caches map[*asyncCache]struct{}
}
var (
// 共用 ticker
refreshTickerMap, expireTickerMap sync.Map
)
type entry struct {
val atomic.Value
expire int32 // 0 means useful, 1 will expire
err Error
}
func (e *entry) Store(x interface{}, err error) {
if x != nil {
e.val.Store(x)
} else {
e.val = atomic.Value{}
}
e.err.Store(err)
}
func (e *entry) Touch() {
atomic.StoreInt32(&e.expire, 0)
}
// NewAsyncCache creates an AsyncCache.
func NewAsyncCache(opt Options) AsyncCache {
c := &asyncCache{
sfg: sf.Group{},
opt: opt,
}
if c.opt.ErrLogFunc == nil {
c.opt.ErrLogFunc = func(str string) {
log.Println(str)
}
}
if c.opt.EnableExpire {
if c.opt.ExpireDuration == 0 {
panic("asynccache: invalid ExpireDuration")
}
ti, _ := expireTickerMap.LoadOrStore(c.opt.ExpireDuration,
&sharedTicker{caches: make(map[*asyncCache]struct{}), stopChan: make(chan bool, 1)})
et := ti.(*sharedTicker)
et.Lock()
et.caches[c] = struct{}{}
if !et.started {
et.started = true
et.ticker = time.NewTicker(c.opt.ExpireDuration)
go et.tick(et.ticker, expireTicker)
}
et.Unlock()
}
ti, _ := refreshTickerMap.LoadOrStore(c.opt.RefreshDuration,
&sharedTicker{caches: make(map[*asyncCache]struct{}), stopChan: make(chan bool, 1)})
rt := ti.(*sharedTicker)
rt.Lock()
rt.caches[c] = struct{}{}
if !rt.started {
rt.started = true
rt.ticker = time.NewTicker(c.opt.RefreshDuration)
go rt.tick(rt.ticker, refreshTicker)
}
rt.Unlock()
return c
}
// SetDefault sets the default value of given key if it is new to the cache.
func (c *asyncCache) SetDefault(key string, val interface{}) bool {
ety := &entry{}
ety.Store(val, nil)
actual, exist := c.data.LoadOrStore(key, ety)
if exist {
actual.(*entry).Touch()
}
return exist
}
// Get tries to fetch a value corresponding to the given key from the cache.
// If error occurs during in the first time fetching, it will be cached until the
// sequential fetchings triggered by the refresh goroutine succeed.
func (c *asyncCache) Get(key string) (val interface{}, err error) {
var ok bool
val, ok = c.data.Load(key)
if ok {
e := val.(*entry)
e.Touch()
return e.val.Load(), e.err.Load()
}
val, err, _ = c.sfg.Do(key, func() (v interface{}, e error) {
v, e = c.opt.Fetcher(key)
ety := &entry{}
ety.Store(v, e)
c.data.Store(key, ety)
return
})
return
}
// GetOrSet tries to fetch a value corresponding to the given key from the cache.
// If the key is not yet cached or fetching failed, the default value will be set.
func (c *asyncCache) GetOrSet(key string, def interface{}) (val interface{}) {
if v, ok := c.data.Load(key); ok {
e := v.(*entry)
if e.err.Load() != nil {
ety := &entry{}
ety.Store(def, nil)
c.data.Store(key, ety)
return def
}
e.Touch()
return e.val.Load()
}
val, _, _ = c.sfg.Do(key, func() (interface{}, error) {
v, e := c.opt.Fetcher(key)
if e != nil {
v = def
}
ety := &entry{}
ety.Store(v, nil)
c.data.Store(key, ety)
return v, nil
})
return
}
// Dump dumps all cached entries.
func (c *asyncCache) Dump() map[string]interface{} {
data := make(map[string]interface{})
c.data.Range(func(key, val interface{}) bool {
k, ok := key.(string)
if !ok {
c.opt.ErrLogFunc(fmt.Sprintf("invalid key: %v, type: %T is not string", k, k))
c.data.Delete(key)
return true
}
data[k] = val.(*entry).val.Load()
return true
})
return data
}
// DeleteIf deletes cached entries that match the `shouldDelete` predicate.
func (c *asyncCache) DeleteIf(shouldDelete func(key string) bool) {
c.data.Range(func(key, value interface{}) bool {
s := key.(string)
if shouldDelete(s) {
if c.opt.DeleteHandler != nil {
go c.opt.DeleteHandler(s, value)
}
c.data.Delete(key)
}
return true
})
}
// Close stops the background goroutine.
func (c *asyncCache) Close() {
// close refresh ticker
ti, _ := refreshTickerMap.Load(c.opt.RefreshDuration)
rt := ti.(*sharedTicker)
rt.Lock()
delete(rt.caches, c)
if len(rt.caches) == 0 {
rt.stopChan <- true
rt.started = false
}
rt.Unlock()
if c.opt.EnableExpire {
// close expire ticker
ti, _ := expireTickerMap.Load(c.opt.ExpireDuration)
et := ti.(*sharedTicker)
et.Lock()
delete(et.caches, c)
if len(et.caches) == 0 {
et.stopChan <- true
et.started = false
}
et.Unlock()
}
}
// tick .
// pass ticker but not use t.ticker directly is to ignore race.
func (t *sharedTicker) tick(ticker *time.Ticker, tt tickerType) {
var wg sync.WaitGroup
defer ticker.Stop()
for {
select {
case <-ticker.C:
t.Lock()
for c := range t.caches {
wg.Add(1)
go func(c *asyncCache) {
defer wg.Done()
if tt == expireTicker {
c.expire()
} else {
c.refresh()
}
}(c)
}
wg.Wait()
t.Unlock()
case stop := <-t.stopChan:
if stop {
return
}
}
}
}
func (c *asyncCache) expire() {
c.data.Range(func(key, value interface{}) bool {
k, ok := key.(string)
if !ok {
c.opt.ErrLogFunc(fmt.Sprintf("invalid key: %v, type: %T is not string", k, k))
c.data.Delete(key)
return true
}
e, ok := value.(*entry)
if !ok {
c.opt.ErrLogFunc(fmt.Sprintf("invalid key: %v, type: %T is not entry", k, value))
c.data.Delete(key)
return true
}
if !atomic.CompareAndSwapInt32(&e.expire, 0, 1) {
if c.opt.DeleteHandler != nil {
go c.opt.DeleteHandler(k, value)
}
c.data.Delete(key)
}
return true
})
}
func (c *asyncCache) refresh() {
c.data.Range(func(key, value interface{}) bool {
k, ok := key.(string)
if !ok {
c.opt.ErrLogFunc(fmt.Sprintf("invalid key: %v, type: %T is not string", k, k))
c.data.Delete(key)
return true
}
e, ok := value.(*entry)
if !ok {
c.opt.ErrLogFunc(fmt.Sprintf("invalid key: %v, type: %T is not entry", k, value))
c.data.Delete(key)
return true
}
newVal, err := c.opt.Fetcher(k)
if err != nil {
if c.opt.ErrorHandler != nil {
go c.opt.ErrorHandler(k, err)
}
if e.err.Load() != nil {
e.err.Store(err)
}
return true
}
if c.opt.IsSame != nil && !c.opt.IsSame(k, e.val.Load(), newVal) {
if c.opt.ChangeHandler != nil {
go c.opt.ChangeHandler(k, e.val.Load(), newVal)
}
}
e.Store(newVal, err)
return true
})
}
流程图
设置过期定时器 设置刷新定时器 是 否 是 有错误 无错误 否 是 否 是 开始 NewAsyncCache 初始化 AsyncCache 实例 设置过期定时器 设置刷新定时器 返回 AsyncCache 实例 SetDefault 设置默认值 为给定的键值对中的值设定默认值 如果在缓存池中就刷新键值对过期时间 Get 获取缓存 GetOrSet 获取或设置缓存 Dump 转储缓存 DeleteIf 删除缓存 关闭 AsyncCache 缓存存在? 加载缓存值 触发 Fetcher 获取值 singleflight 处理请求 存储或更新缓存 缓存存在? 检查缓存错误 设置默认值 加载缓存值 设置默认值 Range 遍历缓存 将缓存存到另一个图中 Range 遍历缓存 满足删除条件? 删除缓存项 EnableExpire? 停止过期定时器 停止刷新定时器 资源清理
其中的refreshTickerMap, expireTickerMap存放的是每个特定的刷新时间/过期时间对应的sharedTicker
每个sharedTicker负责多个相同刷新时间/过期时间的缓存池的更新/过期操作
测试代码
go
package main
import (
"asynccache/asynccache"
"fmt"
"log"
"time"
)
// 模拟一个简单的数据获取函数
func simpleFetcher(key string) (interface{}, error) {
log.Printf("Fetching data for key: %s\n", key)
time.Sleep(500 * time.Millisecond) // 模拟数据获取的延迟
return fmt.Sprintf("value_for_%s", key), nil
}
// 打印缓存中所有的数据观察
func showAllCacheData(cache asynccache.AsyncCache) {
cacheData := cache.Dump() // 导出cache数据
// cacheData map[string]interface{} 类型为interface{},代表任意类型
for k, v := range cacheData {
// %s代表匹配字符串,%+v代表构造任意类型
log.Printf("Fetching data for key: %s, value: %+v", k, v)
}
}
func main() {
// 创建一个 AsyncCache 实例
cache := asynccache.NewAsyncCache(asynccache.Options{
RefreshDuration: 2 * time.Second, // 每2秒刷新一次
Fetcher: simpleFetcher,
EnableExpire: true,
ExpireDuration: 5 * time.Second, // 每5秒过期一次
ErrorHandler: func(key string, err error) {
log.Printf("Error fetching key %s: %v\n", key, err)
},
ChangeHandler: func(key string, oldData, newData interface{}) {
log.Printf("Key %s changed from %v to %v\n", key, oldData, newData)
},
DeleteHandler: func(key string, oldData interface{}) {
log.Printf("Key %s expired with value %v\n", key, oldData)
},
})
// 设置默认值
cache.SetDefault("key1", "default_value_for_key1")
// 观察缓存数据
showAllCacheData(cache)
// 获取值
val, err := cache.Get("key1")
if err != nil {
log.Printf("Error getting key1: %v\n", err)
} else {
log.Printf("Got value for key1: %v\n", val)
}
// 使用 GetOrSet
val = cache.GetOrSet("key2", "default_value_for_key2")
log.Printf("Got value for key2: %v\n", val)
// 等待刷新和过期
time.Sleep(6 * time.Second)
// 再次获取值
val, err = cache.Get("key1")
if err != nil {
log.Printf("Error getting key1 after refresh: %v\n", err)
} else {
log.Printf("Got value for key1 after refresh: %v\n", val)
}
// 删除特定的缓存项
cache.DeleteIf(func(key string) bool {
return key == "key2"
})
// 关闭缓存
cache.Close()
// 尝试获取值,应该会失败
val, err = cache.Get("key1")
if err != nil {
log.Printf("Error getting key1 after close: %v\n", err)
} else {
log.Printf("Got value for key1 after close: %v\n", val)
}
}