pytorch小记(五):pytorch中的求导操作:backward()

pytorch小记(五):pytorch中的求导操作:backward()


完整代码

python 复制代码
x = torch.tensor([[2., -1.], [1., 1.]], requires_grad=True)
print(x)
out = x.pow(2).sum()
print(out)
out.backward()
print(x.grad)
python 复制代码
>>>
tensor([[ 2., -1.],
        [ 1.,  1.]], requires_grad=True)
tensor(9., grad_fn=<SumBackward0>)
tensor([[12.,  3.],
        [ 3.,  3.]])

让我们逐行详细分析代码的执行逻辑及其背后的原理。


代码 1

python 复制代码
x = torch.tensor([[2., -1.], [1., 1.]], requires_grad=True)
print(x)

解释

  1. torch.tensor

    • 创建一个二维张量 x,内容为 [[2., -1.], [1., 1.]]
    • 数据类型为浮点数(float32),因为传入的是小数。
  2. requires_grad=True

    • 启用自动求导功能。
    • 任何以 x 为输入的张量计算图都会记录下来,用于后续的梯度计算。
  3. 结果

    打印张量 x

    tensor([[ 2., -1.],
            [ 1.,  1.]], requires_grad=True)
    

代码 2

python 复制代码
out = x.pow(2).sum()
print(out)

解释

  1. x.pow(2)

    • 计算 x 的每个元素的平方。
    • 张量的每个元素被逐元素计算: x i j 2 x_{ij}^2 xij2。

    计算结果:

    tensor([[4., 1.],
            [1., 1.]], requires_grad=True)
    
  2. .sum()

    • 对张量中所有元素求和,计算结果为:
      4 + 1 + 1 + 1 = 7 4 + 1 + 1 + 1 = 7 4+1+1+1=7
    • requires_grad=True 表示 out 依赖于 x,可以反向传播。
  3. 结果

    打印 out

    tensor(7., grad_fn=<SumBackward0>)
    
    • grad_fn=<SumBackward0> 表示这是通过 sum() 操作计算得出的,反向传播时会回溯到这个操作。

代码 3

python 复制代码
out.backward()

解释

  1. out.backward()

    • out 调用 .backward() 会计算 x 的梯度。
    • 梯度的计算目标 :计算 outx 的偏导数,即 ∂ out ∂ x \frac{\partial \text{out}}{\partial x} ∂x∂out。
  2. 计算过程

    • out = x.pow(2).sum(),展开公式为:
      out = x 11 2 + x 12 2 + x 21 2 + x 22 2 \text{out} = x_{11}^2 + x_{12}^2 + x_{21}^2 + x_{22}^2 out=x112+x122+x212+x222

    • x_{ij} 求偏导:
      ∂ out ∂ x i j = 2 ⋅ x i j \frac{\partial \text{out}}{\partial x_{ij}} = 2 \cdot x_{ij} ∂xij∂out=2⋅xij

    • 每个元素的梯度计算结果:

      [[2 * 2., 2 * -1.],
       [2 * 1., 2 * 1.]]
      
    • 结果为:

      [[ 4., -2.],
       [ 2.,  2.]]
      
  3. 梯度存储

    • 计算出的梯度会存储在 x.grad 中。

代码 4

python 复制代码
print(x.grad)

解释

  1. 打印 x 的梯度,即 x.grad
  2. x.grad 包含了之前通过 .backward() 计算的梯度值。

结果

tensor([[ 4., -2.],
        [ 2.,  2.]])

代码的整体逻辑总结

  1. 初始化张量

    python 复制代码
    x = torch.tensor([[2., -1.], [1., 1.]], requires_grad=True)

    创建一个二维张量,并启用梯度跟踪。

  2. 前向计算

    python 复制代码
    out = x.pow(2).sum()

    计算 x 的平方和,得到标量 out

  3. 反向传播

    python 复制代码
    out.backward()

    自动计算 outx 的梯度,结果存储在 x.grad 中。

  4. 打印梯度

    python 复制代码
    print(x.grad)

    查看 x 的梯度,其值为:

    [[ 4., -2.],
     [ 2.,  2.]]
    

    这表明 outx 的变化率是该梯度值。


