Py之trl:trl(一款采用强化学习训练Transformer语言模型和稳定扩散模型的全栈库)的简介、安装、使用方法之详细攻略

Py之trl:trl(一款采用强化学习训练Transformer语言模型和稳定扩散模型的全栈库)的简介、安装、使用方法之详细攻略

目录

trl的简介

1、亮点

2、PPO是如何工作的:PPO对语言模型微调三步骤,Rollout→Evaluation→Optimization

trl的安装

trl的使用方法

1、基础用法

(1)、如何使用库中的SFTTrainer

(2)、如何使用库中的RewardTrainer

(3)、如何使用库中的PPOTrainer

2、进阶用法

LLMs之BELLE:源码解读(ppo_train.py文件)训练一个基于强化学习的自动对话生成模型---解析命令行参数→加载数据集(datasets库)→初始化模型分词器和PPOConfig配置参数(trl库)→模型训练(accelerate分布式训练+DeepSpeed推理加速,生成对话→计算奖励【评估生成质量】→执行PPO算法更新【改善生成文本的质量】)→模型保存之详细攻略

LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)---解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略


trl 的简介

TRL - Transformer Reinforcement Learning使用强化学习的全栈Transformer语言模型。trl 是一个全栈库,其中我们提供一组工具,用于通过强化学习训练Transformer语言模型和稳定扩散模型,从监督微调步骤(SFT)到奖励建模步骤(RM)再到近端策略优化(PPO)步骤。该库建立在Hugging Face 的 transformers 库之上。因此,可以通过 transformers 直接加载预训练语言模型。目前,大多数解码器架构和编码器-解码器架构都得到支持。请参阅文档或示例/文件夹,以查看示例代码片段以及如何运行这些工具。

GitHub地址GitHub - huggingface/trl: Train transformer language models with reinforcement learning.

1、亮点

>> SFTTrainer:一个轻量级且友好的围绕transformer Trainer的包装器,可以在自定义数据集上轻松微调语言模型或适配器。

>> RewardTrainer: transformer Trainer的一个轻量级包装,可以轻松地微调人类偏好的语言模型(Reward Modeling)。

>> potrainer:用于语言模型的PPO训练器,它只需要(查询、响应、奖励)三元组来优化语言模型。

>> AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead:一个转换器模型,每个令牌有一个额外的标量输出,可以用作强化学习中的值函数。

>> 示例:使用BERT情感分类器训练GPT2生成积极的电影评论,仅使用适配器的完整RLHF,训练GPT-j减少毒性,Stack-Llama示例等。

2、 PPO是如何工作的 PPO对语言模型微调三步骤 ,Rollout→Evaluation→Optimization

通过PPO对语言模型进行微调大致包括三个步骤:

|----------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Rollout | Rollout(展开):语言模型基于查询生成响应或继续,查询可以是句子的开头。 |
| Evaluation | Evaluation(评估):使用一个函数、模型、人类反馈或它们的组合来评估查询和响应。重要的是,此过程应为每个查询/响应对产生一个标量值。 |
| Optimization | Optimization(优化):这是最复杂的部分。在优化步骤中,使用查询/响应对来计算序列中token的对数概率。这是通过训练的模型和一个参考模型(通常是微调之前的预训练模型)来完成的。两个输出之间的KL-散度被用作附加奖励信号,以确保生成的响应不会偏离参考语言模型太远。然后,使用PPO训练主动语言模型。 |

这个过程在下面的示意图中说明。

trl 的安装

pip install trl

trl 的使用方法

1、基础用法

(1)、 如何使用库中的SFTTrainer

以下是如何使用库中的SFTTrainer的基本示例。SFTTrainer是用于轻松微调语言模型或适配器的transformers Trainer的轻量包装器。

python 复制代码
# imports
from datasets import load_dataset
from trl import SFTTrainer

# get dataset
dataset = load_dataset("imdb", split="train")

# get trainer
trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
)

# train
trainer.train()

(2)、 如何使用库中的RewardTrainer

以下是如何使用库中的RewardTrainer的基本示例。RewardTrainer是用于轻松微调奖励模型或适配器的transformers Trainer的包装器。

python 复制代码
# imports
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer

# load model and dataset - dataset needs to be in a specific format
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

...

# load trainer
trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
)

# train
trainer.train()

(3)、 如何使用库中的 PPOTrainer

以下是如何使用库中的PPOTrainer的基本示例。基于查询,语言模型创建响应,然后进行评估。评估可以是人工干预或另一个模型的输出。

python 复制代码
# imports
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch

# get models
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)

tokenizer = AutoTokenizer.from_pretrained('gpt2')

# initialize trainer
ppo_config = PPOConfig(
    batch_size=1,
)

# encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")

# get model response
response_tensor  = respond_to_batch(model, query_tensor)

# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)

# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]

# train model for one step with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

2、进阶用法

LLMs之BELLE:源码解读(ppo_train.py文件)训练一个基于强化学习的自动对话生成模型---解析命令行参数→加载数据集(datasets库)→初始化模型分词器和PPOConfig配置参数(trl库)→模型训练(accelerate分布式训练+DeepSpeed推理加速,生成对话→计算奖励【评估生成质量】→执行PPO算法更新【改善生成文本的质量】)→模型保存之详细攻略

https://yunyaniu.blog.csdn.net/article/details/133865725

LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)---解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略

https://yunyaniu.blog.csdn.net/article/details/133873621

相关推荐
好评笔记2 小时前
AIGC视频扩散模型新星:Video 版本的SD模型
论文阅读·深度学习·机器学习·计算机视觉·面试·aigc·transformer
周杰伦_Jay1 天前
Ollama能本地部署Llama 3等大模型的原因解析(ollama核心架构、技术特性、实际应用)
数据结构·人工智能·深度学习·架构·transformer·llama
好评笔记2 天前
AIGC视频生成模型:ByteDance的PixelDance模型
论文阅读·人工智能·深度学习·机器学习·计算机视觉·aigc·transformer
珊珊而川2 天前
BERT和Transformer模型有什么区别
人工智能·bert·transformer
feifeikon2 天前
深度学习 DAY2:Transformer(一部分)
人工智能·深度学习·transformer
不去幼儿园3 天前
【博客之星】2024年度个人成长、强化学习算法领域总结
人工智能·python·算法·机器学习·强化学习·个人总结
无意21213 天前
【自动驾驶BEV感知之Transformer】
人工智能·自动驾驶·transformer
bug404_3 天前
Restormer: Efficient Transformer for High-Resolution Image Restoration解读
人工智能·深度学习·transformer
GISer Liu3 天前
Transformer详解:Attention机制原理
人工智能·python·gpt·深度学习·机器学习·语言模型·transformer
王了了哇4 天前
精度论文:【Focaler-IoU: More Focused Intersection over Union Loss】
人工智能·pytorch·深度学习·计算机视觉·transformer