在PyTorch里面利用transformers的Trainer微调预训练大模型

背景

transformers提供了非常便捷的api来进行大模型的微调,下面就讲一讲利用Trainer来微调大模型的步骤

第一步:加载预训练的大模型

python 复制代码
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")

第二步:设置训练超参

python 复制代码
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="path/to/save/folder/",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=2,
)

比如这个里面设置了epoch等于2

第三步:获取分词器tokenizer

python 复制代码
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

第四步:加载数据集

python 复制代码
from datasets import load_dataset

dataset = load_dataset("rotten_tomatoes")  # doctest: +IGNORE_RESULT

第五步:创建一个分词函数,指定数据集需要进行分词的字段:

python 复制代码
def tokenize_dataset(dataset):
    return tokenizer(dataset["text"])

第六步:调用map()来将该分词函数应用于整个数据集

python 复制代码
dataset = dataset.map(tokenize_dataset, batched=True)

第七步:使用DataCollatorWithPadding来批量填充数据,加速填充过程:

python 复制代码
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

第八步:初始化Trainer

python 复制代码
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)  # doctest: +SKIP

第九步:开始训练

python 复制代码
trainer.train()

总结:

利用Trainer提供的api,只需要简简单单的九步,十几行代码就能进行大模型的微调,你要不要动手试一试?

相关推荐
DogDaoDao几秒前
【GitHub】AgentMemory 深度解析:让 AI 编程代理拥有持久化记忆的 16K+ Star 开源方案
人工智能·开源·大模型·github·aigc·ai编程·aiagent
电子科技圈1 分钟前
大理5G研究院加速建设面向南亚东南亚新一代信息技术产业化合作新通道
人工智能·物联网·5g·网络安全·信息与通信
徐安安ye1 分钟前
FlashAttention 算子深度解析:让大模型在昇腾NPU上跑得更快
python·transformer
山屿落星辰3 分钟前
cann-tools - 昇腾CANN 工具集使用指南
人工智能·pytorch·python
一切皆是因缘际会3 分钟前
终结拟合式智能:记忆博弈心智架构重塑硅基生命进化逻辑
大数据·人工智能·深度学习·机器学习·架构
renhongxia14 分钟前
用知识图谱重构搜索引擎
人工智能·搜索引擎·重构·分类·语音识别·知识图谱
一起聊电气4 分钟前
不止保安全!智慧用电系统解锁照明安全节能双赛道
大数据·网络·人工智能·安全·智能家居·空调
AI技术控5 分钟前
Long-range Brain Graph Transformer 论文解读:用长程依赖建模理解脑网络通信
人工智能·python·深度学习·分类
Mem0rin6 分钟前
[LLM初步]Transformer 模型分类(从架构出发)
深度学习·分类·transformer
肖有米XTKF86466 分钟前
肖有米开发团队:昕之康模式系统开发-昕之康小程序制度商城
大数据·人工智能·团队开发·csdn开发云