trl的安装与单GPU多GPU测试

文章目录

  • [0 相关资料](#0 相关资料)
  • [1 源码安装](#1 源码安装)
  • [2 Qwen2.5-0.5B-Instruct 模型下载](#2 Qwen2.5-0.5B-Instruct 模型下载)
  • [3 训练demo](#3 训练demo)
  • [4 在多个 GPU/节点上进行训练](#4 在多个 GPU/节点上进行训练)
  • 总结

0 相关资料

https://github.com/huggingface/trl
https://blog.csdn.net/weixin_42486623/article/details/134326187

TRL 是一个先进的库,专为训练后基础模型而设计,采用了监督微调 (SFT)、近端策略优化 (PPO) 和直接偏好优化 (DPO) 等先进技术。TRL 建立在 🤗 Transformers 生态系统之上,支持多种模型架构和模态,并可在各种硬件配置上进行扩展。

b站视频:https://www.bilibili.com/video/BV18ndfYfEcz/

PyTorch / 2.3.0 / 3.12(ubuntu22.04) / 12.1

1 源码安装

复制代码
source /etc/network_turbo
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .

source /etc/network_turbo
pip install trl transformers datasets accelerate

2 Qwen2.5-0.5B-Instruct 模型下载

https://www.modelscope.cn/models/Qwen/Qwen2.5-0.5B-Instruct

bash 复制代码
source /etc/network_turbo
pip install modelscope

采用SDK方式下载

bash 复制代码
from modelscope import snapshot_download

# 指定模型的下载路径
cache_dir = '/root/'
# 调用 snapshot_download 函数下载模型
model_dir = snapshot_download('Qwen/Qwen2.5-0.5B-Instruct', cache_dir=cache_dir)

print(f"模型已下载到: {model_dir}")

3 训练demo

demo.py

执行脚本前,输入:

复制代码
source /etc/network_turbo

from trl import SFTTrainer
from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train")

trainer = SFTTrainer(
    model="/root/Qwen/Qwen2.5-0.5B-Instruct",
    train_dataset=dataset,
)
trainer.train()
复制代码
00:15<1:57:58,

4 在多个 GPU/节点上进行训练

执行脚本前,输入:

复制代码
source /etc/network_turbo
bash 复制代码
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes 2 demo.py --all_arguments_of_the_script

总结

一块L20 GPU 48G,需要2小时

两块L20 GPU 48G,需要0.5小时

速度提升明显

相关推荐
怎么全是重名5 分钟前
DeepLab(V3)
人工智能·深度学习·图像分割
m0_6501082410 分钟前
Vision-Language-Action 模型在自动驾驶中的应用(VLA4AD)
论文阅读·人工智能·自动驾驶·端到端自动驾驶·vla4ad·自动驾驶与多模态大模型交叉
爱笑的眼睛1117 分钟前
文本分类的范式演进:从统计概率到语言模型提示工程
java·人工智能·python·ai
星川皆无恙22 分钟前
基于知识图谱+深度学习的大数据NLP医疗知识问答可视化系统(全网最详细讲解及源码/建议收藏)
大数据·人工智能·python·深度学习·自然语言处理·知识图谱
美狐美颜SDK开放平台27 分钟前
自研还是接入第三方?直播美颜sdk与滤镜功能的技术选型分析
人工智能·美颜sdk·直播美颜sdk·美颜api·美狐美颜sdk
weixin_4166600728 分钟前
插件分享:将AI生成的数学公式无损导出为Word文档
人工智能·ai·word·论文·数学公式·deepseek
PM老周30 分钟前
DORA2025:如何用AI提升研发效能(以 ONES MCP Server 为例)
大数据·人工智能
皇族崛起32 分钟前
【众包 + AI智能体】AI境生态巡查平台边防借鉴价值专项调研——以广西边境线治理为例
大数据·人工智能
zhaodiandiandian1 小时前
AI大模型:重构产业生态的核心引擎
人工智能·重构