通过微调预训练模型得到自己的模型

通过微调预训练模型得到自己的模型

目录

  1. 简介
  2. 环境准备
  3. 数据准备
  4. 加载预训练模型和Tokenizer
  5. 数据预处理
  6. 设置训练参数
  7. 初始化Trainer并开始训练
  8. 评估和保存模型
  9. 总结

简介

在这篇博客中,我们将介绍如何通过预训练模型进行微调来得到自己的模型。我们将使用Hugging Face的Transformers库和一个BART模型进行示例演示。整个过程包括环境准备、数据准备、模型加载、数据预处理、训练参数设置、训练、评估和保存模型。

环境准备

首先,我们需要安装必要的Python库:

bash 复制代码
pip install transformers datasets torch

数据准备

假设我们有三个数据集:训练集、验证集和测试集,分别存储在JSON文件中。我们将这些数据集加载到内存中。

python 复制代码
import os
from datasets import load_dataset

train_data_name = 'train_data'
valid_data_name = 'valid_data'
test_data_name = 'test_data'

# 顶级数据目录
top_data_dir = '../../data/sql'

raw_data_dir = os.path.join(top_data_dir, 'raw_data/')
train_raw_data_path = os.path.join(raw_data_dir, f'{train_data_name}.json')
valid_raw_data_path = os.path.join(raw_data_dir, f'{valid_data_name}.json')
test_raw_data_path = os.path.join(raw_data_dir, f'{test_data_name}.json')

# 加载JSON数据集,忽略无法解码的字符
dataset = load_dataset('json', data_files={
    'train': train_raw_data_path,
    'validation': valid_raw_data_path,
    'test': test_raw_data_path
})

加载预训练模型和Tokenizer

我们将使用Hugging Face的Transformers库加载预训练的BART模型和对应的Tokenizer。

python 复制代码
from transformers import AutoTokenizer, BartForConditionalGeneration

tokenizer = AutoTokenizer.from_pretrained("./bart-base")
model = BartForConditionalGeneration.from_pretrained("./bart-base").to(device)

数据预处理

定义数据预处理函数,将输入和目标文本进行tokenize,并确保长度一致。

python 复制代码
def preprocess_function(examples):
    inputs = examples['code']
    targets = examples['text']

    # 使用 `max_length` 和 `padding` 确保一致的长度
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding='max_length')
    labels = tokenizer(text_target=targets, max_length=512, truncation=True, padding='max_length')

    model_inputs['labels'] = labels['input_ids']
    return model_inputs

# 应用预处理函数到训练集和验证集
tokenized_datasets = dataset.map(preprocess_function, batched=True)

设置训练参数

设置训练参数,包括输出目录、批量大小、训练轮数等。

python 复制代码
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',          # 输出结果的目录
    evaluation_strategy="epoch",     # 每个epoch进行一次评估
    per_device_train_batch_size=4,   # 每个设备的训练批量大小
    per_device_eval_batch_size=4,    # 每个设备的评估批量大小
    num_train_epochs=3,              # 训练的epoch数量
    save_strategy="epoch",           # 保存策略
    logging_dir='./logs',            # 日志目录
    logging_steps=10,                # 日志记录的步数
    no_cuda=False,                   # 强制使用CPU
    learning_rate=5e-5,              # 调整学习率
    gradient_accumulation_steps=8,   # 梯度累

初始化trainer并开始训练

python 复制代码
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
)
trainer.train()

评估保存模型

python 复制代码
results = trainer.evaluate(eval_dataset=tokenized_datasets['validation'])
print(f"Validation Results: {results}")

model.save_pretrained('./trained_model')
tokenizer.save_pretrained('./trained_model')
相关推荐
Niuguangshuo9 小时前
深入解析Stable Diffusion基石——潜在扩散模型(LDMs)
人工智能·计算机视觉·stable diffusion
迈火9 小时前
SD - Latent - Interposer:解锁Stable Diffusion潜在空间的创意工具
人工智能·gpt·计算机视觉·stable diffusion·aigc·语音识别·midjourney
wfeqhfxz25887829 小时前
YOLO13-C3k2-GhostDynamicConv烟雾检测算法实现与优化
人工智能·算法·计算机视觉
芝士爱知识a10 小时前
2026年AI面试软件推荐
人工智能·面试·职场和发展·大模型·ai教育·考公·智蛙面试
Li emily10 小时前
解决港股实时行情数据 API 接入难题
人工智能·python·fastapi
Aaron158810 小时前
基于RFSOC的数字射频存储技术应用分析
c语言·人工智能·驱动开发·算法·fpga开发·硬件工程·信号处理
J_Xiong011710 小时前
【Agents篇】04:Agent 的推理能力——思维链与自我反思
人工智能·ai agent·推理
星爷AG I11 小时前
9-26 主动视觉(AGI基础理论)
人工智能·计算机视觉·agi
爱吃泡芙的小白白11 小时前
CNN参数量计算全解析:从基础公式到前沿优化
人工智能·神经网络·cnn·参数量
拐爷11 小时前
vibe‑coding 九阳神功之喂:把链接喂成“本地知识”,AI 才能稳定干活(API / 设计 / 报道 / 截图)
人工智能