在PyTorch里面利用transformers的Trainer微调预训练大模型

背景

transformers提供了非常便捷的api来进行大模型的微调,下面就讲一讲利用Trainer来微调大模型的步骤

第一步:加载预训练的大模型

python 复制代码
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")

第二步:设置训练超参

python 复制代码
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="path/to/save/folder/",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=2,
)

比如这个里面设置了epoch等于2

第三步:获取分词器tokenizer

python 复制代码
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

第四步:加载数据集

python 复制代码
from datasets import load_dataset

dataset = load_dataset("rotten_tomatoes")  # doctest: +IGNORE_RESULT

第五步:创建一个分词函数,指定数据集需要进行分词的字段:

python 复制代码
def tokenize_dataset(dataset):
    return tokenizer(dataset["text"])

第六步:调用map()来将该分词函数应用于整个数据集

python 复制代码
dataset = dataset.map(tokenize_dataset, batched=True)

第七步:使用DataCollatorWithPadding来批量填充数据,加速填充过程:

python 复制代码
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

第八步:初始化Trainer

python 复制代码
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)  # doctest: +SKIP

第九步:开始训练

python 复制代码
trainer.train()

总结:

利用Trainer提供的api,只需要简简单单的九步,十几行代码就能进行大模型的微调,你要不要动手试一试?

相关推荐
Kenneth風车3 分钟前
【第十三章:Sentosa_DSML社区版-机器学习聚类】
人工智能·低代码·机器学习·数据分析·聚类
jndingxin10 分钟前
OpenCV运动分析和目标跟踪(4)创建汉宁窗函数createHanningWindow()的使用
人工智能·opencv·目标跟踪
机器之心12 分钟前
o1 带火的 CoT 到底行不行?新论文引发了论战
android·人工智能
机器之心18 分钟前
从架构、工艺到能效表现,全面了解 LLM 硬件加速,这篇综述就够了
android·人工智能
jndingxin1 小时前
OpenCV特征检测(1)检测图像中的线段的类LineSegmentDe()的使用
人工智能·opencv·计算机视觉
@月落1 小时前
alibaba获得店铺的所有商品 API接口
java·大数据·数据库·人工智能·学习
z千鑫1 小时前
【人工智能】如何利用AI轻松将java,c++等代码转换为Python语言?程序员必读
java·c++·人工智能·gpt·agent·ai编程·ai工具
MinIO官方账号1 小时前
从 HDFS 迁移到 MinIO 企业对象存储
人工智能·分布式·postgresql·架构·开源
aWty_2 小时前
机器学习--K-Means
人工智能·机器学习·kmeans
草莓屁屁我不吃2 小时前
AI大语言模型的全面解读
人工智能·语言模型·自然语言处理·chatgpt