补充:requires_gradbackward 的作用

  • requires_grad=True
    • 启用张量的自动求导功能,参与计算的操作会记录到计算图中。
  • backward()
    • 执行反向传播,沿着计算图回溯并计算梯度。
    • 梯度存储在 x.grad 中,用于优化或调试。

疑问

在 PyTorch 中,out.backward() 是用来计算张量的梯度的操作。然而,当调用 backward() 时,输出张量必须是标量(即只有一个数值)。如果输出张量不是标量,就会引发以下错误:

RuntimeError: grad can be implicitly created only for scalar outputs

原因分析

  1. 标量输出的要求

    • backward() 会计算输入张量(如 x)对输出张量(如 out)的梯度。
    • 如果输出张量是标量(0 维张量),梯度是输入张量的每个元素对这个标量的偏导数,梯度张量的形状与输入张量一致。

    例子:

    python 复制代码
    out = x.pow(2).sum()
    out.backward()
    • 这里 out 是一个标量,因此可以计算梯度。
  2. 非标量输出的情况

    • 如果输出张量是多维的,例如 x.pow(2) 的结果:

      python 复制代码
      out = x.pow(2)

      out 的形状是 (2, 2)

      tensor([[4., 1.],
              [1., 1.]], requires_grad=True)
      
    • 这种情况下,backward() 不知道如何将梯度传播到输入张量 x,因为输出张量的每个元素都可能有不同的梯度。


如何解决

如果想对多维张量调用 backward(),需要将其 转换为标量,例如通过求和或取均值。

方法 1:通过 .sum() 转为标量

python 复制代码
out = x.pow(2).sum()  # 转为标量
out.backward()

方法 2:通过 .mean() 转为标量

python 复制代码
out = x.pow(2).mean()  # 转为标量
out.backward()

高级用法:为非标量指定梯度权重

如果你需要对非标量张量调用 backward(),可以通过 out.backward(gradient=...) 指定梯度权重,告诉 PyTorch 如何将梯度聚合到输入张量。

示例

python 复制代码
out = x.pow(2)  # 非标量张量,形状为 (2, 2)
gradient = torch.ones_like(out)  # 权重张量,形状必须与 out 相同
out.backward(gradient=gradient)

解释

  • gradient=... 指定了 out 每个元素在反向传播中的权重。
  • 例如,对于 o u t = x 2 out = x^2 out=x2,若 gradient=1,梯度仍然是 ∂ o u t ∂ x = 2 x \frac{\partial out}{\partial x} = 2x ∂x∂out=2x。

总结

  • backward() 要求输出必须是标量,否则会报错。
  • 可以通过 .sum().mean() 将多维张量转换为标量。
  • 对于特殊情况,可以使用 backward(gradient=...) 指定非标量输出的梯度权重。
相关推荐
带娃的IT创业者25 分钟前
《Python实战进阶》专栏 No.3:Django 项目结构解析与入门DEMO
数据库·python·django
AL.千灯学长41 分钟前
DeepSeek接入Siri(已升级支持苹果手表)完整版硅基流动DeepSeek-R1部署
人工智能·gpt·ios·ai·苹果vision pro
HealthScience1 小时前
【异常错误】pycharm debug view变量的时候显示不全,中间会以...显示
ide·python·pycharm
LCG元1 小时前
大模型驱动的围术期质控系统全面解析与应用探索
人工智能
lihuayong1 小时前
计算机视觉:主流数据集整理
人工智能·计算机视觉·mnist数据集·coco数据集·图像数据集·cifar-10数据集·imagenet数据集
政安晨2 小时前
政安晨【零基础玩转各类开源AI项目】DeepSeek 多模态大模型Janus-Pro-7B,本地部署!支持图像识别和图像生成
人工智能·大模型·多模态·deepseek·janus-pro-7b
一ge科研小菜鸡2 小时前
DeepSeek 与后端开发:AI 赋能云端架构与智能化服务
人工智能·云原生
冰 河2 小时前
‌最新版DeepSeek保姆级安装教程:本地部署+避坑指南
人工智能·程序员·openai·deepseek·冰河大模型
维维180-3121-14552 小时前
AI赋能生态学暨“ChatGPT+”多技术融合在生态系统服务中的实践技术应用与论文撰写
人工智能·chatgpt
豌豆花下猫2 小时前
Python 潮流周刊#90:uv 一周岁了,优缺点分析(摘要)
后端·python·ai