网络训练中使用上下文管理器 ctx 控制计算精度

先看代码:

python 复制代码
with ctx: # 使用上下文管理器 ctx,如果在 GPU 上训练且启用了自动混合精度(AMP),则使用 torch.cuda.amp.autocast();否则,使用 nullcontext()
    res = model(X) # 将输入数据 X 传递给模型,获取模型输出 res
    loss = loss_fct(
        res.logits.view(-1, res.logits.size(-1)),
        Y.view(-1)
    ).view(Y.size()) # 计算损失函数。这里将模型输出的 logits 和目标 Y 传递给损失函数

    loss = (loss * loss_mask).sum() / loss_mask.sum() # 使用损失掩码 loss_mask 对损失进行加权而后计算加权损失的平均值

    loss += res.aux_loss # 如果模型有辅助损失(如 MoE 模型中的专家损失),将其加到总损失中

    loss = loss / args.accumulation_steps # 如果使用梯度累积,将损失除以累积步骤数

代码解析

python 复制代码
with ctx:  # 使用上下文管理器 ctx,如果在 GPU 上训练且启用了自动混合精度(AMP),则使用 torch.cuda.amp.autocast();否则,使用 nullcontext()
  • ctx :上下文管理器,用于控制计算的精度模式。
    • 如果启用了自动混合精度(AMP),ctx 通常是 torch.cuda.amp.autocast()。AMP 是一种 PyTorch 提供的技术,用于在 GPU 上自动切换浮点精度(FP32 和 FP16),以加快训练速度并减少内存占用。
    • 如果没有启用 AMP,ctx 可能是 contextlib.nullcontext(),这是一个空的上下文管理器,不会对计算过程产生任何影响。

模型前向传播

python 复制代码
res = model(X)  # 将输入数据 X 传递给模型,获取模型输出 res
  • model(X)
    • X 是输入数据,通常是一个张量,形状为 (batch_size, sequence_length, feature_dim) 或其他根据模型设计的形状。
    • model 是一个 PyTorch 模型实例,调用 model(X) 会执行模型的前向传播,返回模型的输出 res
  • res
    • res 是模型的输出,通常是一个包含多个属性的对象(如 logitsaux_loss)。logits 是模型的预测结果,形状通常为 (batch_size, sequence_length, num_classes)

损失计算

python 复制代码
loss = loss_fct(
    res.logits.view(-1, res.logits.size(-1)),
    Y.view(-1)
).view(Y.size())  # 计算损失函数。这里将模型输出的 logits 和目标 Y 传递给损失函数
  • res.logits.view(-1, res.logits.size(-1))
    • res.logits 是模型的预测结果,形状为 (batch_size, sequence_length, num_classes)
    • view(-1, res.logits.size(-1))logits 展平为二维张量,形状为 (batch_size * sequence_length, num_classes)。这是为了适配 nn.CrossEntropyLoss 的输入要求。
  • Y.view(-1)
    • Y 是目标标签,形状通常为 (batch_size, sequence_length)
    • view(-1)Y 展平为一维张量,形状为 (batch_size * sequence_length,)
  • loss_fct(...)
    • loss_fct 是损失函数,通常是一个 nn.CrossEntropyLoss 实例。
    • 调用 loss_fct 计算每个样本的损失值,返回的损失值是一个一维张量,形状为 (batch_size * sequence_length,)
  • .view(Y.size())
    • 将计算得到的损失值重新调整为与 Y 相同的形状,即 (batch_size, sequence_length)

损失加权

python 复制代码
loss = (loss * loss_mask).sum() / loss_mask.sum()  # 使用损失掩码 loss_mask 对损失进行加权而后计算加权损失的平均值
  • loss * loss_mask
    • loss_mask 是一个掩码张量,形状与 loss 相同,通常用于指示哪些位置的损失需要被计算。
    • loss * loss_mask 对损失值进行加权,掩码值为 0 的位置对应的损失值会被忽略。
  • .sum()
    • 对加权后的损失值求和。
  • / loss_mask.sum()
    • 将总损失值除以掩码中非零值的数量,计算加权损失的平均值。

辅助损失

python 复制代码
loss += res.aux_loss  # 如果模型有辅助损失(如 MoE 模型中的专家损失),将其加到总损失中
  • res.aux_loss
    • aux_loss 是模型的辅助损失,例如在 MoE(Mixture of Experts)模型中,每个专家的损失可以作为辅助损失。
    • 如果模型有辅助损失,将其加到总损失中。

梯度累积

python 复制代码
loss = loss / args.accumulation_steps  # 如果使用梯度累积,将损失除以累积步骤数
  • args.accumulation_steps
    • 梯度累积步骤数,用于在有限的 GPU 内存下处理较大的批次。
  • loss / args.accumulation_steps
    • 将损失值除以梯度累积步骤数,以便在每次反向传播时只计算一部分梯度。

总结

这段代码的主要功能是:

  1. 使用上下文管理器 ctx 控制计算精度(AMP 或普通精度)。
  2. 将输入数据 X 传递给模型,获取模型输出 res
  3. 计算模型输出的 logits 和目标 Y 的损失值。
  4. 使用损失掩码 loss_mask 对损失值进行加权,并计算加权损失的平均值。
  5. 如果模型有辅助损失,将其加到总损失中。
  6. 如果使用梯度累积,将损失值除以累积步骤数。

上述过程在深度学习训练中非常常见,尤其是在处理复杂模型(如 MoE)和使用梯度累积时。

相关推荐
Java技术小馆1 小时前
GitDiagram如何让你的GitHub项目可视化
java·后端·面试
UGOTNOSHOT1 小时前
7.4项目一问题准备
面试
YaHuiLiang3 小时前
小微互联网公司与互联网创业公司 -- 学历之殇
前端·后端·面试
爱学习的茄子6 小时前
深度解析JavaScript中的call方法实现:从原理到手写实现的完整指南
前端·javascript·面试
莫空00006 小时前
Vue组件通信方式详解
前端·面试
呆呆的小鳄鱼6 小时前
cin,cin.get()等异同点[面试题系列]
java·算法·面试
顾林海7 小时前
ViewModel 销毁时机详解
android·面试·android jetpack
bo521007 小时前
解决跨域的几种种方法, 你都知道几种?
前端·面试·浏览器
掘金安东尼8 小时前
前端周刊第421期(2025年7月1日–7月6日)
前端·面试·github
前端小巷子8 小时前
web从输入网址到页面加载完成
前端·面试·浏览器