Py之trl:trl(一款采用强化学习训练Transformer语言模型和稳定扩散模型的全栈库)的简介、安装、使用方法之详细攻略

Py之trl:trl(一款采用强化学习训练Transformer语言模型和稳定扩散模型的全栈库)的简介、安装、使用方法之详细攻略

目录

trl的简介

1、亮点

2、PPO是如何工作的:PPO对语言模型微调三步骤,Rollout→Evaluation→Optimization

trl的安装

trl的使用方法

1、基础用法

(1)、如何使用库中的SFTTrainer

(2)、如何使用库中的RewardTrainer

(3)、如何使用库中的PPOTrainer

2、进阶用法

LLMs之BELLE:源码解读(ppo_train.py文件)训练一个基于强化学习的自动对话生成模型---解析命令行参数→加载数据集(datasets库)→初始化模型分词器和PPOConfig配置参数(trl库)→模型训练(accelerate分布式训练+DeepSpeed推理加速,生成对话→计算奖励【评估生成质量】→执行PPO算法更新【改善生成文本的质量】)→模型保存之详细攻略

LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)---解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略


trl 的简介

TRL - Transformer Reinforcement Learning使用强化学习的全栈Transformer语言模型。trl 是一个全栈库,其中我们提供一组工具,用于通过强化学习训练Transformer语言模型和稳定扩散模型,从监督微调步骤(SFT)到奖励建模步骤(RM)再到近端策略优化(PPO)步骤。该库建立在Hugging Face 的 transformers 库之上。因此,可以通过 transformers 直接加载预训练语言模型。目前,大多数解码器架构和编码器-解码器架构都得到支持。请参阅文档或示例/文件夹,以查看示例代码片段以及如何运行这些工具。

GitHub地址GitHub - huggingface/trl: Train transformer language models with reinforcement learning.

1、亮点

>> SFTTrainer:一个轻量级且友好的围绕transformer Trainer的包装器,可以在自定义数据集上轻松微调语言模型或适配器。

>> RewardTrainer: transformer Trainer的一个轻量级包装,可以轻松地微调人类偏好的语言模型(Reward Modeling)。

>> potrainer:用于语言模型的PPO训练器,它只需要(查询、响应、奖励)三元组来优化语言模型。

>> AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead:一个转换器模型,每个令牌有一个额外的标量输出,可以用作强化学习中的值函数。

>> 示例:使用BERT情感分类器训练GPT2生成积极的电影评论,仅使用适配器的完整RLHF,训练GPT-j减少毒性,Stack-Llama示例等。

2、 PPO是如何工作的 PPO对语言模型微调三步骤 ,Rollout→Evaluation→Optimization

通过PPO对语言模型进行微调大致包括三个步骤:

|----------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Rollout | Rollout(展开):语言模型基于查询生成响应或继续,查询可以是句子的开头。 |
| Evaluation | Evaluation(评估):使用一个函数、模型、人类反馈或它们的组合来评估查询和响应。重要的是,此过程应为每个查询/响应对产生一个标量值。 |
| Optimization | Optimization(优化):这是最复杂的部分。在优化步骤中,使用查询/响应对来计算序列中token的对数概率。这是通过训练的模型和一个参考模型(通常是微调之前的预训练模型)来完成的。两个输出之间的KL-散度被用作附加奖励信号,以确保生成的响应不会偏离参考语言模型太远。然后,使用PPO训练主动语言模型。 |

这个过程在下面的示意图中说明。

trl 的安装

pip install trl

trl 的使用方法

1、基础用法

(1)、 如何使用库中的SFTTrainer

以下是如何使用库中的SFTTrainer的基本示例。SFTTrainer是用于轻松微调语言模型或适配器的transformers Trainer的轻量包装器。

python 复制代码
# imports
from datasets import load_dataset
from trl import SFTTrainer

# get dataset
dataset = load_dataset("imdb", split="train")

# get trainer
trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
)

# train
trainer.train()

(2)、 如何使用库中的RewardTrainer

以下是如何使用库中的RewardTrainer的基本示例。RewardTrainer是用于轻松微调奖励模型或适配器的transformers Trainer的包装器。

python 复制代码
# imports
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer

# load model and dataset - dataset needs to be in a specific format
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

...

# load trainer
trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
)

# train
trainer.train()

(3)、 如何使用库中的 PPOTrainer

以下是如何使用库中的PPOTrainer的基本示例。基于查询,语言模型创建响应,然后进行评估。评估可以是人工干预或另一个模型的输出。

python 复制代码
# imports
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch

# get models
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)

tokenizer = AutoTokenizer.from_pretrained('gpt2')

# initialize trainer
ppo_config = PPOConfig(
    batch_size=1,
)

# encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")

# get model response
response_tensor  = respond_to_batch(model, query_tensor)

# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)

# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]

# train model for one step with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

2、进阶用法

LLMs之BELLE:源码解读(ppo_train.py文件)训练一个基于强化学习的自动对话生成模型---解析命令行参数→加载数据集(datasets库)→初始化模型分词器和PPOConfig配置参数(trl库)→模型训练(accelerate分布式训练+DeepSpeed推理加速,生成对话→计算奖励【评估生成质量】→执行PPO算法更新【改善生成文本的质量】)→模型保存之详细攻略

https://yunyaniu.blog.csdn.net/article/details/133865725

LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)---解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略

https://yunyaniu.blog.csdn.net/article/details/133873621

相关推荐
余生H1 小时前
transformer.js(三):底层架构及性能优化指南
javascript·深度学习·架构·transformer
代码不行的搬运工1 小时前
神经网络12-Time-Series Transformer (TST)模型
人工智能·神经网络·transformer
罗小罗同学2 小时前
医工交叉入门书籍分享:Transformer模型在机器学习领域的应用|个人观点·24-11-22
深度学习·机器学习·transformer
不去幼儿园4 小时前
【MARL】深入理解多智能体近端策略优化(MAPPO)算法与调参
人工智能·python·算法·机器学习·强化学习
rommel rain1 天前
SpecInfer论文阅读
人工智能·语言模型·transformer
Just Jump2 天前
机器翻译基础与模型 之三:基于自注意力的模型
自然语言处理·transformer·机器翻译
cv君2 天前
视频修复技术和实时在线处理
深度学习·音视频·transformer·视频修复
机器学习之心2 天前
POD-Transformer多变量回归预测(Matlab)
matlab·回归·transformer·pod-transformer
regret~3 天前
【论文笔记】LoFLAT: Local Feature Matching using Focused Linear Attention Transformer
论文阅读·深度学习·transformer
迪菲赫尔曼3 天前
即插即用篇 | YOLOv11 引入高效的直方图Transformer模块 | 突破天气障碍:Histoformer引领高效图像修复新路径
人工智能·深度学习·yolo·目标检测·计算机视觉·transformer·注意力机制