trl的安装与单GPU多GPU测试

文章目录

  • [0 相关资料](#0 相关资料)
  • [1 源码安装](#1 源码安装)
  • [2 Qwen2.5-0.5B-Instruct 模型下载](#2 Qwen2.5-0.5B-Instruct 模型下载)
  • [3 训练demo](#3 训练demo)
  • [4 在多个 GPU/节点上进行训练](#4 在多个 GPU/节点上进行训练)
  • 总结

0 相关资料

https://github.com/huggingface/trl
https://blog.csdn.net/weixin_42486623/article/details/134326187

TRL 是一个先进的库,专为训练后基础模型而设计,采用了监督微调 (SFT)、近端策略优化 (PPO) 和直接偏好优化 (DPO) 等先进技术。TRL 建立在 🤗 Transformers 生态系统之上,支持多种模型架构和模态,并可在各种硬件配置上进行扩展。

b站视频:https://www.bilibili.com/video/BV18ndfYfEcz/

PyTorch / 2.3.0 / 3.12(ubuntu22.04) / 12.1

1 源码安装

复制代码
source /etc/network_turbo
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .

source /etc/network_turbo
pip install trl transformers datasets accelerate

2 Qwen2.5-0.5B-Instruct 模型下载

https://www.modelscope.cn/models/Qwen/Qwen2.5-0.5B-Instruct

bash 复制代码
source /etc/network_turbo
pip install modelscope

采用SDK方式下载

bash 复制代码
from modelscope import snapshot_download

# 指定模型的下载路径
cache_dir = '/root/'
# 调用 snapshot_download 函数下载模型
model_dir = snapshot_download('Qwen/Qwen2.5-0.5B-Instruct', cache_dir=cache_dir)

print(f"模型已下载到: {model_dir}")

3 训练demo

demo.py

执行脚本前,输入:

复制代码
source /etc/network_turbo

from trl import SFTTrainer
from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train")

trainer = SFTTrainer(
    model="/root/Qwen/Qwen2.5-0.5B-Instruct",
    train_dataset=dataset,
)
trainer.train()
复制代码
00:15<1:57:58,

4 在多个 GPU/节点上进行训练

执行脚本前,输入:

复制代码
source /etc/network_turbo
bash 复制代码
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes 2 demo.py --all_arguments_of_the_script

总结

一块L20 GPU 48G,需要2小时

两块L20 GPU 48G,需要0.5小时

速度提升明显

相关推荐
会飞的老朱1 小时前
医药集团数智化转型,智能综合管理平台激活集团管理新效能
大数据·人工智能·oa协同办公
聆风吟º2 小时前
CANN runtime 实战指南:异构计算场景中运行时组件的部署、调优与扩展技巧
人工智能·神经网络·cann·异构计算
Codebee4 小时前
能力中心 (Agent SkillCenter):开启AI技能管理新时代
人工智能
聆风吟º5 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
uesowys5 小时前
Apache Spark算法开发指导-One-vs-Rest classifier
人工智能·算法·spark
AI_56785 小时前
AWS EC2新手入门:6步带你从零启动实例
大数据·数据库·人工智能·机器学习·aws
User_芊芊君子5 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
智驱力人工智能6 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
qq_160144876 小时前
亲测!2026年零基础学AI的入门干货,新手照做就能上手
人工智能
Howie Zphile6 小时前
全面预算管理难以落地的核心真相:“完美模型幻觉”的认知误区
人工智能·全面预算