Pytorch: loss.backward()背后的原理

文章目录

  • [0. 前言](#0. 前言)
  • [1. 概述](#1. 概述)
  • [2. 张量的 `requires_grad` 属性](#2. 张量的 requires_grad 属性)
  • [3. 前向传播(Forward Pass)](#3. 前向传播(Forward Pass))
  • [4. 反向传播(Backward Pass)](#4. 反向传播(Backward Pass))
  • [5. 示例代码](#5. 示例代码)

0. 前言

loss.backward() 是 PyTorch 中用于计算梯度的函数。它在训练神经网络时发挥着关键作用。理解 loss.backward() 的内部原理有助于深入了解 PyTorch 的自动微分机制。

1. 概述

在PyTorch中,loss.backward() 是反向传播的核心函数,它负责计算模型参数相对于损失函数的梯度。这个过程是基于自动微分(Automatic Differentiation,简称autodiff)技术实现的,具体来说是采用了反向模式的自动微分(Reverse-Mode AD)。下面是对这一过程内部原理的简要说明:

自动微分原理:

(1) 前向传播 :首先,模型通过前向传播计算输出值。在这个过程中,PyTorch 会记录计算图(Computation Graph),这个计算图记录了从输入到输出的每一步运算及其依赖关系。每个张量(Tensor)都有一个.grad_fn属性,指向一个函数,这个函数描述了如何计算这个张量关于其输入的梯度。

(2) 反向传播 :当调用 loss.backward() 时,PyTorch 开始反向遍历计算图。这个过程从损失函数开始,沿着图反向传播误差,计算每一个参与运算的张量关于损失的梯度。这是通过链式法则(Chain Rule)完成的,即将损失对某个中间变量的导数分解为其后续操作导数的乘积。

(3)梯度计算 :在反向传播过程中,每个运算都会计算其输出关于输入的梯度,并将这个梯度累积到输入张量的.grad属性中(如果是标量损失,它没有.grad属性)。这意味着如果一个张量被多个路径使用,它的.grad属性会累积从所有路径来的梯度。

(4) 梯度累加与同步 :在分布式训练中,如果启用了梯度同步(例如使用DataParallelDistributedDataParallel),PyTorch还会在所有设备之间同步计算出的梯度,确保每个参数的梯度是所有设备上相应梯度的平均值。

(5)梯度裁剪与优化 :在反向传播完成后,用户通常会执行梯度裁剪以避免梯度爆炸问题,随后使用优化器(如SGD, Adam等)来更新模型参数,即执行optimizer.step()。这一步实际上根据计算出的梯度和优化算法更新参数。

2. 张量的 requires_grad 属性

要使张量参与梯度计算,需要将其 requires_grad 属性设置为 True

python 复制代码
import torch

# 创建一个张量,并设置 requires_grad=True
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

3. 前向传播(Forward Pass)

在前向传播过程中,PyTorch 记录所有操作以构建计算图。例如:

python 复制代码
y = x * 2
z = y.mean()

这里,yz 都是通过对 x 的操作得到的张量,计算图会记录这些操作。

4. 反向传播(Backward Pass)

当我们调用 loss.backward() 时,PyTorch 会从计算图中的输出节点开始,沿着图的边缘向后遍历,并计算梯度。这一过程包括以下步骤:

(1)计算梯度

PyTorch 会计算每个张量相对于最终标量输出(如损失)的梯度。

(2)链式法则(Chain Rule)

通过链式法则,PyTorch 会将局部梯度乘积从输出层向输入层传播。

python 复制代码
# 计算损失的梯度
z.backward()

在这个例子中,z 是一个标量。z.backward() 会计算 z 相对于 x 的梯度,并将这些梯度存储在 x.grad 中。

5. 示例代码

以下是一个完整的示例代码,展示了 loss.backward() 的整个过程:

python 复制代码
import torch

# 创建一个张量,并设置 requires_grad=True
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

# 前向传播
y = x * 2
z = y.mean()

# 反向传播
z.backward()

# 输出梯度
print(x.grad)  # 输出: tensor([0.6667, 0.6667, 0.6667])

上述代码中:
z ⃗ = 2 3 ∗ x ⃗ = 2 3 [ x 0 , x 1 , x 2 ] \vec z =\frac {2}{3}*\vec{x}=\frac {2}{3}[x_0,x_1,x_2] z =32∗x =32[x0,x1,x2]

梯度计算如下:
∇ x Z = [ 2 3 , 2 3 , 2 3 ] \nabla_xZ = [\frac {2}{3}, \frac {2}{3}, \frac {2}{3}] ∇xZ=[32,32,32]


欢迎关注本人,我是喜欢搞事的程序猿; 一起进步,一起学习;

欢迎关注知乎/CSDN:SmallerFL;

也欢迎关注我的wx公众号(精选高质量文章):一个比特定乾坤

相关推荐
机器之心4 分钟前
100美元、8000行代码手搓ChatGPT,Karpathy最新开源项目爆火,一夜近5k star
人工智能·openai
RTC老炮5 分钟前
webrtc弱网-BitrateEstimator类源码分析与算法原理
网络·人工智能·算法·机器学习·webrtc
星期天要睡觉9 分钟前
计算机视觉(opencv)——基于 MediaPipe 的手势识别系统
人工智能·opencv·计算机视觉
三年呀14 分钟前
指纹技术深度剖析:从原理到实践的全方位探索
图像处理·人工智能·计算机视觉·指纹识别·生物识别技术·安全算法
学习的周周啊1 小时前
一人AI自动化开发体系(Cursor 驱动):从需求到上线的全流程闭环与实战清单
运维·人工智能·自动化·ai编程·全栈·devops·cursor
后端小肥肠1 小时前
明星漫画总画不像?用 Coze +即梦 4 工作流,素描风漫画3分钟搞定,小白也能上手
人工智能·aigc·coze
flay2 小时前
5个Claude实战项目从0到1:自动化、客服机器人、代码审查
人工智能
flay2 小时前
Claude API完全指南:从入门到实战
人工智能
用户5191495848452 小时前
OAuth/OpenID Connect安全测试全指南
人工智能·aigc
初级炼丹师(爱说实话版)2 小时前
PGLRNet论文笔记
人工智能·深度学习·计算机视觉