网络训练中使用上下文管理器 ctx 控制计算精度

先看代码:

python 复制代码
with ctx: # 使用上下文管理器 ctx,如果在 GPU 上训练且启用了自动混合精度(AMP),则使用 torch.cuda.amp.autocast();否则,使用 nullcontext()
    res = model(X) # 将输入数据 X 传递给模型,获取模型输出 res
    loss = loss_fct(
        res.logits.view(-1, res.logits.size(-1)),
        Y.view(-1)
    ).view(Y.size()) # 计算损失函数。这里将模型输出的 logits 和目标 Y 传递给损失函数

    loss = (loss * loss_mask).sum() / loss_mask.sum() # 使用损失掩码 loss_mask 对损失进行加权而后计算加权损失的平均值

    loss += res.aux_loss # 如果模型有辅助损失(如 MoE 模型中的专家损失),将其加到总损失中

    loss = loss / args.accumulation_steps # 如果使用梯度累积,将损失除以累积步骤数

代码解析

python 复制代码
with ctx:  # 使用上下文管理器 ctx,如果在 GPU 上训练且启用了自动混合精度(AMP),则使用 torch.cuda.amp.autocast();否则,使用 nullcontext()
  • ctx :上下文管理器,用于控制计算的精度模式。
    • 如果启用了自动混合精度(AMP),ctx 通常是 torch.cuda.amp.autocast()。AMP 是一种 PyTorch 提供的技术,用于在 GPU 上自动切换浮点精度(FP32 和 FP16),以加快训练速度并减少内存占用。
    • 如果没有启用 AMP,ctx 可能是 contextlib.nullcontext(),这是一个空的上下文管理器,不会对计算过程产生任何影响。

模型前向传播

python 复制代码
res = model(X)  # 将输入数据 X 传递给模型,获取模型输出 res
  • model(X)
    • X 是输入数据,通常是一个张量,形状为 (batch_size, sequence_length, feature_dim) 或其他根据模型设计的形状。
    • model 是一个 PyTorch 模型实例,调用 model(X) 会执行模型的前向传播,返回模型的输出 res
  • res
    • res 是模型的输出,通常是一个包含多个属性的对象(如 logitsaux_loss)。logits 是模型的预测结果,形状通常为 (batch_size, sequence_length, num_classes)

损失计算

python 复制代码
loss = loss_fct(
    res.logits.view(-1, res.logits.size(-1)),
    Y.view(-1)
).view(Y.size())  # 计算损失函数。这里将模型输出的 logits 和目标 Y 传递给损失函数
  • res.logits.view(-1, res.logits.size(-1))
    • res.logits 是模型的预测结果,形状为 (batch_size, sequence_length, num_classes)
    • view(-1, res.logits.size(-1))logits 展平为二维张量,形状为 (batch_size * sequence_length, num_classes)。这是为了适配 nn.CrossEntropyLoss 的输入要求。
  • Y.view(-1)
    • Y 是目标标签,形状通常为 (batch_size, sequence_length)
    • view(-1)Y 展平为一维张量,形状为 (batch_size * sequence_length,)
  • loss_fct(...)
    • loss_fct 是损失函数,通常是一个 nn.CrossEntropyLoss 实例。
    • 调用 loss_fct 计算每个样本的损失值,返回的损失值是一个一维张量,形状为 (batch_size * sequence_length,)
  • .view(Y.size())
    • 将计算得到的损失值重新调整为与 Y 相同的形状,即 (batch_size, sequence_length)

损失加权

python 复制代码
loss = (loss * loss_mask).sum() / loss_mask.sum()  # 使用损失掩码 loss_mask 对损失进行加权而后计算加权损失的平均值
  • loss * loss_mask
    • loss_mask 是一个掩码张量,形状与 loss 相同,通常用于指示哪些位置的损失需要被计算。
    • loss * loss_mask 对损失值进行加权,掩码值为 0 的位置对应的损失值会被忽略。
  • .sum()
    • 对加权后的损失值求和。
  • / loss_mask.sum()
    • 将总损失值除以掩码中非零值的数量,计算加权损失的平均值。

辅助损失

python 复制代码
loss += res.aux_loss  # 如果模型有辅助损失(如 MoE 模型中的专家损失),将其加到总损失中
  • res.aux_loss
    • aux_loss 是模型的辅助损失,例如在 MoE(Mixture of Experts)模型中,每个专家的损失可以作为辅助损失。
    • 如果模型有辅助损失,将其加到总损失中。

梯度累积

python 复制代码
loss = loss / args.accumulation_steps  # 如果使用梯度累积,将损失除以累积步骤数
  • args.accumulation_steps
    • 梯度累积步骤数,用于在有限的 GPU 内存下处理较大的批次。
  • loss / args.accumulation_steps
    • 将损失值除以梯度累积步骤数,以便在每次反向传播时只计算一部分梯度。

总结

这段代码的主要功能是:

  1. 使用上下文管理器 ctx 控制计算精度(AMP 或普通精度)。
  2. 将输入数据 X 传递给模型,获取模型输出 res
  3. 计算模型输出的 logits 和目标 Y 的损失值。
  4. 使用损失掩码 loss_mask 对损失值进行加权,并计算加权损失的平均值。
  5. 如果模型有辅助损失,将其加到总损失中。
  6. 如果使用梯度累积,将损失值除以累积步骤数。

上述过程在深度学习训练中非常常见,尤其是在处理复杂模型(如 MoE)和使用梯度累积时。

相关推荐
智商低情商凑1 小时前
CAS(Compare And Swap)
java·jvm·面试
uhakadotcom2 小时前
人工智能如何改变医疗行业:简单易懂的基础介绍与实用案例
算法·面试·github
zizisuo2 小时前
面试篇:Spring Boot
spring boot·面试·职场和发展
uhakadotcom4 小时前
企业智能体网络(Agent Mesh)入门指南:基础知识与实用示例
后端·面试·github
独孤歌5 小时前
告别频繁登录:打造用户无感的 Token 刷新机制
安全·面试
Eliauk__5 小时前
深入剖析 Vue 双向数据绑定机制 —— 从响应式原理到 v-model 实现全解析
前端·javascript·面试
慕仲卿5 小时前
模型初始化:加载分词器和模型
面试
海底火旺5 小时前
寻找缺失的最小正整数:从暴力到最优的算法演进
javascript·算法·面试
顾林海5 小时前
深入探究 Android Native 代码的崩溃捕获机制
android·面试·性能优化
慕仲卿5 小时前
缩放器和优化器的定义
面试