[AI编程从入门到入土] 装饰器decorator
个人导航
知乎:https://www.zhihu.com/people/byzh_rc
CSDN:https://blog.csdn.net/qq_54636039
注:本文仅对所述内容做了框架性引导,具体细节可查询其余相关资料or源码
参考文章:各方资料
文章目录
- [[AI编程从入门到入土] 装饰器decorator](#[AI编程从入门到入土] 装饰器decorator)
- 个人导航
- 装饰器decorator
-
-
-
- [1. registration decorator - 情况1](#1. registration decorator - 情况1)
- [2. registration decorator - 情况2](#2. registration decorator - 情况2)
-
-
- 注册表registry
- 装饰器工厂
- [decorator 分类](#decorator 分类)
- [标准 decorator 模板](#标准 decorator 模板)
-
-
-
- [1. 类注册decorator](#1. 类注册decorator)
- [2. 计时器decorator](#2. 计时器decorator)
- [3. 参数检查decorator](#3. 参数检查decorator)
- [4. 类decorator](#4. 类decorator)
-
-
- AI训练常用decorator
装饰器decorator
- registration decorator: 只注册, 不改行为
- wrapper decorator: 添加行为 (使用wraps)
- ...
1. registration decorator - 情况1
py
@xxx
def test():
pass
等价于
py
def test():
pass
test = xxx(test)
test 函数先被创建 -> 然后作为参数传给 xxx -> xxx 返回一个新对象 -> 再覆盖原来的 test
此时xxx接收到的参数是func
2. registration decorator - 情况2
py
@xxx(abc)
def test():
pass
等价于
py
def test():
pass
test = xxx(abc)(test)
test 函数先被创建 -> 然后作为参数传给 xxx(abc) -> xxx(abc) 返回一个新对象 -> 再覆盖原来的 test
此时xxx接收到的参数是abc
注册表registry
注册表本质是字典, 装饰器的第一个参数是func:
py
# 创建注册表
registry = {}
# 创建装饰器
def register(func):
registry[func.__name__] = func
return func
使用:
py
@register
def hello():
print("hello")
@register
def world():
print("world")
执行后, registry就变成:
py
{
"hello": <function hello>,
"world": <function world>,
}
就可以如此调用:
py
registry["hello"]()
装饰器工厂
py
registry = {}
def register(name):
def decorator(func):
registry[name] = func
return func
return decorator
@register("add")
def func1():
print("111")
# 等价于: func1 = register("add")(func1)
@register("sub")
def func2():
print("222")
# 等价于: func2 = register("sub")(func2)
decorator 分类
| 类型 | 作用 |
|---|---|
| registration | 注册 |
| wrapper | 包装行为 |
| cache | 缓存 |
| retry | 重试 |
| permission | 权限 |
| validation | 参数校验 |
| singleton | 单例 |
| async | 异步 |
| transaction | 事务 |
| logging | 日志 |
| injection | 依赖注入 |
标准 decorator 模板
py
from functools import wraps
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# before
result = func(*args, **kwargs)
# after
return result
return wrapper
1. 类注册decorator
py
MODELS = {}
def register(name):
def decorator(cls):
MODELS[name] = cls
return cls
return decorator
@register("resnet")
class resnet:
...
@register("lstm")
class lstm:
...
2. 计时器decorator
py
import time
from functools import wraps
def timer(func):
@wraps(func)
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
print(f"{func.__name__} 耗时: {end - start:.4f}s")
return result
return wrapper
3. 参数检查decorator
py
from functools import wraps
def check_non_negative(func):
@wraps(func)
def wrapper(x):
if x < 0:
raise ValueError("不能小于0")
return func(x)
return wrapper
4. 类decorator
py
from functools import wraps
def enhance(cls):
# 动态增加类属性
cls.version = "1.0"
cls.author = "byzh"
cls.category = "AI"
# 动态增加实例方法
def info(self):
print("========== INFO ==========")
print("class :", cls.__name__)
print("name :", self.name)
print("version :", cls.version)
print("author :", cls.author)
print("category :", cls.category)
cls.info = info
return cls
AI训练常用decorator
py
def benchmark(func):
@wraps(func)
def wrapper(*args, **kwargs):
import tracemalloc
import time
tracemalloc.start()
start = time.time()
result = func(*args, **kwargs)
current, peak = tracemalloc.get_traced_memory()
end = time.time()
print(f"time={end-start}")
print(f"peak={peak/1024/1024:.2f}MB")
return result
return wrapper