Pytorch:optimizer.zero_grad(), loss.backward(), optimizer.step()

在训练过程中常看到如下代码:

python 复制代码
optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(loss).backward()
grad_scaler.step(optimizer)
grad_scaler.update()

这三个函数的作用是:

  • 在训练过程中先调用 optimizer.zero_grad() 清空梯度
  • 再调用 loss.backward() 反向传播
  • 最后调用optimizer.step()更新模型参数:

optimizer.zero_grad():

Optimizer类在实例化时会在构造函数中创建一个param_groups列表

param_group'params':由传入的模型参数组成的列表,即实例化Optimizer类时传入该group的参数,如果参数没有分组,则为整个模型的参数model.parameters(),每个参数是一个torch.nn.parameter.Parameter对象。

optimizer.zero_grad()函数会遍历模型的所有参数,通过:

  • p.grad.detach_()方法截断反向传播的梯度流;
  • 再通过p.grad.zero_()函数将每个参数的梯度值设为0,即上一次的梯度记录被清空。

具体来说,optimizer.zero_grad() 会将优化器中所有可学习参数的梯度设为 0。这样,在下一次前向传递计算和反向传播计算时,之前的梯度就不会对当前的梯度产生影响。

在反向传播计算梯度之前对上一次迭代时记录的梯度清零,参数 set_to_none 设置为 True 时会直接将参数梯度设置为 None,从而减小内存使用,但通常情况下不建议设置这个参数,因为梯度设置为 None 和 0 在 PyTorch 中处理逻辑会不一样。

python 复制代码
def zero_grad(self, set_to_none: bool = False):
        r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero.

        Arguments:
            set_to_none (bool): instead of setting to zero, set the grads to None.
                This is will in general have lower memory footprint, and can modestly improve performance.
                However, it changes certain behaviors. For example:
                1. When the user tries to access a gradient and perform manual ops on it,
                a None attribute or a Tensor full of 0s will behave differently.
                2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``s
                are guaranteed to be None for params that did not receive a gradient.
                3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
                (in one case it does the step with a gradient of 0 and in the other it skips
                the step altogether).
        """
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    if set_to_none:
                        p.grad = None
                    else:
                        if p.grad.grad_fn is not None:
                            p.grad.detach_()
                        else:
                            p.grad.requires_grad_(False)
                        p.grad.zero_()

因为训练的过程通常使用mini-batch方法,所以如果不将梯度清零的话,梯度会与上一个batch的数据相关,因此该函数要写在反向传播和梯度下降之前。

更多详细内容,参考:Pytorch:torch.optim模块

loss.backward():

PyTorch的反向传播(即tensor.backward())是通过autograd包来实现的,autograd包会根据tensor进行过的数学运算来自动计算其对应的梯度。

具体来说,torch.tensor是autograd包的基础类,如果你设置tensor的requires_grads为True,就会开始跟踪这个tensor上面的所有运算,如果你做完运算后使用tensor.backward(),所有的梯度就会自动运算,tensor的梯度将会累加到它的.grad属性里面去。

具体参考:Pytorch:backward()函数详解

optimizer.step()

optimizer.step():此方法主要完成一次模型参数的更新。

当使用 backward() 计算网络参数的梯度后,需要使用 optimizer.step() 来根据梯度更新网络参数的值。

相关推荐
邵宇然4 分钟前
轻量级推理引擎开发:从模型加载到推理执行的 Rust 实战
人工智能
装不满的克莱因瓶5 分钟前
掌握语义分割经典模型 FCN——从像素分类到端到端分割的奠基之作
人工智能·python·深度学习·算法·机器学习·分类·数据挖掘
ACP广源盛139246256735 分钟前
GSV5600@ACP#多接口协议转换芯片,物理 AI 便携终端的互联核心
大数据·人工智能·分布式·嵌入式硬件·spark
لا معنى له6 分钟前
NeoVerse: Enhancing 4D World Model with in-the-wild Monocular Videos
人工智能·笔记·机器学习·语言模型
147API7 分钟前
Fable 5访问暂停后,模型接入层不能再只写死一个模型名
大数据·人工智能·api·claude
KaMeidebaby9 分钟前
卡梅德生物技术快报 | 噬菌体展示 12 肽文库在蛋白表位定位中的应用与实验数据
大数据·人工智能·架构·spark·新浪微博
JIAXIN_culture14 分钟前
甘肃景观工程定制服务FAQ:企业如何选对合作方?
大数据·人工智能
青绿蓝LCA低碳研究院15 分钟前
环保的本质:从“末端修补”到“系统重构”的生存范式转移 - 蓝色星球
大数据·人工智能·经验分享·重构
xwz小王子16 分钟前
ICRA 2026深度观察:全栈闭环成标配,中国具身智能势力显著崛起
大数据·人工智能·算法
逻辑探险家17 分钟前
2026 中国 GEO 服务商综合实力评测
大数据·人工智能·产品运营