PyTorch源码系列(一)——Optimizer源码详解

目录

  • [1. Optimizer类](#1. Optimizer类)
  • [2. Optimizer概览](#2. Optimizer概览)
  • [3. 源码解析](#3. 源码解析)
    • [3.1 构造方法](#3.1 构造方法)
      • [3.1.1 全局设置情形](#3.1.1 全局设置情形)
      • [3.1.2 局部设置情形](#3.1.2 局部设置情形)
      • [3.1.3 覆盖测试](#3.1.3 覆盖测试)
      • [3.1.4 逐行讲解](#3.1.4 逐行讲解)
    • [3.2 add_param_group](#3.2 add_param_group)
    • [3.3 step](#3.3 step)
    • [3.4 zero_grad](#3.4 zero_grad)
    • [3.5 self.state](#3.5 self.state)
    • [3.6 state_dict](#3.6 state_dict)
    • [3.7 load_state_dict](#3.7 load_state_dict)
  • [4. SGD Optimizer](#4. SGD Optimizer)
  • [5. 极简版Optimizer源码](#5. 极简版Optimizer源码)
  • [6. 自定义你的Optimizer](#6. 自定义你的Optimizer)
  • Ref

1. Optimizer类

PyTorch的 Optimizer 类是深度学习模型中用于管理和更新模型参数的基类。它负责根据损失函数的梯度信息调整模型的参数,使模型逐步逼近最佳状态。Optimizer 类通过实现一些核心方法,如 step(),来执行参数更新过程,而 zero_grad() 方法则用于清除模型中所有参数的梯度。

每个优化器会存储参数组和相关的状态,例如学习率、动量等。不同的优化器(如SGD、Adam等)继承自 Optimizer 类,并根据各自的算法特点实现了不同的参数更新策略。此外,Optimizer 类还允许用户在初始化时指定超参数,如学习率等,这些超参数会影响参数的更新方式。

本文将详细讲解 Optimizer 类的源码,并以SGD优化器为例介绍如何自定义一个自己的优化器。

2. Optimizer概览

📝 本文在讲解源码时,只考虑源码的简化版本,而不考虑完整的源码。

除了构造方法外,Optimizer常用的几个方法如下:

python 复制代码
class Optimizer:
    def state_dict(self) -> Dict[str, Any]:
        """返回优化器的状态字典,保存当前的优化器状态以便之后恢复。"""
        ...

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """加载先前保存的状态字典,恢复优化器的状态。"""
        ...

    def zero_grad(self, set_to_none: bool = True) -> None:
        """将优化器中所有参数的梯度清零,通常在每次反向传播前调用。"""
        ...

    def step(self) -> None:
        """执行一步优化更新,用于根据梯度更新模型参数。"""
        raise NotImplementedError
    
    def add_param_group(self, param_group: Dict[str, Any]) -> None:
        """向优化器中添加新的参数组,用于管理不同的参数组(如不同学习率等)。"""
        ...

在自定义优化器时,必须继承 Optimizer 类,并实现 step() 方法,否则将会报错。

optimizer.py 文件中还定义了若干类型别名,如下:

python 复制代码
Args: TypeAlias = Tuple[Any, ...]
Kwargs: TypeAlias = Dict[str, Any]
StateDict: TypeAlias = Dict[str, Any]
ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
  • Args:通常用来表示函数中不定数量的位置参数。
  • Kwargs:通常用来表示函数中不定数量的关键字参数,键为参数名,值为相应的参数值。
  • StateDict:通常用于保存模型参数和优化器的状态信息。
  • ParamsT:是一个 torch.Tensor 的可迭代对象,或者是包含键值对的字典的可迭代对象。

3. 源码解析

3.1 构造方法

Optimizer 的构造方法如下:

python 复制代码
class Optimizer:
    r"""Base class for all optimizers.

    .. warning::
        Parameters need to be specified as collections that have a deterministic
        ordering that is consistent between runs. Examples of objects that don't
        satisfy those properties are sets and iterators over values of dictionaries.

    Args:
        params (iterable): an iterable of :class:`torch.Tensor` s or
            :class:`dict` s. Specifies what Tensors should be optimized.
        defaults: (dict): a dict containing default values of optimization
            options (used when a parameter group doesn't specify them).
    """

    def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None:
        self.defaults = defaults

        if isinstance(params, torch.Tensor):
            raise TypeError(
                "params argument given to the optimizer should be "
                "an iterable of Tensors or dicts, but got " + torch.typename(params)
            )

        self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)
        self.param_groups: List[Dict[str, Any]] = []

        param_groups = list(params)
        if len(param_groups) == 0:
            raise ValueError("optimizer got an empty parameter list")
        if not isinstance(param_groups[0], dict):
            param_groups = [{"params": param_groups}]

        for param_group in param_groups:
            self.add_param_group(param_group)

从源码可以看出,Optimizer两个形参和两个属性self.defaults 不算)。

形参:

  • params:由一系列Tensor组成的迭代器或是由一系列字典组成的迭代器。通常是模型的参数。例如 model.parameters()
  • defaults:一个键为字符串的字典。通常是和优化算法相关的全局 超参数。例如 lrmomentum 等。下文会解释为什么这里是「全局」。

属性:

  • state:一个键为Tensor,值为字典的字典。用来存储每个模型参数对应的临时状态,例如 momentum
  • param_groups:一个列表,其中的每一个元素都是一个键为字符串的字典。每一个元素对应了一个 param_group

看到这里,可能你仍然不明白 param_groups 是什么。既然它是复数形式,说明它是由一个个 param_group 组成的,每一个 param_group 的类型是 Dict[str, Any],因此 param_groups 的类型就是 List[Dict[str, Any]]

那么什么是 param_group 呢?我们知道Transformer模型通常由多个layer堆叠而成,绝大部分情况下,整个模型的训练会采用同一个学习率。但某些特殊场景下,我们可能希望不同的layer使用不同的学习率,此时就会涉及到一个个 param_group 了。

  • 对于前者,形参 params 的类型为 Iterable[Tensor],因为整个模型会共享同一套优化器参数,所以只需要指定全局优化器参数 defaults 即可,此时 param_groups 是一个长度为1的列表
  • 而对于后者,params 的类型为 Iterable[Dict[str, Any]],每一个字典包含了layer的参数和对应的局部优化器参数 。注意此时依然可以指定全局优化器参数,例如我们希望不同的layer使用不同的学习率,但希望所有的layer都使用同一个动量。此时 param_groups 是一个长度大于1的列表

我们来看更具体的例子。

3.1.1 全局设置情形

python 复制代码
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 3, bias=False)
        self.fc2 = nn.Linear(3, 3, bias=False)
        self.fc3 = nn.Linear(3, 1, bias=False)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = MLP()

# =============================
# 情况 1: 所有层使用同一套优化器参数
# =============================
# 此时,params 是 Iterable[Tensor],直接传递模型的所有参数
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
print(optimizer.param_groups)

输出(做了简化处理):

python 复制代码
[
    {
        'params': [fc1_tensor, fc2_tensor, fc3_tensor],
        'lr': 0.01,
        'momentum': 0.9,
        'dampening': 0,
        'weight_decay': 0,
        'nesterov': False,
        'maximize': False,
        'foreach': None,
        'differentiable': False,
        'fused': None
    }
]

可以看到此时 param_groups 是一个长度为1的列表。字典中的 params 对应了模型的参数,因为没有设置 bias,所以总共有三个张量。字典中剩余的键对应了优化器的超参数。

3.1.2 局部设置情形

python 复制代码
# ===============================
# 情况 2: 不同的层使用不同的优化器参数
# ===============================
# 此时,params 是 Iterable[Dict[str, Any]],可以为不同的层设置不同的学习率
optimizer = optim.SGD([
    {'params': model.fc1.parameters(), 'lr': 0.001},  # 第1层学习率为 0.001
    {'params': model.fc2.parameters(), 'lr': 0.01},   # 第2层学习率为 0.01
    {'params': model.fc3.parameters(), 'lr': 0.1}     # 第3层学习率为 0.1
])

print(optimizer.param_groups)

输出(做了简化处理):

python 复制代码
[
    {
        'params': [fc1_tensor],
        'lr': 0.001,
        'momentum': 0,
        'dampening': 0,
        'weight_decay': 0,
        'nesterov': False,
        'maximize': False,
        'foreach': None,
        'differentiable': False,
        'fused': None
    },
    {
        'params': [fc2_tensor],
        'lr': 0.01,
        'momentum': 0,
        'dampening': 0,
        'weight_decay': 0,
        'nesterov': False,
        'maximize': False,
        'foreach': None,
        'differentiable': False,
        'fused': None
    },
    {
        'params': [fc3_tensor],
        'lr': 0.1,
        'momentum': 0,
        'dampening': 0,
        'weight_decay': 0,
        'nesterov': False,
        'maximize': False,
        'foreach': None,
        'differentiable': False,
        'fused': None
    }
]

此时 param_groups 是一个长度为3的列表。列表中的每一个字典就是一个 param_group,存储了相应的layer参数和局部优化器参数。

⚠️ 这里使用"局部优化器参数"这一术语并不准确。事实上,"局部"和"全局"的概念仅在优化器实例化之前存在。一旦优化器实例化,全局参数将会逐步写入到每一个 param_group 中,例如 momentumdampening 等未显式定义的参数,实质上属于全局优化器参数,并会通过 add_param_group 方法自动地写入到每一个 param_group 中。因此,在优化器实例化后,每个 param_group 都拥有一套完整且独立的参数配置。

到目前为止,我们可以做一个简单总结。param_groups 是一个元素为字典的列表。当传入的 params 为由Tensor构成的迭代器时,此时 param_groups 的长度为1,即只含有一个 param_group。当传入的 params 为由字典构成的迭代器时,此时 param_groups 的长度为 len(params)

param_groups 中的所有字典的完全相同,均形如:

python 复制代码
param_group = {
    'params': [tensor_1, tensor_2, ...],  # 待优化的模型参数
    **defaults  # 全局优化器参数
}

但所有字典的却不尽相同。

3.1.3 覆盖测试

之前我们只考虑了「仅全局」和「仅局部」的情形,如果我们手动设置全局优化器参数,并且它和某些 param_group 中的局部优化器参数冲突了,那么这个全局的会覆盖掉局部的吗?

python 复制代码
optimizer = optim.SGD([
    {'params': model.fc1.parameters(), 'lr': 0.001, 'momentum': 0.3},
    {'params': model.fc2.parameters(), 'lr': 0.01},
    {'params': model.fc3.parameters(), 'lr': 0.1, 'nesterov': True}
], momentum=0.9, nesterov=False)

print(optimizer.param_groups)

输出:

python 复制代码
[
    {
        'params': [fc1_tensor],
        'lr': 0.001,
        'momentum': 0.3,
        'dampening': 0,
        'weight_decay': 0,
        'nesterov': False,
        'maximize': False,
        'foreach': None,
        'differentiable': False,
        'fused': None
    },
    {
        'params': [fc2_tensor],
        'lr': 0.01,
        'momentum': 0.9,
        'dampening': 0,
        'weight_decay': 0,
        'nesterov': False,
        'maximize': False,
        'foreach': None,
        'differentiable': False,
        'fused': None
    },
    {
        'params': [fc3_tensor],
        'lr': 0.1,
        'nesterov': True,
        'momentum': 0.9,
        'dampening': 0,
        'weight_decay': 0,
        'maximize': False,
        'foreach': None,
        'differentiable': False,
        'fused': None
    }
]

由此可以得出结论:全局优化器参数不会覆盖掉局部优化器参数

3.1.4 逐行讲解

现在我们已经对 paramsdefaultsparam_groups 这三个变量有了足够的了解(state 会放在下文讲解),接下来我们逐行剖析构造方法。

python 复制代码
def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None:
	# 将defaults设置成属性以方便在后续的add_param_group方法中使用
    self.defaults = defaults
	
	# params必须是关于tensor或dict的可迭代对象
    if isinstance(params, torch.Tensor):
        raise TypeError(
            "params argument given to the optimizer should be "
            "an iterable of Tensors or dicts, but got " + torch.typename(params)
        )
	
	# 初始化两大重要属性:state和param_groups
    self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)
    self.param_groups: List[Dict[str, Any]] = []
	
	# 将迭代器转化成列表,此时要么是List[Tensor]要么是List[Dict]
    param_groups = list(params)

	# 必须非空
    if len(param_groups) == 0:
        raise ValueError("optimizer got an empty parameter list")
    
    # 如果不是List[Dict],说明进行的是全局设置,此时param_groups是List[Tensor]
    # 类型,代表模型的所有参数。然后将其转变成List[Dict]类型,以达到格式统一的目的
    if not isinstance(param_groups[0], dict):
        param_groups = [{"params": param_groups}]
	
	# 将每一个param_group经过相关处理后添加到self.param_groups中
	# 这里也会将全局优化器参数defaults注入到每一个param_group中
    for param_group in param_groups:
        self.add_param_group(param_group)

通常来讲,params 会接收 model.parameters() 作为输入,如果不是,那么 params 必须是由字典构成的列表,且每一个字典必须含有 params 这个键,对应的值是模型的部分参数。

绝大多数情况下,我们认为下式成立:

model.parameters() = ⋃ i = 1 k param_groups [ i ] [ " params " ] \text{model.parameters()}=\bigcup_{i=1}^k \text{param\_groups}[i]["\text{params}"] model.parameters()=i=1⋃kparam_groups[i]["params"]

k k k 是 param_group 的个数,且 param_groups[i]["params"] 两两互不相交。

如果涉及到冻结模型的一部分参数,仅训练剩余的参数,那么上式就不再成立了。

3.2 add_param_group

有了3.1小节的基础后,这里直接逐行讲解源码。

python 复制代码
def add_param_group(self, param_group: Dict[str, Any]) -> None:
    r"""Add a param group to the :class:`Optimizer`'s `param_groups`.

    This can be useful when fine tuning a pre-trained network as frozen layers can be made
    trainable and added to the :class:`Optimizer` as training progresses.

    Args:
        param_group (dict): Specifies what Tensors should be optimized along with group
            specific optimization options.
    """

    # 确保传入的param_group一定是一个字典
    if not isinstance(param_group, dict):
        raise TypeError(f"param_group must be a dict, but got {type(param_group)}")

    # 获取该group中的模型参数部分,然后将其转变为List[Tensor]类型
    params = param_group["params"]
    if isinstance(params, torch.Tensor):
        param_group["params"] = [params]
    elif isinstance(params, set):
        raise TypeError(
            "optimizer parameters need to be organized in ordered collections, but "
            "the ordering of tensors in sets will change between runs. Please use a list instead."
        )
    else:
        param_group["params"] = list(params)

    # 参数检查
    for param in param_group["params"]:
        # 确保所有的参数必须都是Tensor,否则无法优化
        if not isinstance(param, torch.Tensor):
            raise TypeError(
                "optimizer can only optimize Tensors, "
                "but one of the params is " + torch.typename(param)
            )

        # 确保所有的参数都是叶子节点
        if not (param.is_leaf or param.retains_grad):
            raise ValueError("can't optimize a non-leaf Tensor")

    # 将全局优化器参数注入到当前的group中,setdefault保证了这一过程不会覆盖掉局部优化器参数
    for name, default in self.defaults.items():
        param_group.setdefault(name, default)

    # 检查是否存在重复的参数
    # 目前出现重复的参数并不会报错
    params = param_group["params"]
    if len(params) != len(set(params)):
        warnings.warn(
            "optimizer contains a parameter group with duplicate parameters; "
            "in future, this will cause an error; "
            "see github.com/pytorch/pytorch/issues/40967 for more information",
            stacklevel=3,
        )

    # 判断当前的group中是否有参数和已经添加过的group中的参数重复
    # 两个集合交集为空,isdisjoint()返回True
    param_set: Set[torch.Tensor] = set()
    for group in self.param_groups:
        param_set.update(set(group["params"]))

    if not param_set.isdisjoint(set(param_group["params"])):
        raise ValueError("some parameters appear in more than one parameter group")

    # 将当前的group添加到self.param_groups中
    self.param_groups.append(param_group)

add_param_group 源码看似复杂,但归根结底也就那么几行代码是真正起到作用的,这里给出一个简化版本:

python 复制代码
def add_param_group(self, param_group: Dict[str, Any]) -> None:
    params = param_group["params"]
    param_group["params"] = [params] if isinstance(params, torch.Tensor) else list(params)

    for name, default in self.defaults.items():
        param_group.setdefault(name, default)

    self.param_groups.append(param_group)

3.3 step

Optimizer 类并没有实现 step() 方法:

python 复制代码
def step(self) -> None:
    """Performs a single optimization step (parameter update)."""
    raise NotImplementedError

自定义优化器时,需要继承 Optimizer 类,并实现该方法,否则会报错。

因为 step() 不返回任何值,所以需要实现模型参数的原地更新

3.4 zero_grad

如下是简化版的源码:

python 复制代码
def zero_grad(self, set_to_none: bool = True) -> None:
    r"""Resets the gradients of all optimized :class:`torch.Tensor` s.

    Args:
        set_to_none (bool): instead of setting to zero, set the grads to None.
            This will in general have lower memory footprint, and can modestly improve performance.
            However, it changes certain behaviors. For example:
            1. When the user tries to access a gradient and perform manual ops on it,
            a None attribute or a Tensor full of 0s will behave differently.
            2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
            are guaranteed to be None for params that did not receive a gradient.
            3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
            (in one case it does the step with a gradient of 0 and in the other it skips
            the step altogether).
    """

	# 遍历每一个group,然后遍历group中的每一个参数
    for group in self.param_groups:
        for p in group["params"]:
        	
        	# 如果p的梯度不为None,说明需要清空
            if p.grad is not None:
                if set_to_none:
                    p.grad = None
                else:
                	# grad_fn 表示该张量是由某个操作生成的
                	# 因此将该张量从计算图中分离出来
                    if p.grad.grad_fn is not None:
                        p.grad.detach_()
                    else:
                    	# 不在计算图中,原地关闭梯度以防止后续追踪
                        p.grad.requires_grad_(False)
                    
                    # 清零梯度
                    p.grad.zero_()

从注释可以看出,当梯度被设置为 None 时,梯度张量会被释放,从而减少内存占用。而如果梯度被清零(即将其所有元素设置为 0),梯度张量的内存仍然会被保留。

由于 set_to_none 默认为 True,因此 zero_grad 源码可以进一步简化:

python 复制代码
def zero_grad(self) -> None:
    for group in self.param_groups:
        for p in group["params"]:
            if p.grad is not None:
                p.grad = None

3.5 self.state

Optimizer 中有两大重要属性:stateparam_groups。先前我们已经了解了 param_groups,现在来看 state

我们已经知道构造方法中会通过多次调用 add_param_group 来初始化 self.param_groups,但截止目前,似乎并没有方法能够初始化 self.state,那它是怎么初始化的?以及它到底"长什么样"呢?

self.state 用来存储与模型参数相关的临时状态 。对于SGD with momentum而言,每个参数都需要维护一个动量。对于Adam而言,每个参数不仅需要维护一个动量(一阶矩),还需要维护一个平方梯度(二阶矩)。很显然,对于一个待优化的Tensor,它的临时状态和它的形状是相同的 ,并且对于该Tensor,可能有多个临时状态需要维护,每个临时状态都有一个自己的名字,由此推测 self.state 的类型应当是 Dict[Tensor, Dict[str, Tensor]],这与源码中声明的相同:

python 复制代码
self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)

self.state 的键是一个Tensor,值是一个字典。字典存储了优化该Tensor的一些临时状态,键是临时状态的名称,值是相应的状态。

以SGD优化器为例,进行一次单步更新,然后查看它的 state 属性:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim


class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(2, 3, bias=False)
        self.fc2 = nn.Linear(3, 1, bias=False)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


model = SimpleMLP()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

inputs = torch.randn(2)
target = torch.randn(1)

output = model(inputs)
loss = criterion(output, target)

loss.backward()
optimizer.step()

print(optimizer.state)

输出:

python 复制代码
defaultdict(<class 'dict'>, {
    Parameter containing:
    tensor([[-0.4128, -0.6015],
            [-0.3628,  0.6162],
            [ 0.4367,  0.3409]], requires_grad=True): {
        'momentum_buffer': tensor([[ 0.0000,  0.0000],
                                   [-0.0052,  0.0581],
                                   [-0.0130,  0.1446]])
    },
    Parameter containing:
    tensor([[0.5394, 0.1077, 0.2748]], requires_grad=True): {
        'momentum_buffer': tensor([[0.0000, 0.3398, 0.1585]])
    }
})

可以得知,对于SGD优化器而言,每个待优化的Tensor只需要维护一个临时状态:momentum_buffer,即当前的动量。且临时状态的形状与待优化的Tensor相同。

将SGD换成Adam,再来看看结果:

python 复制代码
defaultdict(<class 'dict'>, {
    Parameter containing:
    tensor([[ 0.4954, -0.0392],
            [ 0.0778, -0.5769],
            [-0.3332, -0.0659]], requires_grad=True): {
        'step': tensor(1.),
        'exp_avg': tensor([[-0.1632,  0.0252],
                           [-0.0513,  0.0079],
                           [ 0.0000,  0.0000]]),
        'exp_avg_sq': tensor([[2.6648e-03, 6.3460e-05],
                              [2.6274e-04, 6.2572e-06],
                              [0.0000e+00, 0.0000e+00]])
    },
    Parameter containing:
    tensor([[-0.4492, -0.1479,  0.4789]], requires_grad=True): {
        'step': tensor(1.),
        'exp_avg': tensor([[0.1821, 0.0577, 0.0000]]),
        'exp_avg_sq': tensor([[0.0033, 0.0003, 0.0000]])
    }
})

此时每个Tensor需要维护三个临时状态:stepexp_avgexp_avg_sqstep 是当前更新的步数,exp_avg 就是SGD中的动量(不完全相同),exp_avg_sq 是平方梯度。

知道了 self.state 长什么样后,我们需要了解一下它是如何初始化的。

事实上 Optimizer 类并没有实现 self.state 的初始化,因为它的初始化是在 step() 中完成的,所以我们需要关注 Optimizer 的子类,这里以SGD为例。

python 复制代码
class SGD(Optimizer):
    def step(self):
        for group in self.param_groups:
            # 用来存储模型参数,梯度,动量
            params: List[Tensor] = []
            grads: List[Tensor] = []
            momentum_buffer_list: List[Optional[Tensor]] = []

            # 填充params、grads、momentum_buffer_list
            self._init_group(group, params, grads, momentum_buffer_list)

            # 执行sgd优化算法
            sgd(params, grads, momentum_buffer_list, ...)

            if group["momentum"] != 0:
                for p, momentum_buffer in zip(params, momentum_buffer_list):
                    # 获取Tensor
                    state = self.state[p]
                    # 更新Tensor的临时状态
                    state["momentum_buffer"] = momentum_buffer

    def _init_group(self, group, params, grads, momentum_buffer_list):
        for p in group["params"]:
            if p.grad is not None:
                params.append(p)
                grads.append(p.grad)

                if group["momentum"] != 0:
                    # 因为self.state是defaultdict,所以初始时state会自动创建为一个字典
                    state = self.state[p]
                    # 因为初始时state没有momentum_buffer这个键
                    # 所以momentum_buffer_list的初始值为[None, None, None, ...]
                    momentum_buffer_list.append(state.get("momentum_buffer"))

每一步更新,_init_group 会在SGD算法执行前被调用,它用来获取模型所有待更新的参数,对应的已经计算的梯度,以及对应的上一时刻的动量。self.state 会在 _init_group 中进行初始化 。我们可以通过在 self._init_groupsgd() 这两个语句之间加入以下代码来查看相应的信息:

python 复制代码
print(self.state)
print(momentum_buffer_list)

输出:

python 复制代码
defaultdict(<class 'dict'>, {
    Parameter containing:
    tensor([[ 0.4679, -0.6531],
            [-0.4707, -0.2854],
            [ 0.6846, -0.6576]], requires_grad=True): {},
    
    Parameter containing:
    tensor([[-0.4362,  0.4155,  0.1798]], requires_grad=True): {}
})

[None, None]

这说明初始时,每一个Tensor对应的临时状态为空字典,且 momentum_buffer_list 的初始值为 [None, None, ...]

3.6 state_dict

state_dict 的作用是保存优化器当前的状态,以便在之后的训练中恢复或继续使用。

显然 state_dict 应当保存优化器的两大重要属性

python 复制代码
state_dict = {
    "state": state,
    "param_groups": param_groups,
}

但根据之前的分析,state 中的键会涉及到模型参数,此外,param_groups 中的 params 也会涉及到模型参数,如果就这样直接保存,相当于我们在保存了优化器的同时还保存了两份模型参数(state一份,param_groups一份),这显然是不可行的。

一种直观的想法是将这些模型参数映射成唯一的数字ID,而这些ID是几乎不占空间的。假设 param_groups 含有 k k k 个 param_group,那么我们可以从第一个group开始,从0开始从前往后依次编号直至第 k k k 个group。

当然,我们还需要对 state 进行编号,由于 stateparam_groups 中的参数不一定一一对应(因为 param_groups 可能会含有重复的参数,但 state 不会,所以二者长度不一定相等,具体见 add_param_group 源码),因此我们不能从0开始从前往后依次编号。这启发我们可以构造一个参数的内存地址到ID的映射 mapping,这样在编号 state 的时候,我们就可以通过 mapping[id(tensor)] 来获取模型参数的ID了。

源码解析(已做简化):

python 复制代码
def state_dict(self) -> Dict[str, Any]:
    # 构建模型参数地址到ID的映射
    param_mappings: Dict[int, int] = {}
    start_index = 0

    def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
        nonlocal start_index
        # 处理优化器超参数部分
        packed = {k: v for k, v in group.items() if k != "params"}
        # 更新映射
        param_mappings.update(
            {
                id(p): i
                for i, p in enumerate(group["params"], start_index)
                if id(p) not in param_mappings
            }
        )
        # 处理模型参数部分,将具体的参数映射为ID
        packed["params"] = [param_mappings[id(p)] for p in group["params"]]
        start_index += len(packed["params"])
        return packed

    # 将所有group中的所有模型参数映射为ID
    param_groups = [pack_group(g) for g in self.param_groups]

    # 将state中的所有模型参数映射为ID
    packed_state = {
        param_mappings[id(k)]: v
        for k, v in self.state.items()
    }

    state_dict = {
        "state": packed_state,
        "param_groups": param_groups,
    }

    return state_dict

我们可以通过以下代码来查看 state_dict 的样子:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import itertools

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 3, bias=False)
        self.fc2 = nn.Linear(3, 3, bias=False)
        self.fc3 = nn.Linear(3, 1, bias=False)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = MLP()

optimizer = optim.SGD([
    {"params": itertools.chain(model.fc1.parameters(), model.fc2.parameters()), "lr": 0.01},
    {"params": model.fc3.parameters(), "lr": 0.1},
], momentum=0.9)

x = torch.randn(1, 2)
y = torch.randn(1, 1)

criterion = nn.MSELoss()

output = model(x)
loss = criterion(output, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(optimizer.state_dict().keys())

print(optimizer.state_dict()['state'])
print(optimizer.state_dict()['param_groups'])

输出:

python 复制代码
dict_keys(['state', 'param_groups'])

# state部分
{
    0: {
        'momentum_buffer': tensor([[-0.0270,  0.1874],
                                   [ 0.0000,  0.0000],
                                   [ 0.0000,  0.0000]])
    },
    1: {
        'momentum_buffer': tensor([[ 0.0000,  0.0000,  0.0000],
                                   [-0.2232,  0.0000,  0.0000],
                                   [ 0.0000,  0.0000,  0.0000]])
    },
    2: {
        'momentum_buffer': tensor([[ 0.0000, -0.3215,  0.0000]])
    }
}

# param_groups部分
[
    {
        'lr': 0.01,
        'momentum': 0.9,
        'dampening': 0,
        'weight_decay': 0,
        'nesterov': False,
        'maximize': False,
        'foreach': None,
        'differentiable': False,
        'fused': None,
        'params': [0, 1]
    },
    {
        'lr': 0.1,
        'momentum': 0.9,
        'dampening': 0,
        'weight_decay': 0,
        'nesterov': False,
        'maximize': False,
        'foreach': None,
        'differentiable': False,
        'fused': None,
        'params': [2]
    }
]

可以看到 stateparam_groups 中的模型参数全被映射成了ID。

3.7 load_state_dict

因为之前的ID映射是根据 param_groups 构造的,所以在load的时候,我们也要根据 param_groups 去建立一一对应关系,此时需要构造ID到模型参数的映射。

这里有一个细节,在还原 params 的时候,我们可以直接通过ID进行还原,但是在还原 state 的时候,我们要确保各个临时状态和模型参数是处于同一设备上。

源码解析(已做简化):

python 复制代码
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
    # 获取当前优化器的param_groups和之前保存的param_groups
    groups = self.param_groups
    saved_groups = deepcopy(state_dict["param_groups"])

    # 确保group的数量相等
    if len(groups) != len(saved_groups):
        raise ValueError(
            "loaded state dict has a different number of " "parameter groups"
        )
    
    # 确保每个group中的模型参数个数相等
    param_lens = (len(g["params"]) for g in groups)
    saved_lens = (len(g["params"]) for g in saved_groups)
    if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
        raise ValueError(
            "loaded state dict contains a parameter group "
            "that doesn't match the size of optimizer's group"
        )

    # 之前是根据enumerate构造的正向映射
    # 现在直接用zip构造反向映射
    id_map = dict(
        zip(
            chain.from_iterable(g["params"] for g in saved_groups),
            chain.from_iterable(g["params"] for g in groups),
        )
    )

    # 用于递归地将Tensor对应的临时状态移动到和Tensor相同的设备上,并确保数据类型相同
    def _cast(param, value, key=None):
        if isinstance(value, torch.Tensor):
            if key == 'step':
                return value
            else:
                if param.is_floating_point():
                    return value.to(dtype=param.dtype, device=param.device)
                else:
                    # 例如模型是一个量化模型,此时不必转化value的数据类型
                    return value.to(device=param.device)

        elif isinstance(value, dict):
            return {
                k: _cast(param, v, key=k)
                for k, v in value.items()
            }
        elif isinstance(value, Iterable):
            return type(value)(_cast(param, v) for v in value)
        else:
            return value

    # 还原state
    # 注意要将临时状态转移到和param相同的设备上,有些时候还需要确保数据类型相同
    state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
    for k, v in state_dict["state"].items():
        param = id_map[k]
        state[param] = _cast(
            param, v, param_id=k, param_groups=state_dict["param_groups"]
        )

    # 还原param_groups,只有params需要修改
    def update_group(
        group: Dict[str, Any], new_group: Dict[str, Any]
    ) -> Dict[str, Any]:
        new_group["params"] = group["params"]
        return new_group

    param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]

    # 更新
    self.__dict__.update({"state": state, "param_groups": param_groups})

4. SGD Optimizer

在了解 Optimizer 源码后,我们开看SGD优化器是如何实现的。

构造函数(简化版):

python 复制代码
class SGD(Optimizer):
    def __init__(
        self,
        params,
        lr: float = 1e-3,
        momentum: float = 0,
        dampening: float = 0,
        weight_decay: float = 0,
        nesterov=False,
    ):
        # 将独属于SGD优化器的超参数打包成defaults
        defaults = dict(
            lr=lr,
            momentum=momentum,
            dampening=dampening,
            weight_decay=weight_decay,
            nesterov=nesterov,
        )

        # 调用父类的构造函数初始化
        super().__init__(params, defaults)

无非就做了两件事:

  1. 将独属于该优化器的超参数打包成 defaults
  2. 调用父类的构造函数进行初始化(不然没有 stateparam_groups 这两个属性)

step_init_group 在3.5节中已经介绍过,这里关注sgd函数如何实现。

⚠️ 对SGD算法不熟悉的读者可以看这篇博客:深入解析SGD、Momentum与Nesterov:优化算法的对比与应用述

python 复制代码
def sgd(
    params: List[Tensor],
    grads: List[Tensor],
    momentum_buffer_list: List[Optional[Tensor]],
    weight_decay: float,
    momentum: float,
    lr: float,
    dampening: float,
    nesterov: bool,
):
    for i, param in enumerate(params):
        grad = grads[i]

        # 权重衰减
        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)

        if momentum != 0:
            buf = momentum_buffer_list[i]  # 存储的是每个参数上一时刻的动量

            if buf is None:
                buf = torch.clone(grad).detach()  # 初始时动量就是梯度,因为m_0 = 0
                momentum_buffer_list[i] = buf
            else:
                # buf = momentum * buf + (1 - dampening) * grad
                buf.mul_(momentum).add_(grad, alpha=1 - dampening)

            if nesterov:
                grad = grad.add(buf, alpha=momentum)
            else:
                grad = buf

        # 更新权重:param = param - lr * grad
        param.add_(grad, alpha=-lr)

5. 极简版Optimizer源码

截止目前,我们可以对 Optimizer 基类的源码进行汇总,给出一个极简版的实现:

python 复制代码
from collections import defaultdict
from copy import deepcopy
from itertools import chain
from typing import Any, Dict, Iterable

import torch


class Optimizer:
    def __init__(self, params, defaults: Dict[str, Any]) -> None:
        self.defaults = defaults
        self.state = defaultdict(dict)
        self.param_groups = []

        param_groups = list(params)
        if not isinstance(param_groups[0], dict):
            param_groups = [{"params": param_groups}]

        for param_group in param_groups:
            self.add_param_group(param_group)

    def add_param_group(self, param_group: Dict[str, Any]) -> None:
        params = param_group["params"]
        param_group["params"] = [params] if isinstance(params, torch.Tensor) else list(params)

        for name, default in self.defaults.items():
            param_group.setdefault(name, default)

        self.param_groups.append(param_group)

    def step(self) -> None:
        raise NotImplementedError

    def zero_grad(self) -> None:
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is not None:
                    p.grad = None

    def state_dict(self) -> Dict[str, Any]:
        param_mappings = {}
        start_index = 0

        def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
            nonlocal start_index
            packed = {k: v for k, v in group.items() if k != "params"}
            param_mappings.update(
                {id(p): i for i, p in enumerate(group["params"], start_index) if id(p) not in param_mappings}
            )
            packed["params"] = [param_mappings[id(p)] for p in group["params"]]
            start_index += len(packed["params"])
            return packed

        param_groups = [pack_group(g) for g in self.param_groups]
        packed_state = {param_mappings[id(k)]: v for k, v in self.state.items()}

        return {"state": packed_state, "param_groups": param_groups}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        groups = self.param_groups
        saved_groups = deepcopy(state_dict["param_groups"])
        id_map = dict(
            zip(
                chain.from_iterable(g["params"] for g in saved_groups),
                chain.from_iterable(g["params"] for g in groups),
            )
        )

        def _cast(param, value, key=None):
            if isinstance(value, torch.Tensor):
                if key == 'step':
                    return value
                else:
                    return value.to(dtype=param.dtype, device=param.device) if param.is_floating_point() else value.to(device=param.device)
            elif isinstance(value, dict):
                return {k: _cast(param, v, key=k) for k, v in value.items()}
            elif isinstance(value, Iterable):
                return type(value)(_cast(param, v) for v in value)
            else:
                return value

        state = defaultdict(dict)
        for k, v in state_dict["state"].items():
            param = id_map[k]
            state[param] = _cast(param, v, param_id=k)

        def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]:
            new_group["params"] = group["params"]
            return new_group

        param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
        self.__dict__.update({"state": state, "param_groups": param_groups})

6. 自定义你的Optimizer

以SGD为例,我们来实现一个只有学习率的极简版优化器。

步骤如下:

  • 继承 Optimizer,声明构造函数,构造函数的形参必须含有 params,随后的一系列形参都是和该优化器相关的超参数。
  • 在构造函数中将相关超参数打包成 defaults,然后调用父类的构造函数。
  • 重写 step 方法,并用 @torch.no_grad() 装饰。
  • step 中遍历 self.param_groups,每一次遍历,声明 paramsgrads 列表,然后用 group 的数据进行填充。如果涉及到临时状态,还需要额外声明和临时状态相关的列表。
python 复制代码
class SimpleSGD(Optimizer):
    def __init__(self, params, lr=0.01):
        assert lr > 0.0

        defaults = dict(lr=lr)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            params = []
            grads = []

            for p in group['params']:
                if p.grad is not None:
                    params.append(p)
                    grads.append(p.grad)
            
            lr = group['lr']
            for param, grad in zip(params, grads):
                param.add_(grad, alpha=-lr)

仅设置学习率时,它所产生的结果和官方的SGD相同,读者可自行验证。


Ref

1\] \[2\] \[3\]

相关推荐
ONEYAC唯样3 分钟前
“在中国,为中国” 英飞凌汽车业务正式发布中国本土化战略
大数据·人工智能
mozun20209 分钟前
产业观察:哈工大机器人公司2025.4.22
大数据·人工智能·机器人·创业创新·哈尔滨·名校
-一杯为品-12 分钟前
【深度学习】#9 现代循环神经网络
人工智能·rnn·深度学习
硅谷秋水14 分钟前
ORION:通过视觉-语言指令动作生成的一个整体端到端自动驾驶框架
人工智能·深度学习·机器学习·计算机视觉·语言模型·自动驾驶
Java中文社群36 分钟前
最火向量数据库Milvus安装使用一条龙!
java·人工智能·后端
豆芽81944 分钟前
强化学习(Reinforcement Learning, RL)和深度学习(Deep Learning, DL)
人工智能·深度学习·机器学习·强化学习
山北雨夜漫步1 小时前
机器学习 Day14 XGboost(极端梯度提升树)算法
人工智能·算法·机器学习
董可伦1 小时前
Flink 源码编译
大数据·flink·源码
basketball6161 小时前
Python torchvision.transforms 下常用图像处理方法
开发语言·图像处理·python
兔子蟹子1 小时前
Java集合框架解析
java·windows·python