【NLP 58、利用trl框架训练LLM】

孤独总比忍受傻逼好得多

------ 25.4.11

源代码网页:

项目文件预览 - trl:Train transformer language models with reinforcement learning. - GitCode

TRL ------ 变压器强化学习

trl:一个用于后训练基础模型的全面库

1.概述

TRL 是一个利用监督微调(SFT)、近端策略优化(PPO)和直接偏好优化(DPO)等先进技术后训练基础模型的尖端库。构建在 🤗 Transformers 生态系统之上,TRL 支持多种模型架构和模态,并且可以跨各种硬件配置进行扩展。


2.特色

Ⅰ、高效且可拓展

  • 利用 🤗 Accelerate 实现从单 GPU 到多节点集群的扩展,采用 DDP 和 DeepSpeed 等方法。
  • PEFT 完全集成,通过 量化 和 LoRA/QLoRA 在普通硬件上训练大型模型。
  • 集成 Unsloth 以使用优化核心加速训练。

Ⅱ、命令行界面(CLI)

一个简单的界面让您能够在不编写代码的情况下微调并与模型交互。

Ⅲ、训练器

通过如 SFTTrainerDPOTrainerRewardTrainerORPOTrainer 等训练器轻松访问各种微调方法。

Ⅳ、自动模型

使用预定义的模型类如 AutoModelForCausalLMWithValueHead 来简化与大型语言模型(LLM)的强化学习(RL)。


3.安装

Ⅰ、Python 包

使用 pip 安装该库:

python 复制代码
pip3 install trl

使用国产源下载:

python 复制代码
pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple trl

Ⅱ、从源代码安装

python 复制代码
pip3 install git+https://github.com/huggingface/trl.git

Ⅲ、仓库

python 复制代码
git clone https://github.com/huggingface/trl.git

4.命令行界面(CLI)

使用 TRL命令行界面(CLI)快速入门监督微调(SFT)和直接偏好优化(DPO),或者使用聊天CLI来检测你的模型表现

网址: 命令行界面 (CLI)

Ⅰ、SFT

复制代码
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
    --dataset_name trl-lib/Capybara \
    --output_dir Qwen2.5-0.5B-SFT

Ⅱ、数据保护官

复制代码
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
    --dataset_name argilla/Capybara-Preferences \
    --output_dir Qwen2.5-0.5B-DPO 

Ⅲ、聊天

复制代码
trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct

详细了解,请点击上文的**《网址》**,查看相关文档部分


5.使用方法

为了提供更多灵活性和对训练过程的控制,TRL 提供了专门的训练器类,用于在自定义数据集上对语言模型或 PEFT 适配器进行后训练 。TRL 中的每个训练器都是 🤗 Transformers 训练器的轻量级封装,并原生支持分布式训练方法,如 DDP、DeepSpeed ZeRO 和 FSDP。

Ⅰ、SFTTrainer

python 复制代码
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train")

training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
trainer = SFTTrainer(
    args=training_args,
    model="Qwen/Qwen2.5-0.5B",
    train_dataset=dataset,
)
trainer.train()

Ⅱ、奖励训练器使用基础示例

python 复制代码
from trl import RewardConfig, RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForSequenceClassification.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
)
model.config.pad_token_id = tokenizer.pad_token_id

dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")

training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
trainer = RewardTrainer(
    args=training_args,
    model=model,
    processing_class=tokenizer,
    train_dataset=dataset,
)
trainer.train()

Ⅲ、GRPOTrainer

GRPOTrainer 实现了群组相对策略优化(GRPO)算法该算法相较于PPO在内存效率上更优,并被用于训练Deepseek AI的R1

python 复制代码
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

dataset = load_dataset("trl-lib/tldr", split="train")

# Dummy reward function: rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10)
trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

Ⅳ、DPOTrainer

DPOTrainer实现了广受欢迎的直接偏好优化(DPO)算法,该算法被用于对Llama 3以及其他众多模型进行后训练。

python 复制代码
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
trainer.train()
相关推荐
Wendy144111 分钟前
【灰度实验】——图像预处理(OpenCV)
人工智能·opencv·计算机视觉
中杯可乐多加冰22 分钟前
五大低代码平台横向深度测评:smardaten 2.0领衔AI原型设计
人工智能
无线图像传输研究探索33 分钟前
单兵图传终端:移动场景中的 “实时感知神经”
网络·人工智能·5g·无线图传·5g单兵图传
Aronup34 分钟前
NLP学习开始01-线性回归
学习·自然语言处理·线性回归
zzywxc7872 小时前
AI在编程、测试、数据分析等领域的前沿应用(技术报告)
人工智能·深度学习·机器学习·数据挖掘·数据分析·自动化·ai编程
铭keny2 小时前
YOLOv8 基于RTSP流目标检测
人工智能·yolo·目标检测
墨尘游子2 小时前
11-大语言模型—Transformer 盖楼,BERT 装修,RoBERTa 直接 “拎包入住”|预训练白话指南
人工智能·语言模型·自然语言处理
金井PRATHAMA2 小时前
主要分布于内侧内嗅皮层的层Ⅲ的网格-速度联合细胞(Grid × Speed Conjunctive Cells)对NLP中的深层语义分析的积极影响和启示
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理·知识图谱
天道哥哥3 小时前
InsightFace(RetinaFace + ArcFace)人脸识别项目(预训练模型,鲁棒性很好)
人工智能·目标检测
幻风_huanfeng3 小时前
学习人工智能所需知识体系及路径详解
人工智能·学习