[AI算法] LLM训练-构建transformers custom model

文章目录

    • [1. 继承与实现基础结构](#1. 继承与实现基础结构)
    • [2. 支持 DeepSpeed 和 Accelerate 的注意事项](#2. 支持 DeepSpeed 和 Accelerate 的注意事项)
      • [a. 模型输出格式](#a. 模型输出格式)
      • [b. 设备管理](#b. 设备管理)
      • [c. 分布式训练兼容性](#c. 分布式训练兼容性)
      • [d. DeepSpeed 特定优化](#d. DeepSpeed 特定优化)
    • [3. 训练脚本集成建议](#3. 训练脚本集成建议)
    • [4. 测试与调试建议](#4. 测试与调试建议)

在使用 Hugging Face 的 transformers 库时,若要自定义一个继承自 PreTrainedModel 的模型,并确保其在训练过程中支持 DeepSpeed 或 Accelerate 等加速框架,需要注意以下关键点:

1. 继承与实现基础结构

继承 PreTrainedModel

python 复制代码
  from transformers import PreTrainedModel, PretrainedConfig

  class MyCustomModel(PreTrainedModel):
      config_class = MyCustomConfig  # 自定义配置类
      base_model_prefix = "my_model"  # 模型前缀名

      def __init__(self, config):
          super().__init__(config)
          # 初始化模型结构
实现必要的方法
forward():必须正确返回 loss(用于训练)和输出。
save_pretrained() / from_pretrained():确保模型可保存和加载。

2. 支持 DeepSpeed 和 Accelerate 的注意事项

a. 模型输出格式

返回的输出应为 Seq2SeqLMOutput 或 CausalLMOutputWithPast 等标准输出类型,包含 loss, logits 等字段。

例如:

python 复制代码
  from transformers.modeling_outputs import CausalLMOutputWithPast

  def forward(...):
      ...
      return CausalLMOutputWithPast(
          loss=loss,
          logits=logits,
          past_key_values=past_key_values,
          hidden_states=hidden_states,
          attentions=attentions,
      )

b. 设备管理

不要在模型内部硬编码 .to(device),让 Accelerate 或 DeepSpeed 控制设备放置。

使用 accelerator.prepare(model, optimizer, dataloader) 来自动处理设备分配。

c. 分布式训练兼容性

避免使用不支持分布式训练的操作(如某些自定义 gather/scatter 操作)。

使用 PyTorch 原生支持的并行方式(如 nn.parallel.DistributedDataParallel)。

d. DeepSpeed 特定优化

若使用 DeepSpeed ZeRO,请避免在模型中使用 torch.nn.DataParallel。

使用 deepspeed.initialize() 替代常规优化器初始化。

在 deepspeed 配置文件中指定 train_batch_size、gradient_accumulation_steps 等参数。

3. 训练脚本集成建议

  • 使用 Accelerate
python 复制代码
from accelerate import Accelerator
accelerator = Accelerator()
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)

for batch in train_dataloader:
    outputs = model(**batch)
    loss = outputs.loss
    accelerator.backward(loss)
    optimizer.step()
    optimizer.zero_grad()
  • 使用 DeepSpeed
python 复制代码
安装 DeepSpeed 并使用其启动脚本:
  deepspeed --num_gpus=4 train.py --deepspeed --deepspeed_config ds_config.json
示例 ds_config.json:
json
  {
    "train_batch_size": 32,
    "gradient_accumulation_steps": 1,
    "optimizer": {
      "type": "AdamW",
      "params": {
        "lr": 3e-5
      }
    },
    "zero_optimization": {
      "stage": 2
    }
  }

4. 测试与调试建议

  • 使用 transformers.Trainer 进行快速验证是否能正常训练。
  • 启用 fp16 或 bf16 加速训练时,确保模型计算图支持混合精度。
  • 使用 torch.compile() 可进一步提升性能(PyTorch 2.0+)。
相关推荐
立志成为大牛的小牛5 小时前
数据结构——四十、折半查找(王道408)
数据结构·学习·程序人生·考研·算法
王哈哈^_^5 小时前
【完整源码+数据集】蓝莓数据集,yolo11蓝莓成熟度检测数据集 3023 张,蓝莓成熟度数据集,目标检测蓝莓识别算法系统实战教程
人工智能·算法·yolo·目标检测·计算机视觉·ai·视觉检测
王哈哈^_^5 小时前
【完整源码+数据集】高空作业数据集,yolo高空作业检测数据集 2076 张,人员高空作业数据集,目标检测高空作业识别系统实战教程
人工智能·算法·yolo·目标检测·计算机视觉·目标跟踪·视觉检测
一条数据库5 小时前
猫狗识别数据集:34,441张高质量标注图像,深度学习二分类任务训练数据集,计算机视觉算法研发,CNN模型训练,图像识别分类,机器学习实践项目完整数据资
深度学习·算法·机器学习
bloxd yzh6 小时前
图论基础概念
算法
小白程序员成长日记6 小时前
2025.11.09 力扣每日一题
算法·leetcode·职场和发展
hansang_IR6 小时前
【题解】洛谷 P1477 [NOI2008] 假面舞会 [思维 + 图论]
c++·算法·图论·思维
天选之女wow6 小时前
【代码随想录算法训练营——Day59】图论——47.参加科学大会、94.城市间货物运输I
算法·图论
CoovallyAIHub6 小时前
1.2MB超轻量模型实现草莓苗精准分级检测与定位,准确率超96%
深度学习·算法·计算机视觉
CoovallyAIHub6 小时前
终结AI偏见!Sony AI发布Nature论文与FHIBE数据集,重塑公平性评估基准
深度学习·算法·计算机视觉