网络训练中使用上下文管理器 ctx 控制计算精度

先看代码:

python 复制代码
with ctx: # 使用上下文管理器 ctx,如果在 GPU 上训练且启用了自动混合精度(AMP),则使用 torch.cuda.amp.autocast();否则,使用 nullcontext()
    res = model(X) # 将输入数据 X 传递给模型,获取模型输出 res
    loss = loss_fct(
        res.logits.view(-1, res.logits.size(-1)),
        Y.view(-1)
    ).view(Y.size()) # 计算损失函数。这里将模型输出的 logits 和目标 Y 传递给损失函数

    loss = (loss * loss_mask).sum() / loss_mask.sum() # 使用损失掩码 loss_mask 对损失进行加权而后计算加权损失的平均值

    loss += res.aux_loss # 如果模型有辅助损失(如 MoE 模型中的专家损失),将其加到总损失中

    loss = loss / args.accumulation_steps # 如果使用梯度累积,将损失除以累积步骤数

代码解析

python 复制代码
with ctx:  # 使用上下文管理器 ctx,如果在 GPU 上训练且启用了自动混合精度(AMP),则使用 torch.cuda.amp.autocast();否则,使用 nullcontext()
  • ctx :上下文管理器,用于控制计算的精度模式。
    • 如果启用了自动混合精度(AMP),ctx 通常是 torch.cuda.amp.autocast()。AMP 是一种 PyTorch 提供的技术,用于在 GPU 上自动切换浮点精度(FP32 和 FP16),以加快训练速度并减少内存占用。
    • 如果没有启用 AMP,ctx 可能是 contextlib.nullcontext(),这是一个空的上下文管理器,不会对计算过程产生任何影响。

模型前向传播

python 复制代码
res = model(X)  # 将输入数据 X 传递给模型,获取模型输出 res
  • model(X)
    • X 是输入数据,通常是一个张量,形状为 (batch_size, sequence_length, feature_dim) 或其他根据模型设计的形状。
    • model 是一个 PyTorch 模型实例,调用 model(X) 会执行模型的前向传播,返回模型的输出 res
  • res
    • res 是模型的输出,通常是一个包含多个属性的对象(如 logitsaux_loss)。logits 是模型的预测结果,形状通常为 (batch_size, sequence_length, num_classes)

损失计算

python 复制代码
loss = loss_fct(
    res.logits.view(-1, res.logits.size(-1)),
    Y.view(-1)
).view(Y.size())  # 计算损失函数。这里将模型输出的 logits 和目标 Y 传递给损失函数
  • res.logits.view(-1, res.logits.size(-1))
    • res.logits 是模型的预测结果,形状为 (batch_size, sequence_length, num_classes)
    • view(-1, res.logits.size(-1))logits 展平为二维张量,形状为 (batch_size * sequence_length, num_classes)。这是为了适配 nn.CrossEntropyLoss 的输入要求。
  • Y.view(-1)
    • Y 是目标标签,形状通常为 (batch_size, sequence_length)
    • view(-1)Y 展平为一维张量,形状为 (batch_size * sequence_length,)
  • loss_fct(...)
    • loss_fct 是损失函数,通常是一个 nn.CrossEntropyLoss 实例。
    • 调用 loss_fct 计算每个样本的损失值,返回的损失值是一个一维张量,形状为 (batch_size * sequence_length,)
  • .view(Y.size())
    • 将计算得到的损失值重新调整为与 Y 相同的形状,即 (batch_size, sequence_length)

损失加权

python 复制代码
loss = (loss * loss_mask).sum() / loss_mask.sum()  # 使用损失掩码 loss_mask 对损失进行加权而后计算加权损失的平均值
  • loss * loss_mask
    • loss_mask 是一个掩码张量,形状与 loss 相同,通常用于指示哪些位置的损失需要被计算。
    • loss * loss_mask 对损失值进行加权,掩码值为 0 的位置对应的损失值会被忽略。
  • .sum()
    • 对加权后的损失值求和。
  • / loss_mask.sum()
    • 将总损失值除以掩码中非零值的数量,计算加权损失的平均值。

辅助损失

python 复制代码
loss += res.aux_loss  # 如果模型有辅助损失(如 MoE 模型中的专家损失),将其加到总损失中
  • res.aux_loss
    • aux_loss 是模型的辅助损失,例如在 MoE(Mixture of Experts)模型中,每个专家的损失可以作为辅助损失。
    • 如果模型有辅助损失,将其加到总损失中。

梯度累积

python 复制代码
loss = loss / args.accumulation_steps  # 如果使用梯度累积,将损失除以累积步骤数
  • args.accumulation_steps
    • 梯度累积步骤数,用于在有限的 GPU 内存下处理较大的批次。
  • loss / args.accumulation_steps
    • 将损失值除以梯度累积步骤数,以便在每次反向传播时只计算一部分梯度。

总结

这段代码的主要功能是:

  1. 使用上下文管理器 ctx 控制计算精度(AMP 或普通精度)。
  2. 将输入数据 X 传递给模型,获取模型输出 res
  3. 计算模型输出的 logits 和目标 Y 的损失值。
  4. 使用损失掩码 loss_mask 对损失值进行加权,并计算加权损失的平均值。
  5. 如果模型有辅助损失,将其加到总损失中。
  6. 如果使用梯度累积,将损失值除以累积步骤数。

上述过程在深度学习训练中非常常见,尤其是在处理复杂模型(如 MoE)和使用梯度累积时。

相关推荐
蓝婷儿15 分钟前
前端面试每日三题 - Day 34
前端·面试·职场和发展
测试界萧萧10 小时前
15:00开始面试,15:06就出来了,问的问题有点变态。。。
自动化测试·软件测试·功能测试·程序人生·面试·职场和发展
Warren9810 小时前
Java面试八股Spring篇(4500字)
java·开发语言·spring boot·后端·spring·面试
是麟渊11 小时前
【大模型面试每日一题】Day 17:解释MoE(Mixture of Experts)架构如何实现模型稀疏性,并分析其训练难点
人工智能·自然语言处理·面试·职场和发展·架构
HBR666_17 小时前
面试--HTML
面试·html
码农飞哥18 小时前
互联网大厂Java求职面试实战:Spring Boot到微服务的技术问答解析
java·spring boot·缓存·面试·消息队列·技术栈·microservices
大学生小郑19 小时前
Go语言八股之Mysql事务
mysql·面试
八股文领域大手子1 天前
磁盘I/O瓶颈排查:面试通关“三部曲”心法
面试·职场和发展
大学生小郑1 天前
Go语言八股之Mysql基础详解
mysql·面试
八股文领域大手子2 天前
Java死锁排查:线上救火实战指南
java·开发语言·面试