trl的安装与单GPU多GPU测试

文章目录

  • [0 相关资料](#0 相关资料)
  • [1 源码安装](#1 源码安装)
  • [2 Qwen2.5-0.5B-Instruct 模型下载](#2 Qwen2.5-0.5B-Instruct 模型下载)
  • [3 训练demo](#3 训练demo)
  • [4 在多个 GPU/节点上进行训练](#4 在多个 GPU/节点上进行训练)
  • 总结

0 相关资料

https://github.com/huggingface/trl
https://blog.csdn.net/weixin_42486623/article/details/134326187

TRL 是一个先进的库,专为训练后基础模型而设计,采用了监督微调 (SFT)、近端策略优化 (PPO) 和直接偏好优化 (DPO) 等先进技术。TRL 建立在 🤗 Transformers 生态系统之上,支持多种模型架构和模态,并可在各种硬件配置上进行扩展。

b站视频:https://www.bilibili.com/video/BV18ndfYfEcz/

PyTorch / 2.3.0 / 3.12(ubuntu22.04) / 12.1

1 源码安装

复制代码
source /etc/network_turbo
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .

source /etc/network_turbo
pip install trl transformers datasets accelerate

2 Qwen2.5-0.5B-Instruct 模型下载

https://www.modelscope.cn/models/Qwen/Qwen2.5-0.5B-Instruct

bash 复制代码
source /etc/network_turbo
pip install modelscope

采用SDK方式下载

bash 复制代码
from modelscope import snapshot_download

# 指定模型的下载路径
cache_dir = '/root/'
# 调用 snapshot_download 函数下载模型
model_dir = snapshot_download('Qwen/Qwen2.5-0.5B-Instruct', cache_dir=cache_dir)

print(f"模型已下载到: {model_dir}")

3 训练demo

demo.py

执行脚本前,输入:

复制代码
source /etc/network_turbo

from trl import SFTTrainer
from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train")

trainer = SFTTrainer(
    model="/root/Qwen/Qwen2.5-0.5B-Instruct",
    train_dataset=dataset,
)
trainer.train()
复制代码
00:15<1:57:58,

4 在多个 GPU/节点上进行训练

执行脚本前,输入:

复制代码
source /etc/network_turbo
bash 复制代码
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes 2 demo.py --all_arguments_of_the_script

总结

一块L20 GPU 48G,需要2小时

两块L20 GPU 48G,需要0.5小时

速度提升明显

相关推荐
新智元3 分钟前
刚刚,苹果大模型团队负责人叛逃 Meta!华人 AI 巨星 + 1,年薪飙至 9 位数
人工智能·openai
Cyltcc18 分钟前
如何安装和使用 Claude Code 教程 - Windows 用户篇
人工智能·claude·visual studio code
吹风看太阳1 小时前
机器学习16-总体架构
人工智能·机器学习
moonsims2 小时前
全国产化行业自主无人机智能处理单元-AI飞控+通信一体化模块SkyCore-I
人工智能·无人机
MUTA️2 小时前
ELMo——Embeddings from Language Models原理速学
人工智能·语言模型·自然语言处理
海豚调度2 小时前
Linux 基金会报告解读:开源 AI 重塑经济格局,有人失业,有人涨薪!
大数据·人工智能·ai·开源
T__TIII2 小时前
Dify 插件非正式打包
人工智能
jerwey2 小时前
大语言模型(LLM)按架构分类
人工智能·语言模型·分类
令狐少侠20112 小时前
ai之RAG本地知识库--基于OCR和文本解析器的新一代RAG引擎:RAGFlow 认识和源码剖析
人工智能·ai
小叮当爱咖啡2 小时前
Seq2seq+Attention 机器翻译
人工智能·自然语言处理·机器翻译