Pytorch:torch.nn.utils.clip_grad_norm_梯度截断_解读

torch.nn.utils.clip_grad_norm_函数主要作用:

神经网络深度逐渐增加,网络参数量增多的时候,容易引起梯度消失和梯度爆炸。对于梯度爆炸问题,解决方法之一便是进行梯度剪裁torch.nn.utils.clip_grad_norm_(),即设置一个梯度大小的上限

注:旧版为torch.nn.utils.clip_grad_norm()

函数参数:

官网链接:https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html

torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None)

"Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place."

"对一组可迭代(网络)参数的梯度范数进行裁剪。效果如同将所有参数连接成单个向量来计算范数。梯度原位修改。"

Parameters

  • parameters (Iterable[Tensor] or Tensor) -- 实施梯度裁剪的可迭代网络参数

    an iterable of Tensors or a single Tensor that will have gradients normalized(一个由张量或单个张量组成的可迭代对象(模型参数),将梯度归一化)

  • max_norm (float) -- 该组网络参数梯度的范数上限

    max norm of the gradients(梯度的最大值)

  • norm_type (float) --范数类型

    type of the used p-norm. Can be 'inf' for infinity norm.(所使用的范数类型。默认为L2范数,可以是无穷大范数('inf'))

  • error_if_nonfinite (bool) --

    if True, an error is thrown if the total norm of the gradients from parameters is nan, inf, or -inf. Default: False (will switch to True in the future)

  • foreach (bool) --

    use the faster foreach-based implementation. If None, use the foreach implementation for CUDA and CPU native tensors and silently fall back to the slow implementation for other device types. Default: None

源码解读:

参考:https://blog.csdn.net/Mikeyboi/article/details/119522689

(建议大家看看源码,更好理解函数意义,有注释)

python 复制代码
def clip_grad_norm_(parameters, max_norm, norm_type=2):
	# 处理传入的三个参数。
	# 首先将parameters中的非空网络参数存入一个列表,
	# 然后将max_norm和norm_type类型强制为浮点数。
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    
	#对无穷范数进行了单独计算,即取所有网络参数梯度范数中的最大值,定义为total_norm
    if norm_type == inf:
        total_norm = max(p.grad.data.abs().max() for p in parameters)

	# 对于其他范数,计算所有网络参数梯度范数之和,再归一化,
	# 即等价于把所有网络参数放入一个向量,再对向量计算范数。将结果定义为total_norm
    else:
        total_norm = 0
        for p in parameters:
            param_norm = p.grad.data.norm(norm_type)
            total_norm += param_norm.item() ** norm_type # norm_type=2 求平方(二范数)
        total_norm = total_norm ** (1. / norm_type) # norm_type=2 等价于 开根号
        
    # 最后定义了一个"裁剪系数"变量clip_coef,为传入参数max_norm和total_norm的比值(+1e-6防止分母为0的情况)。
    # 如果max_norm > total_norm,即没有溢出预设上限,则不对梯度进行修改。
    # 反之则以clip_coef为系数对全部梯度进行惩罚,使最后的全部梯度范数归一化至max_norm的值。
    # 注意该方法返回了一个 total_norm,实际应用时可以通过该方法得到网络参数梯度的范数,以便确定合理的max_norm值。
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        for p in parameters:
            p.grad.data.mul_(clip_coef)
    return total_norm

使用方法及分析:

应用逻辑为:

  1. 先计算梯度;
  2. 裁剪梯度(在函数内部会判断是否需要裁剪,具体看源码解读);
  3. 最后更新网络参数。

因此 torch.nn.utils.clip_grad_norm_() 的使用应该在loss.backward() 之后,optimizer.step() 之前,

在U-Net中如下:

python 复制代码
optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
grad_scaler.step(optimizer)
grad_scaler.update()

参考:https://blog.csdn.net/zhaohongfei_358/article/details/122820992

注意:

  • 从上面文章可以看到,clip_grad_norm 最后就是对所有的梯度乘以一个 clip_coefp.grad.data.mul_(clip_coef) ),而且乘的前提是clip_coef 一定是小于1的,所以,clip_grad_norm 只解决梯度爆炸问题,不解决梯度消失问题
  • clip_coef 的定义**clip_coef = max_norm / (total_norm + 1e-6)** 可以知道:max_norm越大,对于梯度爆炸的解决越柔和,max_norm越小,对梯度爆炸的解决越狠
相关推荐
喵手5 分钟前
Python爬虫实战:构建招聘会数据采集系统 - requests+lxml 实战企业名单爬取与智能分析!
爬虫·python·爬虫实战·requests·lxml·零基础python爬虫教学·招聘会数据采集
星幻元宇VR10 分钟前
5D动感影院,科技与沉浸式体验的完美融合
人工智能·科技·虚拟现实
WZGL123014 分钟前
“十五五”发展展望:以社区为底座构建智慧康养服务
大数据·人工智能·物联网
阿杰学AI22 分钟前
AI核心知识86——大语言模型之 Superalignment(简洁且通俗易懂版)
人工智能·深度学习·ai·语言模型·超级对齐·superalignment·#ai安全
CV@CV26 分钟前
拆解自动驾驶核心架构——感知、决策、控制三层逻辑详解
人工智能·机器学习·自动驾驶
专注VB编程开发20年26 分钟前
python图片验证码识别selenium爬虫--超级鹰实现自动登录,滑块,点击
数据库·python·mysql
海心焱30 分钟前
从零开始构建 AI 插件生态:深挖 MCP 如何打破 LLM 与本地数据的连接壁垒
jvm·人工智能·oracle
阿杰学AI31 分钟前
AI核心知识85——大语言模型之 RLAIF(简洁且通俗易懂版)
人工智能·深度学习·ai·语言模型·aigc·rlaihf·基于ai反馈的强化学习
Coco恺撒32 分钟前
【脑机接口】难在哪里,【人工智能】如何破局(2.研发篇)
人工智能·深度学习·开源·人机交互·脑机接口
iFeng的小屋34 分钟前
【2026最新当当网爬虫分享】用Python爬取千本日本相关图书,自动分析价格分布!
开发语言·爬虫·python