PyTorch2 Python深度学习 - 自动微分(Autograd)与梯度优化

锋哥原创的PyTorch2 Python深度学习视频教程:

https://www.bilibili.com/video/BV1eqxNzXEYc

课程介绍

​基于前面的机器学习Scikit-learn,深度学习Tensorflow2课程,我们继续讲解深度学习PyTorch2,所以有些机器学习,深度学习基本概念就不再重复讲解,大家务必学习好前面两个课程。本课程主要讲解基于PyTorch2的深度学习核心知识,主要讲解包括PyTorch2框架入门知识,环境搭建,张量,自动微分,数据加载与预处理,模型训练与优化,以及卷积神经网络(CNN),循环神经网络(RNN),生成对抗网络(GAN),模型保存与加载等。

PyTorch2 Python深度学习 - 自动微分(Autograd)与梯度优化

在PyTorch2中, 自动微分(Autograd)机制, 是 PyTorch 的核心功能之一,用于自动计算张量的导数(梯度)。

它的主要用途是:在神经网络反向传播过程中自动计算参数的梯度

在 PyTorch 中,只要一个张量的属性 requires_grad=True,系统就会跟踪它的所有运算,从而可以在反向传播时自动求出梯度。

基本原理

  • 计算图(Computational Graph): PyTorch 会动态构建一张有向无环图(DAG),图的节点是张量,边是函数(如加法、乘法等)。 反向传播时,PyTorch 会沿着这张图从输出向输入依次计算梯度。

  • 反向传播(Backpropagation) : 调用 loss.backward() 时,PyTorch 会自动计算所有参与计算的 requires_grad=True 张量的梯度。

  • 梯度存储 : 计算出的梯度会存放在每个张量的 .grad 属性中。

简单示例

复制代码
import torch

# 创建一个张量并启用自动求导
x = torch.tensor(3.0, requires_grad=True)

# 构建一个函数 y = x^2
y = x ** 2

# 自动求导(反向传播)
y.backward()

# 查看梯度 dy/dx
print(x.grad)  # 输出:tensor(6.)
print(x.grad.item())

运行输出:

复制代码
tensor(6.)
6.0

神经网络训练中使用 Autograd

复制代码
import torch
from torch import nn, optim

# 1,构造训练数据:y=2x+1
x = torch.linspace(-5, 5, 100).unsqueeze(1)  # 100的样本,维度[100,1]
print(x, x.shape)
y = 2 * x + 1 + torch.randn(x.size())  # 添加噪声

# 2,定义简单的线性模型
model = nn.Linear(1, 1)

# 3, 定义损失函数与优化器
criterion = nn.MSELoss()  # 均方误差
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 4,训练模型
epochs = 2000
for epoch in range(epochs):
    y_pred = model(x)  # 前向传播
    loss = criterion(y_pred, y)  # 计算损失
    optimizer.zero_grad()  # 清空梯度
    loss.backward()  # 反向传播
    optimizer.step()  # 更新参数

    print(f'epoch: {epoch}, loss: {loss.item()}')

# 5,查看结果
[w, b] = model.parameters()
print(f'训练结果:w: {w}, b: {b}')

流程说明:

  1. forward() 前向传播,构建计算图

  2. loss.backward() 反向传播,自动求出参数梯度

  3. optimizer.step() 更新模型参数

相关推荐
冷雨夜中漫步2 小时前
Python快速入门(6)——for/if/while语句
开发语言·经验分享·笔记·python
郝学胜-神的一滴2 小时前
深入解析Python字典的继承关系:从abc模块看设计之美
网络·数据结构·python·程序人生
百锦再2 小时前
Reactive编程入门:Project Reactor 深度指南
前端·javascript·python·react.js·django·前端框架·reactjs
yLDeveloper4 小时前
从模型评估、梯度难题到科学初始化:一步步解析深度学习的训练问题
深度学习
m0_736919104 小时前
C++代码风格检查工具
开发语言·c++·算法
喵手4 小时前
Python爬虫实战:旅游数据采集实战 - 携程&去哪儿酒店机票价格监控完整方案(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·采集结果csv导出·旅游数据采集·携程/去哪儿酒店机票价格监控
Coder_Boy_4 小时前
技术让开发更轻松的底层矛盾
java·大数据·数据库·人工智能·深度学习
2501_944934734 小时前
高职大数据技术专业,CDA和Python认证优先考哪个?
大数据·开发语言·python
helloworldandy4 小时前
使用Pandas进行数据分析:从数据清洗到可视化
jvm·数据库·python
2401_836235864 小时前
中安未来SDK15:以AI之眼,解锁企业档案的数字化基因
人工智能·科技·深度学习·ocr·生活