目录
- [1. Optimizer类](#1. Optimizer类)
- [2. Optimizer概览](#2. Optimizer概览)
- [3. 源码解析](#3. 源码解析)
-
- [3.1 构造方法](#3.1 构造方法)
-
- [3.1.1 全局设置情形](#3.1.1 全局设置情形)
- [3.1.2 局部设置情形](#3.1.2 局部设置情形)
- [3.1.3 覆盖测试](#3.1.3 覆盖测试)
- [3.1.4 逐行讲解](#3.1.4 逐行讲解)
- [3.2 add_param_group](#3.2 add_param_group)
- [3.3 step](#3.3 step)
- [3.4 zero_grad](#3.4 zero_grad)
- [3.5 self.state](#3.5 self.state)
- [3.6 state_dict](#3.6 state_dict)
- [3.7 load_state_dict](#3.7 load_state_dict)
- [4. SGD Optimizer](#4. SGD Optimizer)
- [5. 极简版Optimizer源码](#5. 极简版Optimizer源码)
- [6. 自定义你的Optimizer](#6. 自定义你的Optimizer)
- Ref
1. Optimizer类
PyTorch的 Optimizer
类是深度学习模型中用于管理和更新模型参数的基类。它负责根据损失函数的梯度信息调整模型的参数,使模型逐步逼近最佳状态。Optimizer
类通过实现一些核心方法,如 step()
,来执行参数更新过程,而 zero_grad()
方法则用于清除模型中所有参数的梯度。
每个优化器会存储参数组和相关的状态,例如学习率、动量等。不同的优化器(如SGD、Adam等)继承自 Optimizer
类,并根据各自的算法特点实现了不同的参数更新策略。此外,Optimizer
类还允许用户在初始化时指定超参数,如学习率等,这些超参数会影响参数的更新方式。
本文将详细讲解 Optimizer
类的源码,并以SGD优化器为例介绍如何自定义一个自己的优化器。
2. Optimizer概览
📝 本文在讲解源码时,只考虑源码的简化版本,而不考虑完整的源码。
除了构造方法外,Optimizer常用的几个方法如下:
python
class Optimizer:
def state_dict(self) -> Dict[str, Any]:
"""返回优化器的状态字典,保存当前的优化器状态以便之后恢复。"""
...
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""加载先前保存的状态字典,恢复优化器的状态。"""
...
def zero_grad(self, set_to_none: bool = True) -> None:
"""将优化器中所有参数的梯度清零,通常在每次反向传播前调用。"""
...
def step(self) -> None:
"""执行一步优化更新,用于根据梯度更新模型参数。"""
raise NotImplementedError
def add_param_group(self, param_group: Dict[str, Any]) -> None:
"""向优化器中添加新的参数组,用于管理不同的参数组(如不同学习率等)。"""
...
在自定义优化器时,必须继承 Optimizer
类,并实现 step()
方法,否则将会报错。
optimizer.py
文件中还定义了若干类型别名,如下:
python
Args: TypeAlias = Tuple[Any, ...]
Kwargs: TypeAlias = Dict[str, Any]
StateDict: TypeAlias = Dict[str, Any]
ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
Args
:通常用来表示函数中不定数量的位置参数。Kwargs
:通常用来表示函数中不定数量的关键字参数,键为参数名,值为相应的参数值。StateDict
:通常用于保存模型参数和优化器的状态信息。ParamsT
:是一个torch.Tensor
的可迭代对象,或者是包含键值对的字典的可迭代对象。
3. 源码解析
3.1 构造方法
Optimizer
的构造方法如下:
python
class Optimizer:
r"""Base class for all optimizers.
.. warning::
Parameters need to be specified as collections that have a deterministic
ordering that is consistent between runs. Examples of objects that don't
satisfy those properties are sets and iterators over values of dictionaries.
Args:
params (iterable): an iterable of :class:`torch.Tensor` s or
:class:`dict` s. Specifies what Tensors should be optimized.
defaults: (dict): a dict containing default values of optimization
options (used when a parameter group doesn't specify them).
"""
def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None:
self.defaults = defaults
if isinstance(params, torch.Tensor):
raise TypeError(
"params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " + torch.typename(params)
)
self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)
self.param_groups: List[Dict[str, Any]] = []
param_groups = list(params)
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
if not isinstance(param_groups[0], dict):
param_groups = [{"params": param_groups}]
for param_group in param_groups:
self.add_param_group(param_group)
从源码可以看出,Optimizer
有两个形参和两个属性 (self.defaults
不算)。
形参:
params
:由一系列Tensor组成的迭代器或是由一系列字典组成的迭代器。通常是模型的参数。例如model.parameters()
。defaults
:一个键为字符串的字典。通常是和优化算法相关的全局 超参数。例如lr
、momentum
等。下文会解释为什么这里是「全局」。
属性:
state
:一个键为Tensor,值为字典的字典。用来存储每个模型参数对应的临时状态,例如momentum
。param_groups
:一个列表,其中的每一个元素都是一个键为字符串的字典。每一个元素对应了一个param_group
。
看到这里,可能你仍然不明白 param_groups
是什么。既然它是复数形式,说明它是由一个个 param_group
组成的,每一个 param_group
的类型是 Dict[str, Any]
,因此 param_groups
的类型就是 List[Dict[str, Any]]
。
那么什么是 param_group
呢?我们知道Transformer模型通常由多个layer堆叠而成,绝大部分情况下,整个模型的训练会采用同一个学习率。但某些特殊场景下,我们可能希望不同的layer使用不同的学习率,此时就会涉及到一个个 param_group
了。
- 对于前者,形参
params
的类型为Iterable[Tensor]
,因为整个模型会共享同一套优化器参数,所以只需要指定全局优化器参数defaults
即可,此时param_groups
是一个长度为1的列表。 - 而对于后者,
params
的类型为Iterable[Dict[str, Any]]
,每一个字典包含了layer的参数和对应的局部优化器参数 。注意此时依然可以指定全局优化器参数,例如我们希望不同的layer使用不同的学习率,但希望所有的layer都使用同一个动量。此时param_groups
是一个长度大于1的列表。
我们来看更具体的例子。
3.1.1 全局设置情形
python
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 3, bias=False)
self.fc2 = nn.Linear(3, 3, bias=False)
self.fc3 = nn.Linear(3, 1, bias=False)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
model = MLP()
# =============================
# 情况 1: 所有层使用同一套优化器参数
# =============================
# 此时,params 是 Iterable[Tensor],直接传递模型的所有参数
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
print(optimizer.param_groups)
输出(做了简化处理):
python
[
{
'params': [fc1_tensor, fc2_tensor, fc3_tensor],
'lr': 0.01,
'momentum': 0.9,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'maximize': False,
'foreach': None,
'differentiable': False,
'fused': None
}
]
可以看到此时 param_groups
是一个长度为1的列表。字典中的 params
对应了模型的参数,因为没有设置 bias
,所以总共有三个张量。字典中剩余的键对应了优化器的超参数。
3.1.2 局部设置情形
python
# ===============================
# 情况 2: 不同的层使用不同的优化器参数
# ===============================
# 此时,params 是 Iterable[Dict[str, Any]],可以为不同的层设置不同的学习率
optimizer = optim.SGD([
{'params': model.fc1.parameters(), 'lr': 0.001}, # 第1层学习率为 0.001
{'params': model.fc2.parameters(), 'lr': 0.01}, # 第2层学习率为 0.01
{'params': model.fc3.parameters(), 'lr': 0.1} # 第3层学习率为 0.1
])
print(optimizer.param_groups)
输出(做了简化处理):
python
[
{
'params': [fc1_tensor],
'lr': 0.001,
'momentum': 0,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'maximize': False,
'foreach': None,
'differentiable': False,
'fused': None
},
{
'params': [fc2_tensor],
'lr': 0.01,
'momentum': 0,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'maximize': False,
'foreach': None,
'differentiable': False,
'fused': None
},
{
'params': [fc3_tensor],
'lr': 0.1,
'momentum': 0,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'maximize': False,
'foreach': None,
'differentiable': False,
'fused': None
}
]
此时 param_groups
是一个长度为3的列表。列表中的每一个字典就是一个 param_group
,存储了相应的layer参数和局部优化器参数。
⚠️ 这里使用"局部优化器参数"这一术语并不准确。事实上,"局部"和"全局"的概念仅在优化器实例化之前存在。一旦优化器实例化,全局参数将会逐步写入到每一个
param_group
中,例如momentum
和dampening
等未显式定义的参数,实质上属于全局优化器参数,并会通过add_param_group
方法自动地写入到每一个param_group
中。因此,在优化器实例化后,每个param_group
都拥有一套完整且独立的参数配置。
到目前为止,我们可以做一个简单总结。param_groups
是一个元素为字典的列表。当传入的 params
为由Tensor构成的迭代器时,此时 param_groups
的长度为1,即只含有一个 param_group
。当传入的 params
为由字典构成的迭代器时,此时 param_groups
的长度为 len(params)
。
param_groups
中的所有字典的键完全相同,均形如:
python
param_group = {
'params': [tensor_1, tensor_2, ...], # 待优化的模型参数
**defaults # 全局优化器参数
}
但所有字典的值却不尽相同。
3.1.3 覆盖测试
之前我们只考虑了「仅全局」和「仅局部」的情形,如果我们手动设置全局优化器参数,并且它和某些 param_group
中的局部优化器参数冲突了,那么这个全局的会覆盖掉局部的吗?
python
optimizer = optim.SGD([
{'params': model.fc1.parameters(), 'lr': 0.001, 'momentum': 0.3},
{'params': model.fc2.parameters(), 'lr': 0.01},
{'params': model.fc3.parameters(), 'lr': 0.1, 'nesterov': True}
], momentum=0.9, nesterov=False)
print(optimizer.param_groups)
输出:
python
[
{
'params': [fc1_tensor],
'lr': 0.001,
'momentum': 0.3,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'maximize': False,
'foreach': None,
'differentiable': False,
'fused': None
},
{
'params': [fc2_tensor],
'lr': 0.01,
'momentum': 0.9,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'maximize': False,
'foreach': None,
'differentiable': False,
'fused': None
},
{
'params': [fc3_tensor],
'lr': 0.1,
'nesterov': True,
'momentum': 0.9,
'dampening': 0,
'weight_decay': 0,
'maximize': False,
'foreach': None,
'differentiable': False,
'fused': None
}
]
由此可以得出结论:全局优化器参数不会覆盖掉局部优化器参数。
3.1.4 逐行讲解
现在我们已经对 params
、defaults
、param_groups
这三个变量有了足够的了解(state
会放在下文讲解),接下来我们逐行剖析构造方法。
python
def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None:
# 将defaults设置成属性以方便在后续的add_param_group方法中使用
self.defaults = defaults
# params必须是关于tensor或dict的可迭代对象
if isinstance(params, torch.Tensor):
raise TypeError(
"params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " + torch.typename(params)
)
# 初始化两大重要属性:state和param_groups
self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)
self.param_groups: List[Dict[str, Any]] = []
# 将迭代器转化成列表,此时要么是List[Tensor]要么是List[Dict]
param_groups = list(params)
# 必须非空
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
# 如果不是List[Dict],说明进行的是全局设置,此时param_groups是List[Tensor]
# 类型,代表模型的所有参数。然后将其转变成List[Dict]类型,以达到格式统一的目的
if not isinstance(param_groups[0], dict):
param_groups = [{"params": param_groups}]
# 将每一个param_group经过相关处理后添加到self.param_groups中
# 这里也会将全局优化器参数defaults注入到每一个param_group中
for param_group in param_groups:
self.add_param_group(param_group)
通常来讲,params
会接收 model.parameters()
作为输入,如果不是,那么 params
必须是由字典构成的列表,且每一个字典必须含有 params
这个键,对应的值是模型的部分参数。
绝大多数情况下,我们认为下式成立:
model.parameters() = ⋃ i = 1 k param_groups [ i ] [ " params " ] \text{model.parameters()}=\bigcup_{i=1}^k \text{param\_groups}[i]["\text{params}"] model.parameters()=i=1⋃kparam_groups[i]["params"]
k k k 是 param_group
的个数,且 param_groups[i]["params"]
两两互不相交。
如果涉及到冻结模型的一部分参数,仅训练剩余的参数,那么上式就不再成立了。
3.2 add_param_group
有了3.1小节的基础后,这里直接逐行讲解源码。
python
def add_param_group(self, param_group: Dict[str, Any]) -> None:
r"""Add a param group to the :class:`Optimizer`'s `param_groups`.
This can be useful when fine tuning a pre-trained network as frozen layers can be made
trainable and added to the :class:`Optimizer` as training progresses.
Args:
param_group (dict): Specifies what Tensors should be optimized along with group
specific optimization options.
"""
# 确保传入的param_group一定是一个字典
if not isinstance(param_group, dict):
raise TypeError(f"param_group must be a dict, but got {type(param_group)}")
# 获取该group中的模型参数部分,然后将其转变为List[Tensor]类型
params = param_group["params"]
if isinstance(params, torch.Tensor):
param_group["params"] = [params]
elif isinstance(params, set):
raise TypeError(
"optimizer parameters need to be organized in ordered collections, but "
"the ordering of tensors in sets will change between runs. Please use a list instead."
)
else:
param_group["params"] = list(params)
# 参数检查
for param in param_group["params"]:
# 确保所有的参数必须都是Tensor,否则无法优化
if not isinstance(param, torch.Tensor):
raise TypeError(
"optimizer can only optimize Tensors, "
"but one of the params is " + torch.typename(param)
)
# 确保所有的参数都是叶子节点
if not (param.is_leaf or param.retains_grad):
raise ValueError("can't optimize a non-leaf Tensor")
# 将全局优化器参数注入到当前的group中,setdefault保证了这一过程不会覆盖掉局部优化器参数
for name, default in self.defaults.items():
param_group.setdefault(name, default)
# 检查是否存在重复的参数
# 目前出现重复的参数并不会报错
params = param_group["params"]
if len(params) != len(set(params)):
warnings.warn(
"optimizer contains a parameter group with duplicate parameters; "
"in future, this will cause an error; "
"see github.com/pytorch/pytorch/issues/40967 for more information",
stacklevel=3,
)
# 判断当前的group中是否有参数和已经添加过的group中的参数重复
# 两个集合交集为空,isdisjoint()返回True
param_set: Set[torch.Tensor] = set()
for group in self.param_groups:
param_set.update(set(group["params"]))
if not param_set.isdisjoint(set(param_group["params"])):
raise ValueError("some parameters appear in more than one parameter group")
# 将当前的group添加到self.param_groups中
self.param_groups.append(param_group)
add_param_group
源码看似复杂,但归根结底也就那么几行代码是真正起到作用的,这里给出一个简化版本:
python
def add_param_group(self, param_group: Dict[str, Any]) -> None:
params = param_group["params"]
param_group["params"] = [params] if isinstance(params, torch.Tensor) else list(params)
for name, default in self.defaults.items():
param_group.setdefault(name, default)
self.param_groups.append(param_group)
3.3 step
Optimizer
类并没有实现 step()
方法:
python
def step(self) -> None:
"""Performs a single optimization step (parameter update)."""
raise NotImplementedError
自定义优化器时,需要继承 Optimizer
类,并实现该方法,否则会报错。
因为 step()
不返回任何值,所以需要实现模型参数的原地更新。
3.4 zero_grad
如下是简化版的源码:
python
def zero_grad(self, set_to_none: bool = True) -> None:
r"""Resets the gradients of all optimized :class:`torch.Tensor` s.
Args:
set_to_none (bool): instead of setting to zero, set the grads to None.
This will in general have lower memory footprint, and can modestly improve performance.
However, it changes certain behaviors. For example:
1. When the user tries to access a gradient and perform manual ops on it,
a None attribute or a Tensor full of 0s will behave differently.
2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
are guaranteed to be None for params that did not receive a gradient.
3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
(in one case it does the step with a gradient of 0 and in the other it skips
the step altogether).
"""
# 遍历每一个group,然后遍历group中的每一个参数
for group in self.param_groups:
for p in group["params"]:
# 如果p的梯度不为None,说明需要清空
if p.grad is not None:
if set_to_none:
p.grad = None
else:
# grad_fn 表示该张量是由某个操作生成的
# 因此将该张量从计算图中分离出来
if p.grad.grad_fn is not None:
p.grad.detach_()
else:
# 不在计算图中,原地关闭梯度以防止后续追踪
p.grad.requires_grad_(False)
# 清零梯度
p.grad.zero_()
从注释可以看出,当梯度被设置为 None
时,梯度张量会被释放,从而减少内存占用。而如果梯度被清零(即将其所有元素设置为 0),梯度张量的内存仍然会被保留。
由于 set_to_none
默认为 True
,因此 zero_grad
源码可以进一步简化:
python
def zero_grad(self) -> None:
for group in self.param_groups:
for p in group["params"]:
if p.grad is not None:
p.grad = None
3.5 self.state
Optimizer
中有两大重要属性:state
、param_groups
。先前我们已经了解了 param_groups
,现在来看 state
。
我们已经知道构造方法中会通过多次调用 add_param_group
来初始化 self.param_groups
,但截止目前,似乎并没有方法能够初始化 self.state
,那它是怎么初始化的?以及它到底"长什么样"呢?
self.state
用来存储与模型参数相关的临时状态 。对于SGD with momentum而言,每个参数都需要维护一个动量。对于Adam而言,每个参数不仅需要维护一个动量(一阶矩),还需要维护一个平方梯度(二阶矩)。很显然,对于一个待优化的Tensor,它的临时状态和它的形状是相同的 ,并且对于该Tensor,可能有多个临时状态需要维护,每个临时状态都有一个自己的名字,由此推测 self.state
的类型应当是 Dict[Tensor, Dict[str, Tensor]]
,这与源码中声明的相同:
python
self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)
即 self.state
的键是一个Tensor,值是一个字典。字典存储了优化该Tensor的一些临时状态,键是临时状态的名称,值是相应的状态。
以SGD优化器为例,进行一次单步更新,然后查看它的 state
属性:
python
import torch
import torch.nn as nn
import torch.optim as optim
class SimpleMLP(nn.Module):
def __init__(self):
super(SimpleMLP, self).__init__()
self.fc1 = nn.Linear(2, 3, bias=False)
self.fc2 = nn.Linear(3, 1, bias=False)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
model = SimpleMLP()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
inputs = torch.randn(2)
target = torch.randn(1)
output = model(inputs)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(optimizer.state)
输出:
python
defaultdict(<class 'dict'>, {
Parameter containing:
tensor([[-0.4128, -0.6015],
[-0.3628, 0.6162],
[ 0.4367, 0.3409]], requires_grad=True): {
'momentum_buffer': tensor([[ 0.0000, 0.0000],
[-0.0052, 0.0581],
[-0.0130, 0.1446]])
},
Parameter containing:
tensor([[0.5394, 0.1077, 0.2748]], requires_grad=True): {
'momentum_buffer': tensor([[0.0000, 0.3398, 0.1585]])
}
})
可以得知,对于SGD优化器而言,每个待优化的Tensor只需要维护一个临时状态:momentum_buffer
,即当前的动量。且临时状态的形状与待优化的Tensor相同。
将SGD换成Adam,再来看看结果:
python
defaultdict(<class 'dict'>, {
Parameter containing:
tensor([[ 0.4954, -0.0392],
[ 0.0778, -0.5769],
[-0.3332, -0.0659]], requires_grad=True): {
'step': tensor(1.),
'exp_avg': tensor([[-0.1632, 0.0252],
[-0.0513, 0.0079],
[ 0.0000, 0.0000]]),
'exp_avg_sq': tensor([[2.6648e-03, 6.3460e-05],
[2.6274e-04, 6.2572e-06],
[0.0000e+00, 0.0000e+00]])
},
Parameter containing:
tensor([[-0.4492, -0.1479, 0.4789]], requires_grad=True): {
'step': tensor(1.),
'exp_avg': tensor([[0.1821, 0.0577, 0.0000]]),
'exp_avg_sq': tensor([[0.0033, 0.0003, 0.0000]])
}
})
此时每个Tensor需要维护三个临时状态:step
、exp_avg
和 exp_avg_sq
。step
是当前更新的步数,exp_avg
就是SGD中的动量(不完全相同),exp_avg_sq
是平方梯度。
知道了 self.state
长什么样后,我们需要了解一下它是如何初始化的。
事实上 Optimizer
类并没有实现 self.state
的初始化,因为它的初始化是在 step()
中完成的,所以我们需要关注 Optimizer
的子类,这里以SGD为例。
python
class SGD(Optimizer):
def step(self):
for group in self.param_groups:
# 用来存储模型参数,梯度,动量
params: List[Tensor] = []
grads: List[Tensor] = []
momentum_buffer_list: List[Optional[Tensor]] = []
# 填充params、grads、momentum_buffer_list
self._init_group(group, params, grads, momentum_buffer_list)
# 执行sgd优化算法
sgd(params, grads, momentum_buffer_list, ...)
if group["momentum"] != 0:
for p, momentum_buffer in zip(params, momentum_buffer_list):
# 获取Tensor
state = self.state[p]
# 更新Tensor的临时状态
state["momentum_buffer"] = momentum_buffer
def _init_group(self, group, params, grads, momentum_buffer_list):
for p in group["params"]:
if p.grad is not None:
params.append(p)
grads.append(p.grad)
if group["momentum"] != 0:
# 因为self.state是defaultdict,所以初始时state会自动创建为一个字典
state = self.state[p]
# 因为初始时state没有momentum_buffer这个键
# 所以momentum_buffer_list的初始值为[None, None, None, ...]
momentum_buffer_list.append(state.get("momentum_buffer"))
每一步更新,_init_group
会在SGD算法执行前被调用,它用来获取模型所有待更新的参数,对应的已经计算的梯度,以及对应的上一时刻的动量。self.state
会在 _init_group
中进行初始化 。我们可以通过在 self._init_group
和 sgd()
这两个语句之间加入以下代码来查看相应的信息:
python
print(self.state)
print(momentum_buffer_list)
输出:
python
defaultdict(<class 'dict'>, {
Parameter containing:
tensor([[ 0.4679, -0.6531],
[-0.4707, -0.2854],
[ 0.6846, -0.6576]], requires_grad=True): {},
Parameter containing:
tensor([[-0.4362, 0.4155, 0.1798]], requires_grad=True): {}
})
[None, None]
这说明初始时,每一个Tensor对应的临时状态为空字典,且 momentum_buffer_list
的初始值为 [None, None, ...]
。
3.6 state_dict
state_dict
的作用是保存优化器当前的状态,以便在之后的训练中恢复或继续使用。
显然 state_dict
应当保存优化器的两大重要属性:
python
state_dict = {
"state": state,
"param_groups": param_groups,
}
但根据之前的分析,state
中的键会涉及到模型参数,此外,param_groups
中的 params
也会涉及到模型参数,如果就这样直接保存,相当于我们在保存了优化器的同时还保存了两份模型参数(state一份,param_groups一份),这显然是不可行的。
一种直观的想法是将这些模型参数映射成唯一的数字ID,而这些ID是几乎不占空间的。假设 param_groups
含有 k k k 个 param_group
,那么我们可以从第一个group开始,从0开始从前往后依次编号直至第 k k k 个group。
当然,我们还需要对 state
进行编号,由于 state
和 param_groups
中的参数不一定一一对应(因为 param_groups
可能会含有重复的参数,但 state
不会,所以二者长度不一定相等,具体见 add_param_group
源码),因此我们不能从0开始从前往后依次编号。这启发我们可以构造一个参数的内存地址到ID的映射 mapping
,这样在编号 state
的时候,我们就可以通过 mapping[id(tensor)]
来获取模型参数的ID了。
源码解析(已做简化):
python
def state_dict(self) -> Dict[str, Any]:
# 构建模型参数地址到ID的映射
param_mappings: Dict[int, int] = {}
start_index = 0
def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
# 处理优化器超参数部分
packed = {k: v for k, v in group.items() if k != "params"}
# 更新映射
param_mappings.update(
{
id(p): i
for i, p in enumerate(group["params"], start_index)
if id(p) not in param_mappings
}
)
# 处理模型参数部分,将具体的参数映射为ID
packed["params"] = [param_mappings[id(p)] for p in group["params"]]
start_index += len(packed["params"])
return packed
# 将所有group中的所有模型参数映射为ID
param_groups = [pack_group(g) for g in self.param_groups]
# 将state中的所有模型参数映射为ID
packed_state = {
param_mappings[id(k)]: v
for k, v in self.state.items()
}
state_dict = {
"state": packed_state,
"param_groups": param_groups,
}
return state_dict
我们可以通过以下代码来查看 state_dict
的样子:
python
import torch
import torch.nn as nn
import torch.optim as optim
import itertools
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 3, bias=False)
self.fc2 = nn.Linear(3, 3, bias=False)
self.fc3 = nn.Linear(3, 1, bias=False)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
model = MLP()
optimizer = optim.SGD([
{"params": itertools.chain(model.fc1.parameters(), model.fc2.parameters()), "lr": 0.01},
{"params": model.fc3.parameters(), "lr": 0.1},
], momentum=0.9)
x = torch.randn(1, 2)
y = torch.randn(1, 1)
criterion = nn.MSELoss()
output = model(x)
loss = criterion(output, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(optimizer.state_dict().keys())
print(optimizer.state_dict()['state'])
print(optimizer.state_dict()['param_groups'])
输出:
python
dict_keys(['state', 'param_groups'])
# state部分
{
0: {
'momentum_buffer': tensor([[-0.0270, 0.1874],
[ 0.0000, 0.0000],
[ 0.0000, 0.0000]])
},
1: {
'momentum_buffer': tensor([[ 0.0000, 0.0000, 0.0000],
[-0.2232, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000]])
},
2: {
'momentum_buffer': tensor([[ 0.0000, -0.3215, 0.0000]])
}
}
# param_groups部分
[
{
'lr': 0.01,
'momentum': 0.9,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'maximize': False,
'foreach': None,
'differentiable': False,
'fused': None,
'params': [0, 1]
},
{
'lr': 0.1,
'momentum': 0.9,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'maximize': False,
'foreach': None,
'differentiable': False,
'fused': None,
'params': [2]
}
]
可以看到 state
和 param_groups
中的模型参数全被映射成了ID。
3.7 load_state_dict
因为之前的ID映射是根据 param_groups
构造的,所以在load的时候,我们也要根据 param_groups
去建立一一对应关系,此时需要构造ID到模型参数的映射。
这里有一个细节,在还原 params
的时候,我们可以直接通过ID进行还原,但是在还原 state
的时候,我们要确保各个临时状态和模型参数是处于同一设备上。
源码解析(已做简化):
python
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# 获取当前优化器的param_groups和之前保存的param_groups
groups = self.param_groups
saved_groups = deepcopy(state_dict["param_groups"])
# 确保group的数量相等
if len(groups) != len(saved_groups):
raise ValueError(
"loaded state dict has a different number of " "parameter groups"
)
# 确保每个group中的模型参数个数相等
param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError(
"loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group"
)
# 之前是根据enumerate构造的正向映射
# 现在直接用zip构造反向映射
id_map = dict(
zip(
chain.from_iterable(g["params"] for g in saved_groups),
chain.from_iterable(g["params"] for g in groups),
)
)
# 用于递归地将Tensor对应的临时状态移动到和Tensor相同的设备上,并确保数据类型相同
def _cast(param, value, key=None):
if isinstance(value, torch.Tensor):
if key == 'step':
return value
else:
if param.is_floating_point():
return value.to(dtype=param.dtype, device=param.device)
else:
# 例如模型是一个量化模型,此时不必转化value的数据类型
return value.to(device=param.device)
elif isinstance(value, dict):
return {
k: _cast(param, v, key=k)
for k, v in value.items()
}
elif isinstance(value, Iterable):
return type(value)(_cast(param, v) for v in value)
else:
return value
# 还原state
# 注意要将临时状态转移到和param相同的设备上,有些时候还需要确保数据类型相同
state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
for k, v in state_dict["state"].items():
param = id_map[k]
state[param] = _cast(
param, v, param_id=k, param_groups=state_dict["param_groups"]
)
# 还原param_groups,只有params需要修改
def update_group(
group: Dict[str, Any], new_group: Dict[str, Any]
) -> Dict[str, Any]:
new_group["params"] = group["params"]
return new_group
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
# 更新
self.__dict__.update({"state": state, "param_groups": param_groups})
4. SGD Optimizer
在了解 Optimizer
源码后,我们开看SGD优化器是如何实现的。
构造函数(简化版):
python
class SGD(Optimizer):
def __init__(
self,
params,
lr: float = 1e-3,
momentum: float = 0,
dampening: float = 0,
weight_decay: float = 0,
nesterov=False,
):
# 将独属于SGD优化器的超参数打包成defaults
defaults = dict(
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
)
# 调用父类的构造函数初始化
super().__init__(params, defaults)
无非就做了两件事:
- 将独属于该优化器的超参数打包成
defaults
- 调用父类的构造函数进行初始化(不然没有
state
和param_groups
这两个属性)
step
和 _init_group
在3.5节中已经介绍过,这里关注sgd函数如何实现。
⚠️ 对SGD算法不熟悉的读者可以看这篇博客:深入解析SGD、Momentum与Nesterov:优化算法的对比与应用述
python
def sgd(
params: List[Tensor],
grads: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool,
):
for i, param in enumerate(params):
grad = grads[i]
# 权重衰减
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
if momentum != 0:
buf = momentum_buffer_list[i] # 存储的是每个参数上一时刻的动量
if buf is None:
buf = torch.clone(grad).detach() # 初始时动量就是梯度,因为m_0 = 0
momentum_buffer_list[i] = buf
else:
# buf = momentum * buf + (1 - dampening) * grad
buf.mul_(momentum).add_(grad, alpha=1 - dampening)
if nesterov:
grad = grad.add(buf, alpha=momentum)
else:
grad = buf
# 更新权重:param = param - lr * grad
param.add_(grad, alpha=-lr)
5. 极简版Optimizer源码
截止目前,我们可以对 Optimizer
基类的源码进行汇总,给出一个极简版的实现:
python
from collections import defaultdict
from copy import deepcopy
from itertools import chain
from typing import Any, Dict, Iterable
import torch
class Optimizer:
def __init__(self, params, defaults: Dict[str, Any]) -> None:
self.defaults = defaults
self.state = defaultdict(dict)
self.param_groups = []
param_groups = list(params)
if not isinstance(param_groups[0], dict):
param_groups = [{"params": param_groups}]
for param_group in param_groups:
self.add_param_group(param_group)
def add_param_group(self, param_group: Dict[str, Any]) -> None:
params = param_group["params"]
param_group["params"] = [params] if isinstance(params, torch.Tensor) else list(params)
for name, default in self.defaults.items():
param_group.setdefault(name, default)
self.param_groups.append(param_group)
def step(self) -> None:
raise NotImplementedError
def zero_grad(self) -> None:
for group in self.param_groups:
for p in group["params"]:
if p.grad is not None:
p.grad = None
def state_dict(self) -> Dict[str, Any]:
param_mappings = {}
start_index = 0
def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
packed = {k: v for k, v in group.items() if k != "params"}
param_mappings.update(
{id(p): i for i, p in enumerate(group["params"], start_index) if id(p) not in param_mappings}
)
packed["params"] = [param_mappings[id(p)] for p in group["params"]]
start_index += len(packed["params"])
return packed
param_groups = [pack_group(g) for g in self.param_groups]
packed_state = {param_mappings[id(k)]: v for k, v in self.state.items()}
return {"state": packed_state, "param_groups": param_groups}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
groups = self.param_groups
saved_groups = deepcopy(state_dict["param_groups"])
id_map = dict(
zip(
chain.from_iterable(g["params"] for g in saved_groups),
chain.from_iterable(g["params"] for g in groups),
)
)
def _cast(param, value, key=None):
if isinstance(value, torch.Tensor):
if key == 'step':
return value
else:
return value.to(dtype=param.dtype, device=param.device) if param.is_floating_point() else value.to(device=param.device)
elif isinstance(value, dict):
return {k: _cast(param, v, key=k) for k, v in value.items()}
elif isinstance(value, Iterable):
return type(value)(_cast(param, v) for v in value)
else:
return value
state = defaultdict(dict)
for k, v in state_dict["state"].items():
param = id_map[k]
state[param] = _cast(param, v, param_id=k)
def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]:
new_group["params"] = group["params"]
return new_group
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__dict__.update({"state": state, "param_groups": param_groups})
6. 自定义你的Optimizer
以SGD为例,我们来实现一个只有学习率的极简版优化器。
步骤如下:
- 继承
Optimizer
,声明构造函数,构造函数的形参必须含有params
,随后的一系列形参都是和该优化器相关的超参数。 - 在构造函数中将相关超参数打包成
defaults
,然后调用父类的构造函数。 - 重写
step
方法,并用@torch.no_grad()
装饰。 - 在
step
中遍历self.param_groups
,每一次遍历,声明params
、grads
列表,然后用group
的数据进行填充。如果涉及到临时状态,还需要额外声明和临时状态相关的列表。
python
class SimpleSGD(Optimizer):
def __init__(self, params, lr=0.01):
assert lr > 0.0
defaults = dict(lr=lr)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
for group in self.param_groups:
params = []
grads = []
for p in group['params']:
if p.grad is not None:
params.append(p)
grads.append(p.grad)
lr = group['lr']
for param, grad in zip(params, grads):
param.add_(grad, alpha=-lr)
仅设置学习率时,它所产生的结果和官方的SGD相同,读者可自行验证。
Ref
[1] https://blog.csdn.net/zzxxxaa1/article/details/121144570?spm=1001.2014.3001.5502
[2] https://www.hjhgjghhg.com/archives/119/
[3] https://pytorch.org/docs/stable/generated/torch.optim.SGD.html