Tensor自动微分

1. 自动微分基础概念

自动微分(Automatic Differentiation,简称AD)是现代深度学习框架的核心技术之一,它能够自动计算函数的导数。与符号微分和数值微分不同,自动微分结合了二者的优点:

  • 符号微分:精确但表达式复杂时会出现"表达式膨胀"

  • 数值微分:简单但存在截断误差和舍入误差

  • 自动微分:精确且高效,通过计算图实现链式法则

在PyTorch中,自动微分功能主要由autograd模块提供。每个Tensor都有一个requires_grad属性,当设置为True时,PyTorch会跟踪所有对该Tensor的操作。

python 复制代码
import torch

# 创建需要计算梯度的Tensor
x = torch.tensor([2.0], requires_grad=True)
print(x.requires_grad)  # 输出: True

2. 计算梯度

2.1 标量梯度计算

标量梯度计算是最基础的情况,即对标量函数关于标量或向量变量的求导。

python 复制代码
# 示例:计算y = x^2在x=2处的导数
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2

# 反向传播计算梯度
y.backward()

# 查看梯度
print(x.grad)  # 输出: tensor([4.]) 因为dy/dx=2x,当x=2时为4

2.2 向量梯度计算

当函数的输出是标量而输入是向量时,计算的是梯度向量(一阶导数)。

python 复制代码
# 示例:计算L = sum(x^2)关于向量x的梯度
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
L = torch.sum(x ** 2)

L.backward()

print(x.grad)  # 输出: tensor([2., 4., 6.]) 因为dL/dx_i = 2x_i

2.3 多标量梯度计算

当有多个标量输出时,需要为每个输出指定梯度权重。

python 复制代码
# 示例:计算y1 = x^2, y2 = x^3在x=2处的导数
x = torch.tensor([2.0], requires_grad=True)
y1 = x ** 2
y2 = x ** 3

# 创建梯度权重(与y1,y2形状相同)
grad_tensors = torch.tensor([1.0, 1.0])

# 反向传播
torch.autograd.backward([y1, y2], grad_tensors=grad_tensors)

print(x.grad)  # 输出: tensor([16.]) 因为dy1/dx=4, dy2/dx=12, 总和为16

2.4 多向量梯度计算

python 复制代码
import torch


def test01():
    # 创建两个张量,并设置 requires_grad=True
    x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
    y = torch.tensor([4.0, 5.0, 6.0], requires_grad=True)

    # 前向传播:计算 z = x * y
    z = x * y

    # 前向传播:计算 loss = z.sum()
    loss = z.sum()

    # 查看前向传播的结果
    print("z:", z)  # 输出: tensor([ 4., 10., 18.], grad_fn=<MulBackward0>)
    print("loss:", loss)  # 输出: tensor(32., grad_fn=<SumBackward0>)

    # 反向传播:计算梯度
    loss.backward()

    # 查看梯度
    print("x.grad:", x.grad)  # 输出: tensor([4., 5., 6.])
    print("y.grad:", y.grad)  # 输出: tensor([1., 2., 3.])


if __name__ == "__main__":
    test01()

3. 梯度上下文控制

3.1 控制梯度计算

PyTorch提供了几种控制梯度计算的上下文管理器:

  • torch.no_grad():禁用梯度计算,减少内存消耗

  • torch.enable_grad():启用梯度计算

  • torch.set_grad_enabled():根据布尔值设置梯度计算状态

python 复制代码
x = torch.tensor([1.0], requires_grad=True)

# 在no_grad上下文中不计算梯度
with torch.no_grad():
    y = x * 2
    print(y.requires_grad)  # 输出: False

# 临时启用梯度计算
with torch.enable_grad():
    z = x * 3
    print(z.requires_grad)  # 输出: True

3.2 累计梯度

在PyTorch中,梯度是累加的,这意味着每次调用.backward()时,梯度会加到原来的梯度上,而不是替换。

python 复制代码
x = torch.tensor([1.0], requires_grad=True)

for _ in range(3):
    y = x ** 2
    y.backward()
    print(x.grad)  # 第一次: tensor([2.]), 第二次: tensor([4.]), 第三次: tensor([6.])

3.3 梯度清零

由于梯度是累加的,在训练神经网络时,通常需要在每次参数更新前将梯度清零。

python 复制代码
# 创建参数
w = torch.tensor([1.0], requires_grad=True)
b = torch.tensor([0.5], requires_grad=True)

epochs = 100

# 模拟训练循环
for epoch in range(epochs):
    # 前向传播
    y_pred = w * x + b
    loss = (y_pred - y_true)**2
    
    # 反向传播
    loss.backward()
    
    # 更新参数 (模拟)
    with torch.no_grad():
        w -= 0.01 * w.grad
        b -= 0.01 * b.grad
    
    # 梯度清零
    w.grad.zero_()
    b.grad.zero_()

3.4 案例1-求函数最小值

