在PyTorch里面利用transformers的Trainer微调预训练大模型

背景

transformers提供了非常便捷的api来进行大模型的微调,下面就讲一讲利用Trainer来微调大模型的步骤

第一步:加载预训练的大模型

python 复制代码
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")

第二步:设置训练超参

python 复制代码
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="path/to/save/folder/",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=2,
)

比如这个里面设置了epoch等于2

第三步:获取分词器tokenizer

python 复制代码
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

第四步:加载数据集

python 复制代码
from datasets import load_dataset

dataset = load_dataset("rotten_tomatoes")  # doctest: +IGNORE_RESULT

第五步:创建一个分词函数,指定数据集需要进行分词的字段:

python 复制代码
def tokenize_dataset(dataset):
    return tokenizer(dataset["text"])

第六步:调用map()来将该分词函数应用于整个数据集

python 复制代码
dataset = dataset.map(tokenize_dataset, batched=True)

第七步:使用DataCollatorWithPadding来批量填充数据,加速填充过程:

python 复制代码
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

第八步:初始化Trainer

python 复制代码
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)  # doctest: +SKIP

第九步:开始训练

python 复制代码
trainer.train()

总结:

利用Trainer提供的api,只需要简简单单的九步,十几行代码就能进行大模型的微调,你要不要动手试一试?

相关推荐
什么都想学的阿超1 分钟前
【大语言模型 01】注意力机制数学推导:从零实现Self-Attention
人工智能·语言模型·自然语言处理
大千AI助手2 小时前
SWE-bench:真实世界软件工程任务的“试金石”
人工智能·深度学习·大模型·llm·软件工程·代码生成·swe-bench
天上的光2 小时前
17.迁移学习
人工智能·机器学习·迁移学习
后台开发者Ethan3 小时前
Python需要了解的一些知识
开发语言·人工智能·python
猫头虎3 小时前
猫头虎AI分享|一款Coze、Dify类开源AI应用超级智能体快速构建工具:FastbuildAI
人工智能·开源·prompt·github·aigc·ai编程·ai-native
重启的码农3 小时前
ggml 介绍 (6) 后端 (ggml_backend)
c++·人工智能·神经网络
重启的码农3 小时前
ggml介绍 (7)后端缓冲区 (ggml_backend_buffer)
c++·人工智能·神经网络
数据智能老司机3 小时前
面向企业的图学习扩展——图简介
人工智能·机器学习·ai编程
盼小辉丶3 小时前
PyTorch生成式人工智能——使用MusicGen生成音乐
pytorch·python·深度学习·生成模型