pytorch中的梯度裁剪

神经网络是通过梯度下降来学习的,在进行反向传播时,进行每一层的梯度计算,假设梯度都是比较大的值,计算到第一层的梯度时,会呈指数级增长,那么更新完的参数值会越来越大,产生梯度爆炸现象。一个比较常见的表现就是损失变成non。

梯度裁剪(Gradient Clipping)是一种防止梯度爆炸或梯度消失的技术,它可以在反向传播过程中对梯度进行缩放或截断,使其保持在一个合理的范围内。梯度裁剪有两种常见的方法:按梯度的绝对值截断或者按梯度的范数进行截断。pytorch给定了相应方法实现,这一步应该在更新参数前进行。

按值截断

c 复制代码
torch.nn.utils.clip_grad_value_(model.parameters(), value)

对一个参数的梯度进行裁剪,使其不超过一个指定的值,它接受两个参数:一个是模型的参数,一个是裁剪的值。它会对每个参数的梯度进行裁剪,使其在 [-value,value]的范围内。这样可以避免梯度过大或过小,影响模型的收敛。例子:

c 复制代码
import torch
import torch.nn as nn
 
# 定义一个简单的线性模型
model = nn.Linear(2, 1)
# 定义一个优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
# 定义一个损失函数
criterion = nn.MSELoss()
 
# 生成一些随机的输入和目标
x = torch.randn(4, 2)
y = torch.randn(4, 1)
 
# 前向传播
output = model(x)
# 计算损失
loss = criterion(output, y)
# 反向传播
loss.backward()
 
# 在更新权重之前,对梯度进行裁剪,使其不超过0.5
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
 
# 更新权重
optimizer.step()

按范数截断

c 复制代码
torch.nn.utils.clip_grad_norm_(model.parameters(), threshold)

对一个参数的梯度进行裁剪,首先计算出梯度的范数,然后将其限制在一个最大值之内。这样可以防止在反向传播过程中梯度过大导致的数值不稳定问题。例子:

c 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
 
# 假设我们有一个简单的全连接网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(10, 1)
 
    def forward(self, x):
        return self.fc(x)
 
# 创建网络、优化器和损失函数
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()
 
# 假设我们有一些随机输入数据和目标
data = torch.randn(5, 10)
target = torch.randn(5, 1)
 
# 训练步骤
outputs = model(data)  # 前向传播
loss = loss_fn(outputs, target)  # 计算损失
optimizer.zero_grad()  # 清零梯度
loss.backward()  # 反向传播,计算梯度
 
# 在优化器步骤之前,我们使用梯度裁剪
nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
 
optimizer.step()  # 更新模型参数
相关推荐
万邦科技Lafite1 小时前
京东按图搜索京东商品(拍立淘) API (.jd.item_search_img)快速抓取数据
开发语言·前端·数据库·python·电商开放平台·京东开放平台
Giser探索家1 小时前
无人机桥梁巡检:以“空天地”智慧之力守护交通生命线
大数据·人工智能·算法·安全·架构·无人机
不会学习的小白O^O1 小时前
双通道深度学习框架可实现从无人机激光雷达点云中提取橡胶树冠
人工智能·深度学习·无人机
恒点虚拟仿真1 小时前
虚拟仿真实训破局革新:打造无人机飞行专业实践教学新范式
人工智能·无人机·ai教学·虚拟仿真实训·无人机飞行·无人机专业虚拟仿真·无人机飞行虚拟仿真
丁浩6662 小时前
Python机器学习---6.集成学习与随机森林
python·随机森林·机器学习
鲜枣课堂2 小时前
华为最新光通信架构AI-OTN,如何应对AI浪潮?
人工智能·华为·架构
charlie1145141913 小时前
现代 Python 学习笔记:Statements & Syntax
笔记·python·学习·教程·基础·现代python·python3.13
格林威3 小时前
AOI在新能源电池制造领域的应用
人工智能·数码相机·计算机视觉·视觉检测·制造·工业相机
dxnb223 小时前
Datawhale25年10月组队学习:math for AI+Task5解析几何
人工智能·学习
DooTask官方号3 小时前
DooTask 1.3.38 版本更新:MCP 服务器与 AI 工具深度融合,开启任务管理新体验
运维·服务器·人工智能·开源软件·dootask