使用梯度下降法求函数最小值是自动微分的典型应用。

python 复制代码
import torch
from torch import optim
import numpy as np
import matplotlib.pyplot as plt

# 定义函数 f(x) = x^2 + 5sin(x)
def f(x):
    return x**2 + 5*np.sin(x)

# 使用PyTorch自动微分求最小值
x = torch.tensor([3.0], requires_grad=True)  # 初始值
optimizer = optim.SGD([x], lr=0.1)

loss_history = []
x_history = []

for i in range(100):
    optimizer.zero_grad()
    y = x**2 + 5*torch.sin(x)
    y.backward()
    optimizer.step()
    
    loss_history.append(y.item())
    x_history.append(x.item())

# 绘制结果
plt.plot(loss_history)
plt.title("Loss during optimization")
plt.xlabel("Iteration")
plt.ylabel("f(x)")
plt.show()

print(f"Minimum at x = {x.item()}, f(x) = {f(x.item())}")

3.5 案例2-函数参数求解

使用自动微分可以求解函数的参数,例如拟合线性回归模型。

python 复制代码
# 生成模拟数据
torch.manual_seed(42)
X = torch.rand(100, 1) * 10
true_w = 2.0
true_b = 3.0
y = true_w * X + true_b + torch.randn(100, 1)  # 添加噪声

# 定义模型参数
w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)

# 定义优化器
optimizer = torch.optim.SGD([w, b], lr=0.01)

# 训练循环
for epoch in range(1000):
    # 前向传播
    y_pred = w * X + b
    
    # 计算损失
    loss = torch.mean((y_pred - y)**2)
    
    # 反向传播
    optimizer.zero_grad()
    loss.backward()
    
    # 更新参数
    optimizer.step()
    
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss {loss.item()}")

print(f"Estimated parameters: w = {w.item()}, b = {b.item()}")
print(f"True parameters: w = {true_w}, b = {true_b}")

4. 高级主题

4.1 自定义自动微分函数

PyTorch允许通过继承torch.autograd.Function来定义自定义的自动微分函数。

python 复制代码
class MyReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.clamp(min=0)
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

# 使用自定义函数
x = torch.randn(5, requires_grad=True)
y = MyReLU.apply(x)
y.backward(torch.ones_like(y))
print(x.grad)

4.2 高阶导数

PyTorch支持高阶导数计算,通过在backward()时设置create_graph=True

python 复制代码
x = torch.tensor([2.0], requires_grad=True)
y = x ** 3

# 计算一阶导数
first_deriv = torch.autograd.grad(y, x, create_graph=True)[0]
print(first_deriv)  # 3*x^2 = 12

# 计算二阶导数
second_deriv = torch.autograd.grad(first_deriv, x)[0]
print(second_deriv)  # 6*x = 12

5. 总结

本文详细介绍了Tensor自动微分的基础概念和各种应用场景,包括:

  1. 不同情况下的梯度计算(标量、向量、多输出等)

  2. 梯度上下文的控制方法

  3. 实际应用案例(函数优化、参数求解)

  4. 高级主题(自定义函数、高阶导数)

自动微分是深度学习框架的核心功能,掌握其原理和使用方法对于理解和实现各种机器学习算法至关重要。PyTorch的自动微分系统设计灵活且高效,能够满足从研究到生产的各种需求。

相关推荐
我爱一条柴ya18 分钟前
【AI大模型】神经网络反向传播:核心原理与完整实现
人工智能·深度学习·神经网络·ai·ai编程
万米商云23 分钟前
企业物资集采平台解决方案:跨地域、多仓库、百部门——大型企业如何用一套系统管好百万级物资?
大数据·运维·人工智能
新加坡内哥谈技术26 分钟前
Google AI 刚刚开源 MCP 数据库工具箱,让 AI 代理安全高效地查询数据库
人工智能
慕婉030728 分钟前
深度学习概述
人工智能·深度学习
大模型真好玩29 分钟前
准确率飙升!GraphRAG如何利用知识图谱提升RAG答案质量(额外篇)——大规模文本数据下GraphRAG实战
人工智能·python·mcp
198930 分钟前
【零基础学AI】第30讲:生成对抗网络(GAN)实战 - 手写数字生成
人工智能·python·深度学习·神经网络·机器学习·生成对抗网络·近邻算法
6confim30 分钟前
AI原生软件工程师
人工智能·ai编程·cursor
阿里云大数据AI技术30 分钟前
Flink Forward Asia 2025 主旨演讲精彩回顾
大数据·人工智能·flink
i小溪31 分钟前
在使用 Docker 时,如果容器挂载的数据目录(如 `/var/moments`)位于数据盘,只要服务没有读写,数据盘是否就不会被唤醒?
人工智能·docker
程序员NEO34 分钟前
Spring AI 对话记忆大揭秘:服务器重启,聊天记录不再丢失!
人工智能·后端