PyTorch源码系列(一)——Optimizer源码详解

目录

  • [1. Optimizer类](#1. Optimizer类)
  • [2. Optimizer概览](#2. Optimizer概览)
  • [3. 源码解析](#3. 源码解析)
    • [3.1 构造方法](#3.1 构造方法)
      • [3.1.1 全局设置情形](#3.1.1 全局设置情形)
      • [3.1.2 局部设置情形](#3.1.2 局部设置情形)
      • [3.1.3 覆盖测试](#3.1.3 覆盖测试)
      • [3.1.4 逐行讲解](#3.1.4 逐行讲解)
    • [3.2 add_param_group](#3.2 add_param_group)
    • [3.3 step](#3.3 step)
    • [3.4 zero_grad](#3.4 zero_grad)
    • [3.5 self.state](#3.5 self.state)
    • [3.6 state_dict](#3.6 state_dict)
    • [3.7 load_state_dict](#3.7 load_state_dict)
  • [4. SGD Optimizer](#4. SGD Optimizer)
  • [5. 极简版Optimizer源码](#5. 极简版Optimizer源码)
  • [6. 自定义你的Optimizer](#6. 自定义你的Optimizer)
  • Ref

1. Optimizer类

PyTorch的 Optimizer 类是深度学习模型中用于管理和更新模型参数的基类。它负责根据损失函数的梯度信息调整模型的参数,使模型逐步逼近最佳状态。Optimizer 类通过实现一些核心方法,如 step(),来执行参数更新过程,而 zero_grad() 方法则用于清除模型中所有参数的梯度。

每个优化器会存储参数组和相关的状态,例如学习率、动量等。不同的优化器(如SGD、Adam等)继承自 Optimizer 类,并根据各自的算法特点实现了不同的参数更新策略。此外,Optimizer 类还允许用户在初始化时指定超参数,如学习率等,这些超参数会影响参数的更新方式。

本文将详细讲解 Optimizer 类的源码,并以SGD优化器为例介绍如何自定义一个自己的优化器。

2. Optimizer概览

📝 本文在讲解源码时,只考虑源码的简化版本,而不考虑完整的源码。

除了构造方法外,Optimizer常用的几个方法如下:

python 复制代码
class Optimizer:
    def state_dict(self) -> Dict[str, Any]:
        """返回优化器的状态字典,保存当前的优化器状态以便之后恢复。"""
        ...

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """加载先前保存的状态字典,恢复优化器的状态。"""
        ...

    def zero_grad(self, set_to_none: bool = True) -> None:
        """将优化器中所有参数的梯度清零,通常在每次反向传播前调用。"""
        ...

    def step(self) -> None:
        """执行一步优化更新,用于根据梯度更新模型参数。"""
        raise NotImplementedError
    
    def add_param_group(self, param_group: Dict[str, Any]) -> None:
        """向优化器中添加新的参数组,用于管理不同的参数组(如不同学习率等)。"""
        ...

在自定义优化器时,必须继承 Optimizer 类,并实现 step() 方法,否则将会报错。

optimizer.py 文件中还定义了若干类型别名,如下:

python 复制代码
Args: TypeAlias = Tuple[Any, ...]
Kwargs: TypeAlias = Dict[str, Any]
StateDict: TypeAlias = Dict[str, Any]
ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
  • Args:通常用来表示函数中不定数量的位置参数。
  • Kwargs:通常用来表示函数中不定数量的关键字参数,键为参数名,值为相应的参数值。
  • StateDict:通常用于保存模型参数和优化器的状态信息。
  • ParamsT:是一个 torch.Tensor 的可迭代对象,或者是包含键值对的字典的可迭代对象。

3. 源码解析

3.1 构造方法

Optimizer 的构造方法如下:

python 复制代码
class Optimizer:
    r"""Base class for all optimizers.

    .. warning::
        Parameters need to be specified as collections that have a deterministic
        ordering that is consistent between runs. Examples of objects that don't
        satisfy those properties are sets and iterators over values of dictionaries.

    Args:
        params (iterable): an iterable of :class:`torch.Tensor` s or
            :class:`dict` s. Specifies what Tensors should be optimized.
        defaults: (dict): a dict containing default values of optimization
            options (used when a parameter group doesn't specify them).
    """

    def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None:
        self.defaults = defaults

        if isinstance(params, torch.Tensor):
            raise TypeError(
                "params argument given to the optimizer should be "
                "an iterable of Tensors or dicts, but got " + torch.typename(params)
            )

        self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)
        self.param_groups: List[Dict[str, Any]] = []

        param_groups = list(params)
        if len(param_groups) == 0:
            raise ValueError("optimizer got an empty parameter list")
        if not isinstance(param_groups[0], dict):
            param_groups = [{"params": param_groups}]

        for param_group in param_groups:
            self.add_param_group(param_group)

从源码可以看出,Optimizer两个形参和两个属性self.defaults 不算)。

形参:

  • params:由一系列Tensor组成的迭代器或是由一系列字典组成的迭代器。通常是模型的参数。例如 model.parameters()
  • defaults:一个键为字符串的字典。通常是和优化算法相关的全局 超参数。例如 lrmomentum 等。下文会解释为什么这里是「全局」。

属性:

  • state:一个键为Tensor,值为字典的字典。用来存储每个模型参数对应的临时状态,例如 momentum
  • param_groups:一个列表,其中的每一个元素都是一个键为字符串的字典。每一个元素对应了一个 param_group

看到这里,可能你仍然不明白 param_groups 是什么。既然它是复数形式,说明它是由一个个 param_group 组成的,每一个 param_group 的类型是 Dict[str, Any],因此 param_groups 的类型就是 List[Dict[str, Any]]

那么什么是 param_group 呢?我们知道Transformer模型通常由多个layer堆叠而成,绝大部分情况下,整个模型的训练会采用同一个学习率。但某些特殊场景下,我们可能希望不同的layer使用不同的学习率,此时就会涉及到一个个 param_group 了。

  • 对于前者,形参 params 的类型为 Iterable[Tensor],因为整个模型会共享同一套优化器参数,所以只需要指定全局优化器参数 defaults 即可,此时 param_groups 是一个长度为1的列表
  • 而对于后者,params 的类型为 Iterable[Dict[str, Any]],每一个字典包含了layer的参数和对应的局部优化器参数 。注意此时依然可以指定全局优化器参数,例如我们希望不同的layer使用不同的学习率,但希望所有的layer都使用同一个动量。此时 param_groups 是一个长度大于1的列表

我们来看更具体的例子。

3.1.1 全局设置情形

python 复制代码
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 3, bias=False)
        self.fc2 = nn.Linear(3, 3, bias=False)
        self.fc3 = nn.Linear(3, 1, bias=False)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = MLP()

# =============================
# 情况 1: 所有层使用同一套优化器参数
# =============================
# 此时,params 是 Iterable[Tensor],直接传递模型的所有参数
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
print(optimizer.param_groups)

输出(做了简化处理):

python 复制代码
[
    {
        'params': [fc1_tensor, fc2_tensor, fc3_tensor],
        'lr': 0.01,
        'momentum': 0.9,
        'dampening': 0,
        'weight_decay': 0,
        'nesterov': False,
        'maximize': False,
        'foreach': None,
        'differentiable': False,
        'fused': None
    }
]

可以看到此时 param_groups 是一个长度为1的列表。字典中的 params 对应了模型的参数,因为没有设置 bias,所以总共有三个张量。字典中剩余的键对应了优化器的超参数。

3.1.2 局部设置情形

python 复制代码
# ===============================
# 情况 2: 不同的层使用不同的优化器参数
# ===============================
# 此时,params 是 Iterable[Dict[str, Any]],可以为不同的层设置不同的学习率
optimizer = optim.SGD([
    {'params': model.fc1.parameters(), 'lr': 0.001},  # 第1层学习率为 0.001
    {'params': model.fc2.parameters(), 'lr': 0.01},   # 第2层学习率为 0.01
    {'params': model.fc3.parameters(), 'lr': 0.1}     # 第3层学习率为 0.1
])

print(optimizer.param_groups)

输出(做了简化处理):

python 复制代码
[
    {
        'params': [fc1_tensor],
        'lr': 0.001,
        'momentum': 0,
        'dampening': 0,
        'weight_decay': 0,
        'nesterov': False,
        'maximize': False,
        'foreach': None,
        'differentiable': False,
        'fused': None
    },
    {
        'params': [fc2_tensor],
        'lr': 0.01,
        'momentum': 0,
        'dampening': 0,
        'weight_decay': 0,
        'nesterov': False,
        'maximize': False,
        'foreach': None,
        'differentiable': False,
        'fused': None
    },
    {
        'params': [fc3_tensor],
        'lr': 0.1,
        'momentum': 0,
        'dampening': 0,
        'weight_decay': 0,
        'nesterov': False,
        'maximize': False,
        'foreach': None,
        'differentiable': False,
        'fused': None
    }
]

此时 param_groups 是一个长度为3的列表。列表中的每一个字典就是一个 param_group,存储了相应的layer参数和局部优化器参数。

⚠️ 这里使用"局部优化器参数"这一术语并不准确。事实上,"局部"和"全局"的概念仅在优化器实例化之前存在。一旦优化器实例化,全局参数将会逐步写入到每一个 param_group 中,例如 momentumdampening 等未显式定义的参数,实质上属于全局优化器参数,并会通过 add_param_group 方法自动地写入到每一个 param_group 中。因此,在优化器实例化后,每个 param_group 都拥有一套完整且独立的参数配置。

到目前为止,我们可以做一个简单总结。param_groups 是一个元素为字典的列表。当传入的 params 为由Tensor构成的迭代器时,此时 param_groups 的长度为1,即只含有一个 param_group。当传入的 params 为由字典构成的迭代器时,此时 param_groups 的长度为 len(params)

param_groups 中的所有字典的完全相同,均形如:

python 复制代码
param_group = {
    'params': [tensor_1, tensor_2, ...],  # 待优化的模型参数
    **defaults  # 全局优化器参数
}

但所有字典的却不尽相同。

3.1.3 覆盖测试

之前我们只考虑了「仅全局」和「仅局部」的情形,如果我们手动设置全局优化器参数,并且它和某些 param_group 中的局部优化器参数冲突了,那么这个全局的会覆盖掉局部的吗?

python 复制代码
optimizer = optim.SGD([
    {'params': model.fc1.parameters(), 'lr': 0.001, 'momentum': 0.3},
    {'params': model.fc2.parameters(), 'lr': 0.01},
    {'params': model.fc3.parameters(), 'lr': 0.1, 'nesterov': True}
], momentum=0.9, nesterov=False)

print(optimizer.param_groups)

输出:

python 复制代码
[
    {
        'params': [fc1_tensor],
        'lr': 0.001,
        'momentum': 0.3,
        'dampening': 0,
        'weight_decay': 0,
        'nesterov': False,
        'maximize': False,
        'foreach': None,
        'differentiable': False,
        'fused': None
    },
    {
        'params': [fc2_tensor],
        'lr': 0.01,
        'momentum': 0.9,
        'dampening': 0,
        'weight_decay': 0,
        'nesterov': False,
        'maximize': False,
        'foreach': None,
        'differentiable': False,
        'fused': None
    },
    {
        'params': [fc3_tensor],
        'lr': 0.1,
        'nesterov': True,
        'momentum': 0.9,
        'dampening': 0,
        'weight_decay': 0,
        'maximize': False,
        'foreach': None,
        'differentiable': False,
        'fused': None
    }
]

由此可以得出结论:全局优化器参数不会覆盖掉局部优化器参数

3.1.4 逐行讲解

现在我们已经对 paramsdefaultsparam_groups 这三个变量有了足够的了解(state 会放在下文讲解),接下来我们逐行剖析构造方法。

python 复制代码
def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None:
	# 将defaults设置成属性以方便在后续的add_param_group方法中使用
    self.defaults = defaults
	
	# params必须是关于tensor或dict的可迭代对象
    if isinstance(params, torch.Tensor):
        raise TypeError(
            "params argument given to the optimizer should be "
            "an iterable of Tensors or dicts, but got " + torch.typename(params)
        )
	
	# 初始化两大重要属性:state和param_groups
    self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)
    self.param_groups: List[Dict[str, Any]] = []
	
	# 将迭代器转化成列表,此时要么是List[Tensor]要么是List[Dict]
    param_groups = list(params)

	# 必须非空
    if len(param_groups) == 0:
        raise ValueError("optimizer got an empty parameter list")
    
    # 如果不是List[Dict],说明进行的是全局设置,此时param_groups是List[Tensor]
    # 类型,代表模型的所有参数。然后将其转变成List[Dict]类型,以达到格式统一的目的
    if not isinstance(param_groups[0], dict):
        param_groups = [{"params": param_groups}]
	
	# 将每一个param_group经过相关处理后添加到self.param_groups中
	# 这里也会将全局优化器参数defaults注入到每一个param_group中
    for param_group in param_groups:
        self.add_param_group(param_group)

通常来讲,params 会接收 model.parameters() 作为输入,如果不是,那么 params 必须是由字典构成的列表,且每一个字典必须含有 params 这个键,对应的值是模型的部分参数。

绝大多数情况下,我们认为下式成立:

model.parameters() = ⋃ i = 1 k param_groups [ i ] [ " params " ] \text{model.parameters()}=\bigcup_{i=1}^k \text{param\_groups}[i]["\text{params}"] model.parameters()=i=1⋃kparam_groups[i]["params"]

k k k 是 param_group 的个数,且 param_groups[i]["params"] 两两互不相交。

如果涉及到冻结模型的一部分参数,仅训练剩余的参数,那么上式就不再成立了。

3.2 add_param_group

有了3.1小节的基础后,这里直接逐行讲解源码。

python 复制代码
def add_param_group(self, param_group: Dict[str, Any]) -> None:
    r"""Add a param group to the :class:`Optimizer`'s `param_groups`.

    This can be useful when fine tuning a pre-trained network as frozen layers can be made
    trainable and added to the :class:`Optimizer` as training progresses.

    Args:
        param_group (dict): Specifies what Tensors should be optimized along with group
            specific optimization options.
    """

    # 确保传入的param_group一定是一个字典
    if not isinstance(param_group, dict):
        raise TypeError(f"param_group must be a dict, but got {type(param_group)}")

    # 获取该group中的模型参数部分,然后将其转变为List[Tensor]类型
    params = param_group["params"]
    if isinstance(params, torch.Tensor):
        param_group["params"] = [params]
    elif isinstance(params, set):
        raise TypeError(
            "optimizer parameters need to be organized in ordered collections, but "
            "the ordering of tensors in sets will change between runs. Please use a list instead."
        )
    else:
        param_group["params"] = list(params)

    # 参数检查
    for param in param_group["params"]:
        # 确保所有的参数必须都是Tensor,否则无法优化
        if not isinstance(param, torch.Tensor):
            raise TypeError(
                "optimizer can only optimize Tensors, "
                "but one of the params is " + torch.typename(param)
            )

        # 确保所有的参数都是叶子节点
        if not (param.is_leaf or param.retains_grad):
            raise ValueError("can't optimize a non-leaf Tensor")

    # 将全局优化器参数注入到当前的group中,setdefault保证了这一过程不会覆盖掉局部优化器参数
    for name, default in self.defaults.items():
        param_group.setdefault(name, default)

    # 检查是否存在重复的参数
    # 目前出现重复的参数并不会报错
    params = param_group["params"]
    if len(params) != len(set(params)):
        warnings.warn(
            "optimizer contains a parameter group with duplicate parameters; "
            "in future, this will cause an error; "
            "see github.com/pytorch/pytorch/issues/40967 for more information",
            stacklevel=3,
        )

    # 判断当前的group中是否有参数和已经添加过的group中的参数重复
    # 两个集合交集为空,isdisjoint()返回True
    param_set: Set[torch.Tensor] = set()
    for group in self.param_groups:
        param_set.update(set(group["params"]))

    if not param_set.isdisjoint(set(param_group["params"])):
        raise ValueError("some parameters appear in more than one parameter group")

    # 将当前的group添加到self.param_groups中
    self.param_groups.append(param_group)

add_param_group 源码看似复杂,但归根结底也就那么几行代码是真正起到作用的,这里给出一个简化版本:

python 复制代码
def add_param_group(self, param_group: Dict[str, Any]) -> None:
    params = param_group["params"]
    param_group["params"] = [params] if isinstance(params, torch.Tensor) else list(params)

    for name, default in self.defaults.items():
        param_group.setdefault(name, default)

    self.param_groups.append(param_group)

3.3 step

Optimizer 类并没有实现 step() 方法:

python 复制代码
def step(self) -> None:
    """Performs a single optimization step (parameter update)."""
    raise NotImplementedError

自定义优化器时,需要继承 Optimizer 类,并实现该方法,否则会报错。

因为 step() 不返回任何值,所以需要实现模型参数的原地更新

3.4 zero_grad

如下是简化版的源码:

python 复制代码
def zero_grad(self, set_to_none: bool = True) -> None:
    r"""Resets the gradients of all optimized :class:`torch.Tensor` s.

    Args:
        set_to_none (bool): instead of setting to zero, set the grads to None.
            This will in general have lower memory footprint, and can modestly improve performance.
            However, it changes certain behaviors. For example:
            1. When the user tries to access a gradient and perform manual ops on it,
            a None attribute or a Tensor full of 0s will behave differently.
            2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
            are guaranteed to be None for params that did not receive a gradient.
            3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
            (in one case it does the step with a gradient of 0 and in the other it skips
            the step altogether).
    """

	# 遍历每一个group,然后遍历group中的每一个参数
    for group in self.param_groups:
        for p in group["params"]:
        	
        	# 如果p的梯度不为None,说明需要清空
            if p.grad is not None:
                if set_to_none:
                    p.grad = None
                else:
                	# grad_fn 表示该张量是由某个操作生成的
                	# 因此将该张量从计算图中分离出来
                    if p.grad.grad_fn is not None:
                        p.grad.detach_()
                    else:
                    	# 不在计算图中,原地关闭梯度以防止后续追踪
                        p.grad.requires_grad_(False)
                    
                    # 清零梯度
                    p.grad.zero_()

从注释可以看出,当梯度被设置为 None 时,梯度张量会被释放,从而减少内存占用。而如果梯度被清零(即将其所有元素设置为 0),梯度张量的内存仍然会被保留。

由于 set_to_none 默认为 True,因此 zero_grad 源码可以进一步简化:

python 复制代码
def zero_grad(self) -> None:
    for group in self.param_groups:
        for p in group["params"]:
            if p.grad is not None:
                p.grad = None

3.5 self.state

Optimizer 中有两大重要属性:stateparam_groups。先前我们已经了解了 param_groups,现在来看 state

我们已经知道构造方法中会通过多次调用 add_param_group 来初始化 self.param_groups,但截止目前,似乎并没有方法能够初始化 self.state,那它是怎么初始化的?以及它到底"长什么样"呢?

self.state 用来存储与模型参数相关的临时状态 。对于SGD with momentum而言,每个参数都需要维护一个动量。对于Adam而言,每个参数不仅需要维护一个动量(一阶矩),还需要维护一个平方梯度(二阶矩)。很显然,对于一个待优化的Tensor,它的临时状态和它的形状是相同的 ,并且对于该Tensor,可能有多个临时状态需要维护,每个临时状态都有一个自己的名字,由此推测 self.state 的类型应当是 Dict[Tensor, Dict[str, Tensor]],这与源码中声明的相同:

python 复制代码
self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)

self.state 的键是一个Tensor,值是一个字典。字典存储了优化该Tensor的一些临时状态,键是临时状态的名称,值是相应的状态。

以SGD优化器为例,进行一次单步更新,然后查看它的 state 属性:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim


class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(2, 3, bias=False)
        self.fc2 = nn.Linear(3, 1, bias=False)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


model = SimpleMLP()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

inputs = torch.randn(2)
target = torch.randn(1)

output = model(inputs)
loss = criterion(output, target)

loss.backward()
optimizer.step()

print(optimizer.state)

输出:

python 复制代码
defaultdict(<class 'dict'>, {
    Parameter containing:
    tensor([[-0.4128, -0.6015],
            [-0.3628,  0.6162],
            [ 0.4367,  0.3409]], requires_grad=True): {
        'momentum_buffer': tensor([[ 0.0000,  0.0000],
                                   [-0.0052,  0.0581],
                                   [-0.0130,  0.1446]])
    },
    Parameter containing:
    tensor([[0.5394, 0.1077, 0.2748]], requires_grad=True): {
        'momentum_buffer': tensor([[0.0000, 0.3398, 0.1585]])
    }
})

可以得知,对于SGD优化器而言,每个待优化的Tensor只需要维护一个临时状态:momentum_buffer,即当前的动量。且临时状态的形状与待优化的Tensor相同。

将SGD换成Adam,再来看看结果:

python 复制代码
defaultdict(<class 'dict'>, {
    Parameter containing:
    tensor([[ 0.4954, -0.0392],
            [ 0.0778, -0.5769],
            [-0.3332, -0.0659]], requires_grad=True): {
        'step': tensor(1.),
        'exp_avg': tensor([[-0.1632,  0.0252],
                           [-0.0513,  0.0079],
                           [ 0.0000,  0.0000]]),
        'exp_avg_sq': tensor([[2.6648e-03, 6.3460e-05],
                              [2.6274e-04, 6.2572e-06],
                              [0.0000e+00, 0.0000e+00]])
    },
    Parameter containing:
    tensor([[-0.4492, -0.1479,  0.4789]], requires_grad=True): {
        'step': tensor(1.),
        'exp_avg': tensor([[0.1821, 0.0577, 0.0000]]),
        'exp_avg_sq': tensor([[0.0033, 0.0003, 0.0000]])
    }
})

此时每个Tensor需要维护三个临时状态:stepexp_avgexp_avg_sqstep 是当前更新的步数,exp_avg 就是SGD中的动量(不完全相同),exp_avg_sq 是平方梯度。

知道了 self.state 长什么样后,我们需要了解一下它是如何初始化的。

事实上 Optimizer 类并没有实现 self.state 的初始化,因为它的初始化是在 step() 中完成的,所以我们需要关注 Optimizer 的子类,这里以SGD为例。

python 复制代码
class SGD(Optimizer):
    def step(self):
        for group in self.param_groups:
            # 用来存储模型参数,梯度,动量
            params: List[Tensor] = []
            grads: List[Tensor] = []
            momentum_buffer_list: List[Optional[Tensor]] = []

            # 填充params、grads、momentum_buffer_list
            self._init_group(group, params, grads, momentum_buffer_list)

            # 执行sgd优化算法
            sgd(params, grads, momentum_buffer_list, ...)

            if group["momentum"] != 0:
                for p, momentum_buffer in zip(params, momentum_buffer_list):
                    # 获取Tensor
                    state = self.state[p]
                    # 更新Tensor的临时状态
                    state["momentum_buffer"] = momentum_buffer

    def _init_group(self, group, params, grads, momentum_buffer_list):
        for p in group["params"]:
            if p.grad is not None:
                params.append(p)
                grads.append(p.grad)

                if group["momentum"] != 0:
                    # 因为self.state是defaultdict,所以初始时state会自动创建为一个字典
                    state = self.state[p]
                    # 因为初始时state没有momentum_buffer这个键
                    # 所以momentum_buffer_list的初始值为[None, None, None, ...]
                    momentum_buffer_list.append(state.get("momentum_buffer"))

每一步更新,_init_group 会在SGD算法执行前被调用,它用来获取模型所有待更新的参数,对应的已经计算的梯度,以及对应的上一时刻的动量。self.state 会在 _init_group 中进行初始化 。我们可以通过在 self._init_groupsgd() 这两个语句之间加入以下代码来查看相应的信息:

python 复制代码
print(self.state)
print(momentum_buffer_list)

输出:

python 复制代码
defaultdict(<class 'dict'>, {
    Parameter containing:
    tensor([[ 0.4679, -0.6531],
            [-0.4707, -0.2854],
            [ 0.6846, -0.6576]], requires_grad=True): {},
    
    Parameter containing:
    tensor([[-0.4362,  0.4155,  0.1798]], requires_grad=True): {}
})

[None, None]

这说明初始时,每一个Tensor对应的临时状态为空字典,且 momentum_buffer_list 的初始值为 [None, None, ...]

3.6 state_dict

state_dict 的作用是保存优化器当前的状态,以便在之后的训练中恢复或继续使用。

显然 state_dict 应当保存优化器的两大重要属性

python 复制代码
state_dict = {
    "state": state,
    "param_groups": param_groups,
}

但根据之前的分析,state 中的键会涉及到模型参数,此外,param_groups 中的 params 也会涉及到模型参数,如果就这样直接保存,相当于我们在保存了优化器的同时还保存了两份模型参数(state一份,param_groups一份),这显然是不可行的。

一种直观的想法是将这些模型参数映射成唯一的数字ID,而这些ID是几乎不占空间的。假设 param_groups 含有 k k k 个 param_group,那么我们可以从第一个group开始,从0开始从前往后依次编号直至第 k k k 个group。

当然,我们还需要对 state 进行编号,由于 stateparam_groups 中的参数不一定一一对应(因为 param_groups 可能会含有重复的参数,但 state 不会,所以二者长度不一定相等,具体见 add_param_group 源码),因此我们不能从0开始从前往后依次编号。这启发我们可以构造一个参数的内存地址到ID的映射 mapping,这样在编号 state 的时候,我们就可以通过 mapping[id(tensor)] 来获取模型参数的ID了。

源码解析(已做简化):

python 复制代码
def state_dict(self) -> Dict[str, Any]:
    # 构建模型参数地址到ID的映射
    param_mappings: Dict[int, int] = {}
    start_index = 0

    def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
        nonlocal start_index
        # 处理优化器超参数部分
        packed = {k: v for k, v in group.items() if k != "params"}
        # 更新映射
        param_mappings.update(
            {
                id(p): i
                for i, p in enumerate(group["params"], start_index)
                if id(p) not in param_mappings
            }
        )
        # 处理模型参数部分,将具体的参数映射为ID
        packed["params"] = [param_mappings[id(p)] for p in group["params"]]
        start_index += len(packed["params"])
        return packed

    # 将所有group中的所有模型参数映射为ID
    param_groups = [pack_group(g) for g in self.param_groups]

    # 将state中的所有模型参数映射为ID
    packed_state = {
        param_mappings[id(k)]: v
        for k, v in self.state.items()
    }

    state_dict = {
        "state": packed_state,
        "param_groups": param_groups,
    }

    return state_dict

我们可以通过以下代码来查看 state_dict 的样子:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import itertools

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 3, bias=False)
        self.fc2 = nn.Linear(3, 3, bias=False)
        self.fc3 = nn.Linear(3, 1, bias=False)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = MLP()

optimizer = optim.SGD([
    {"params": itertools.chain(model.fc1.parameters(), model.fc2.parameters()), "lr": 0.01},
    {"params": model.fc3.parameters(), "lr": 0.1},
], momentum=0.9)

x = torch.randn(1, 2)
y = torch.randn(1, 1)

criterion = nn.MSELoss()

output = model(x)
loss = criterion(output, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(optimizer.state_dict().keys())

print(optimizer.state_dict()['state'])
print(optimizer.state_dict()['param_groups'])

输出:

python 复制代码
dict_keys(['state', 'param_groups'])

# state部分
{
    0: {
        'momentum_buffer': tensor([[-0.0270,  0.1874],
                                   [ 0.0000,  0.0000],
                                   [ 0.0000,  0.0000]])
    },
    1: {
        'momentum_buffer': tensor([[ 0.0000,  0.0000,  0.0000],
                                   [-0.2232,  0.0000,  0.0000],
                                   [ 0.0000,  0.0000,  0.0000]])
    },
    2: {
        'momentum_buffer': tensor([[ 0.0000, -0.3215,  0.0000]])
    }
}

# param_groups部分
[
    {
        'lr': 0.01,
        'momentum': 0.9,
        'dampening': 0,
        'weight_decay': 0,
        'nesterov': False,
        'maximize': False,
        'foreach': None,
        'differentiable': False,
        'fused': None,
        'params': [0, 1]
    },
    {
        'lr': 0.1,
        'momentum': 0.9,
        'dampening': 0,
        'weight_decay': 0,
        'nesterov': False,
        'maximize': False,
        'foreach': None,
        'differentiable': False,
        'fused': None,
        'params': [2]
    }
]

可以看到 stateparam_groups 中的模型参数全被映射成了ID。

3.7 load_state_dict

因为之前的ID映射是根据 param_groups 构造的,所以在load的时候,我们也要根据 param_groups 去建立一一对应关系,此时需要构造ID到模型参数的映射。

这里有一个细节,在还原 params 的时候,我们可以直接通过ID进行还原,但是在还原 state 的时候,我们要确保各个临时状态和模型参数是处于同一设备上。

源码解析(已做简化):

python 复制代码
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
    # 获取当前优化器的param_groups和之前保存的param_groups
    groups = self.param_groups
    saved_groups = deepcopy(state_dict["param_groups"])

    # 确保group的数量相等
    if len(groups) != len(saved_groups):
        raise ValueError(
            "loaded state dict has a different number of " "parameter groups"
        )
    
    # 确保每个group中的模型参数个数相等
    param_lens = (len(g["params"]) for g in groups)
    saved_lens = (len(g["params"]) for g in saved_groups)
    if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
        raise ValueError(
            "loaded state dict contains a parameter group "
            "that doesn't match the size of optimizer's group"
        )

    # 之前是根据enumerate构造的正向映射
    # 现在直接用zip构造反向映射
    id_map = dict(
        zip(
            chain.from_iterable(g["params"] for g in saved_groups),
            chain.from_iterable(g["params"] for g in groups),
        )
    )

    # 用于递归地将Tensor对应的临时状态移动到和Tensor相同的设备上,并确保数据类型相同
    def _cast(param, value, key=None):
        if isinstance(value, torch.Tensor):
            if key == 'step':
                return value
            else:
                if param.is_floating_point():
                    return value.to(dtype=param.dtype, device=param.device)
                else:
                    # 例如模型是一个量化模型,此时不必转化value的数据类型
                    return value.to(device=param.device)

        elif isinstance(value, dict):
            return {
                k: _cast(param, v, key=k)
                for k, v in value.items()
            }
        elif isinstance(value, Iterable):
            return type(value)(_cast(param, v) for v in value)
        else:
            return value

    # 还原state
    # 注意要将临时状态转移到和param相同的设备上,有些时候还需要确保数据类型相同
    state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
    for k, v in state_dict["state"].items():
        param = id_map[k]
        state[param] = _cast(
            param, v, param_id=k, param_groups=state_dict["param_groups"]
        )

    # 还原param_groups,只有params需要修改
    def update_group(
        group: Dict[str, Any], new_group: Dict[str, Any]
    ) -> Dict[str, Any]:
        new_group["params"] = group["params"]
        return new_group

    param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]

    # 更新
    self.__dict__.update({"state": state, "param_groups": param_groups})

4. SGD Optimizer

在了解 Optimizer 源码后,我们开看SGD优化器是如何实现的。

构造函数(简化版):

python 复制代码
class SGD(Optimizer):
    def __init__(
        self,
        params,
        lr: float = 1e-3,
        momentum: float = 0,
        dampening: float = 0,
        weight_decay: float = 0,
        nesterov=False,
    ):
        # 将独属于SGD优化器的超参数打包成defaults
        defaults = dict(
            lr=lr,
            momentum=momentum,
            dampening=dampening,
            weight_decay=weight_decay,
            nesterov=nesterov,
        )

        # 调用父类的构造函数初始化
        super().__init__(params, defaults)

无非就做了两件事:

  1. 将独属于该优化器的超参数打包成 defaults
  2. 调用父类的构造函数进行初始化(不然没有 stateparam_groups 这两个属性)

step_init_group 在3.5节中已经介绍过,这里关注sgd函数如何实现。

⚠️ 对SGD算法不熟悉的读者可以看这篇博客:深入解析SGD、Momentum与Nesterov:优化算法的对比与应用述

python 复制代码
def sgd(
    params: List[Tensor],
    grads: List[Tensor],
    momentum_buffer_list: List[Optional[Tensor]],
    weight_decay: float,
    momentum: float,
    lr: float,
    dampening: float,
    nesterov: bool,
):
    for i, param in enumerate(params):
        grad = grads[i]

        # 权重衰减
        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)

        if momentum != 0:
            buf = momentum_buffer_list[i]  # 存储的是每个参数上一时刻的动量

            if buf is None:
                buf = torch.clone(grad).detach()  # 初始时动量就是梯度,因为m_0 = 0
                momentum_buffer_list[i] = buf
            else:
                # buf = momentum * buf + (1 - dampening) * grad
                buf.mul_(momentum).add_(grad, alpha=1 - dampening)

            if nesterov:
                grad = grad.add(buf, alpha=momentum)
            else:
                grad = buf

        # 更新权重:param = param - lr * grad
        param.add_(grad, alpha=-lr)

5. 极简版Optimizer源码

截止目前,我们可以对 Optimizer 基类的源码进行汇总,给出一个极简版的实现:

python 复制代码
from collections import defaultdict
from copy import deepcopy
from itertools import chain
from typing import Any, Dict, Iterable

import torch


class Optimizer:
    def __init__(self, params, defaults: Dict[str, Any]) -> None:
        self.defaults = defaults
        self.state = defaultdict(dict)
        self.param_groups = []

        param_groups = list(params)
        if not isinstance(param_groups[0], dict):
            param_groups = [{"params": param_groups}]

        for param_group in param_groups:
            self.add_param_group(param_group)

    def add_param_group(self, param_group: Dict[str, Any]) -> None:
        params = param_group["params"]
        param_group["params"] = [params] if isinstance(params, torch.Tensor) else list(params)

        for name, default in self.defaults.items():
            param_group.setdefault(name, default)

        self.param_groups.append(param_group)

    def step(self) -> None:
        raise NotImplementedError

    def zero_grad(self) -> None:
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is not None:
                    p.grad = None

    def state_dict(self) -> Dict[str, Any]:
        param_mappings = {}
        start_index = 0

        def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
            nonlocal start_index
            packed = {k: v for k, v in group.items() if k != "params"}
            param_mappings.update(
                {id(p): i for i, p in enumerate(group["params"], start_index) if id(p) not in param_mappings}
            )
            packed["params"] = [param_mappings[id(p)] for p in group["params"]]
            start_index += len(packed["params"])
            return packed

        param_groups = [pack_group(g) for g in self.param_groups]
        packed_state = {param_mappings[id(k)]: v for k, v in self.state.items()}

        return {"state": packed_state, "param_groups": param_groups}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        groups = self.param_groups
        saved_groups = deepcopy(state_dict["param_groups"])
        id_map = dict(
            zip(
                chain.from_iterable(g["params"] for g in saved_groups),
                chain.from_iterable(g["params"] for g in groups),
            )
        )

        def _cast(param, value, key=None):
            if isinstance(value, torch.Tensor):
                if key == 'step':
                    return value
                else:
                    return value.to(dtype=param.dtype, device=param.device) if param.is_floating_point() else value.to(device=param.device)
            elif isinstance(value, dict):
                return {k: _cast(param, v, key=k) for k, v in value.items()}
            elif isinstance(value, Iterable):
                return type(value)(_cast(param, v) for v in value)
            else:
                return value

        state = defaultdict(dict)
        for k, v in state_dict["state"].items():
            param = id_map[k]
            state[param] = _cast(param, v, param_id=k)

        def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]:
            new_group["params"] = group["params"]
            return new_group

        param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
        self.__dict__.update({"state": state, "param_groups": param_groups})

6. 自定义你的Optimizer

以SGD为例,我们来实现一个只有学习率的极简版优化器。

步骤如下:

  • 继承 Optimizer,声明构造函数,构造函数的形参必须含有 params,随后的一系列形参都是和该优化器相关的超参数。
  • 在构造函数中将相关超参数打包成 defaults,然后调用父类的构造函数。
  • 重写 step 方法,并用 @torch.no_grad() 装饰。
  • step 中遍历 self.param_groups,每一次遍历,声明 paramsgrads 列表,然后用 group 的数据进行填充。如果涉及到临时状态,还需要额外声明和临时状态相关的列表。
python 复制代码
class SimpleSGD(Optimizer):
    def __init__(self, params, lr=0.01):
        assert lr > 0.0

        defaults = dict(lr=lr)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            params = []
            grads = []

            for p in group['params']:
                if p.grad is not None:
                    params.append(p)
                    grads.append(p.grad)
            
            lr = group['lr']
            for param, grad in zip(params, grads):
                param.add_(grad, alpha=-lr)

仅设置学习率时,它所产生的结果和官方的SGD相同,读者可自行验证。


Ref

[1] https://blog.csdn.net/zzxxxaa1/article/details/121144570?spm=1001.2014.3001.5502

[2] https://www.hjhgjghhg.com/archives/119/

[3] https://pytorch.org/docs/stable/generated/torch.optim.SGD.html

相关推荐
m0_748232921 分钟前
DALL-M:基于大语言模型的上下文感知临床数据增强方法 ,补充
人工智能·语言模型·自然语言处理
袁袁袁袁满4 分钟前
100天精通Python(爬虫篇)——第113天:‌爬虫基础模块之urllib详细教程大全
开发语言·爬虫·python·网络爬虫·爬虫实战·urllib·urllib模块教程
szxinmai主板定制专家7 分钟前
【国产NI替代】基于FPGA的32通道(24bits)高精度终端采集核心板卡
大数据·人工智能·fpga开发
海棠AI实验室10 分钟前
AI的进阶之路:从机器学习到深度学习的演变(三)
人工智能·深度学习·机器学习
机器懒得学习21 分钟前
基于YOLOv5的智能水域监测系统:从目标检测到自动报告生成
人工智能·yolo·目标检测
老大白菜26 分钟前
Python 爬虫技术指南
python
QQ同步助手36 分钟前
如何正确使用人工智能:开启智慧学习与创新之旅
人工智能·学习·百度
AIGC大时代39 分钟前
如何使用ChatGPT辅助文献综述,以及如何进行优化?一篇说清楚
人工智能·深度学习·chatgpt·prompt·aigc
流浪的小新43 分钟前
【AI】人工智能、LLM学习资源汇总
人工智能·学习
古希腊掌管学习的神2 小时前
[搜广推]王树森推荐系统——矩阵补充&最近邻查找
python·算法·机器学习·矩阵