trl的安装与单GPU多GPU测试

文章目录

  • [0 相关资料](#0 相关资料)
  • [1 源码安装](#1 源码安装)
  • [2 Qwen2.5-0.5B-Instruct 模型下载](#2 Qwen2.5-0.5B-Instruct 模型下载)
  • [3 训练demo](#3 训练demo)
  • [4 在多个 GPU/节点上进行训练](#4 在多个 GPU/节点上进行训练)
  • 总结

0 相关资料

https://github.com/huggingface/trl
https://blog.csdn.net/weixin_42486623/article/details/134326187

TRL 是一个先进的库,专为训练后基础模型而设计,采用了监督微调 (SFT)、近端策略优化 (PPO) 和直接偏好优化 (DPO) 等先进技术。TRL 建立在 🤗 Transformers 生态系统之上,支持多种模型架构和模态,并可在各种硬件配置上进行扩展。

b站视频:https://www.bilibili.com/video/BV18ndfYfEcz/

PyTorch / 2.3.0 / 3.12(ubuntu22.04) / 12.1

1 源码安装

复制代码
source /etc/network_turbo
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .

source /etc/network_turbo
pip install trl transformers datasets accelerate

2 Qwen2.5-0.5B-Instruct 模型下载

https://www.modelscope.cn/models/Qwen/Qwen2.5-0.5B-Instruct

bash 复制代码
source /etc/network_turbo
pip install modelscope

采用SDK方式下载

bash 复制代码
from modelscope import snapshot_download

# 指定模型的下载路径
cache_dir = '/root/'
# 调用 snapshot_download 函数下载模型
model_dir = snapshot_download('Qwen/Qwen2.5-0.5B-Instruct', cache_dir=cache_dir)

print(f"模型已下载到: {model_dir}")

3 训练demo

demo.py

执行脚本前,输入:

复制代码
source /etc/network_turbo

from trl import SFTTrainer
from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train")

trainer = SFTTrainer(
    model="/root/Qwen/Qwen2.5-0.5B-Instruct",
    train_dataset=dataset,
)
trainer.train()
复制代码
00:15<1:57:58,

4 在多个 GPU/节点上进行训练

执行脚本前,输入:

复制代码
source /etc/network_turbo
bash 复制代码
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes 2 demo.py --all_arguments_of_the_script

总结

一块L20 GPU 48G,需要2小时

两块L20 GPU 48G,需要0.5小时

速度提升明显

相关推荐
潮湿的心情4 分钟前
中宇联:以“智云融合+AI”赋能全栈云MSP服务,深化阿里云生态合作
人工智能·阿里云·云计算
云布道师4 分钟前
【云故事探索】NO.16:阿里云弹性计算加速精准学 AI 教育普惠落地
人工智能·阿里云·云计算
kevin 17 分钟前
AI文档比对和Word的“比较”功能有什么区别?
人工智能·word
1892280486130 分钟前
NX947NX955美光固态闪存NX962NX966
大数据·服务器·网络·人工智能·科技
2501_9248787340 分钟前
无人机光伏巡检缺陷检出率↑32%:陌讯多模态融合算法实战解析
开发语言·人工智能·算法·视觉检测·无人机
沉睡的无敌雄狮43 分钟前
无人机光伏巡检漏检率↓78%!陌讯多模态融合算法实战解析
人工智能·算法·计算机视觉·目标跟踪
Shan12051 小时前
人工智能篇之计算机视觉
人工智能
真智AI1 小时前
打破数据质量瓶颈:用n8n实现30秒专业数据质量报告自动化
大数据·运维·人工智能·python·自动化
echola_mendes1 小时前
Dify案例2:基于Workflow的小红书笔记AI智能体以及AI绘图过程中遇到的问题
人工智能
攻城狮7号1 小时前
中科院开源HYPIR图像复原大模型:1.7秒,老照片变8K画质
人工智能·中科院·开源大模型·hypir图像复原大模型