trl的安装与单GPU多GPU测试

文章目录

  • [0 相关资料](#0 相关资料)
  • [1 源码安装](#1 源码安装)
  • [2 Qwen2.5-0.5B-Instruct 模型下载](#2 Qwen2.5-0.5B-Instruct 模型下载)
  • [3 训练demo](#3 训练demo)
  • [4 在多个 GPU/节点上进行训练](#4 在多个 GPU/节点上进行训练)
  • 总结

0 相关资料

https://github.com/huggingface/trl
https://blog.csdn.net/weixin_42486623/article/details/134326187

TRL 是一个先进的库,专为训练后基础模型而设计,采用了监督微调 (SFT)、近端策略优化 (PPO) 和直接偏好优化 (DPO) 等先进技术。TRL 建立在 🤗 Transformers 生态系统之上,支持多种模型架构和模态,并可在各种硬件配置上进行扩展。

b站视频:https://www.bilibili.com/video/BV18ndfYfEcz/

PyTorch / 2.3.0 / 3.12(ubuntu22.04) / 12.1

1 源码安装

复制代码
source /etc/network_turbo
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .

source /etc/network_turbo
pip install trl transformers datasets accelerate

2 Qwen2.5-0.5B-Instruct 模型下载

https://www.modelscope.cn/models/Qwen/Qwen2.5-0.5B-Instruct

bash 复制代码
source /etc/network_turbo
pip install modelscope

采用SDK方式下载

bash 复制代码
from modelscope import snapshot_download

# 指定模型的下载路径
cache_dir = '/root/'
# 调用 snapshot_download 函数下载模型
model_dir = snapshot_download('Qwen/Qwen2.5-0.5B-Instruct', cache_dir=cache_dir)

print(f"模型已下载到: {model_dir}")

3 训练demo

demo.py

执行脚本前,输入:

复制代码
source /etc/network_turbo

from trl import SFTTrainer
from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train")

trainer = SFTTrainer(
    model="/root/Qwen/Qwen2.5-0.5B-Instruct",
    train_dataset=dataset,
)
trainer.train()
复制代码
00:15<1:57:58,

4 在多个 GPU/节点上进行训练

执行脚本前,输入:

复制代码
source /etc/network_turbo
bash 复制代码
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes 2 demo.py --all_arguments_of_the_script

总结

一块L20 GPU 48G,需要2小时

两块L20 GPU 48G,需要0.5小时

速度提升明显

相关推荐
陈苏同学22 分钟前
MPC控制器从入门到进阶(小车动态避障变道仿真 - Python)
人工智能·python·机器学习·数学建模·机器人·自动驾驶
努力毕业的小土博^_^43 分钟前
【深度学习|学习笔记】 Generalized additive model广义可加模型(GAM)详解,附代码
人工智能·笔记·深度学习·神经网络·学习
小小鱼儿小小林1 小时前
用AI制作黑神话悟空质感教程,3D西游记裸眼效果,西游人物跳出书本
人工智能·3d·ai画图
浪淘沙jkp1 小时前
AI大模型学习二十、利用Dify+deepseekR1 使用知识库搭建初中英语学习智能客服机器人
人工智能·llm·embedding·agent·知识库·dify·deepseek
AndrewHZ3 小时前
【图像处理基石】什么是油画感?
图像处理·人工智能·算法·图像压缩·视频处理·超分辨率·去噪算法
Robot2514 小时前
「华为」人形机器人赛道投资首秀!
大数据·人工智能·科技·microsoft·华为·机器人
J先生x4 小时前
【IP101】图像处理进阶:从直方图均衡化到伽马变换,全面掌握图像增强技术
图像处理·人工智能·学习·算法·计算机视觉
Narutolxy7 小时前
大模型数据分析破局之路20250512
人工智能·chatgpt·数据分析
浊酒南街7 小时前
TensorFlow中数据集的创建
人工智能·tensorflow
2301_787552878 小时前
console-chat-gpt开源程序是用于 AI Chat API 的 Python CLI
人工智能·python·gpt·开源·自动化