Trainer介绍
Trainer 是 Hugging Face transformers 库中的一个核心API,它为PyTorch模型提供了一个功能完整的训练和评估循环。它的主要目标是简化训练流程,让你不需要手动编写繁琐的训练代码,可以更专注于模型、数据和参数本身。
简单来说,你只需要把模型、数据集和训练参数"喂"给 Trainer,调用 .train() 方法,它就会自动处理背后几乎所有复杂的事情,比如:
-
自动训练循环 :自动处理前向传播、计算损失、反向传播和梯度更新。
-
分布式训练支持:轻松地在多GPU、TPU上进行训练。
-
混合精度训练 :通过简单的参数开启
fp16或bf16训练,以加速训练并节省显存。 -
回调与日志:支持TensorBoard等日志工具,并可通过回调自定义行为。
-
使用流程如下:
-
加载数据集
-
数据预处理
-
准备训练参数
-
准备模型
-
创建Trainer并开始训练
-
config文件
python
import torch
import datetime
from transformers.models import BertModel, BertTokenizer, BertConfig
# 获取当前日期字符串,用于模型文件等命名
# print('当前日期--->\n', datetime.datetime.now().date())
current_date = datetime.datetime.now().date().strftime("%Y%m%d")
# print('当前日期--->\n', type(current_date), current_date)
class Config(object):
"""
配置类Config
该类用于集中管理项目开发和训练/推理阶段涉及到的所有重要参数,包括数据路径、模型参数、类别信息等。
属性说明:
model_name (str): 模型名称(一般为'bert')。
data_path (str): 数据集存放的主目录路径。
train_path (str): 训练集文件路径,格式通常为每行'文本\t标签'。
dev_path (str): 验证集(开发集)文件路径,用于模型调优。
test_path (str): 测试集文件路径,用于最终评估。
class_path (str): 类别文件路径,存放所有标签类别(每行一类)。
class_list (List[str]): 标签类别列表,从class_path读取,每项为类别名。
model_save_path (str): 模型训练好后权重/配置的保存路径。
device (torch.device): 训练/推理时使用的硬件设备(cpu或cuda)。
num_classes (int): 类别数,根据class_list自适应计算。
num_epochs (int): 训练总轮数。
batch_size (int): 每一批次(batch)的数据条数。
pad_size (int): 句子最大填充/截断长度(超出截断,不足补0)。
learning_rate (float): 优化器学习率。
bert_path (str): 预训练BERT模型的文件目录。
bert_model (BertModel): 加载的BERT主干神经网络(transformers实现)。
tokenizer (BertTokenizer):BERT分词器(与模型完全对应)。
bert_config (BertConfig): BERT结构参数对象,方便后续定义自有模型头部。
hidden_size (int): BERT编码输出的向量维度(base版通常为768)。
output_dir (str): transformers Trainer输出目录(如模型、日志等)。
logging_dir (str): 日志保存目录。
warmup_steps (int): 学习率预热步数。
weight_decay (float): 优化器的权重衰减系数。
logging_steps (int): 打印日志的步频。
eval_steps (int): 验证评估的间隔步数。
save_steps (int): 检查点保存的间隔步数。
save_total_limit (int): 最多仅保留多少个最新模型,超过会自动淘汰旧文件。
"""
def __init__(self):
# ========== 路径相关 ==========
self.model_name = "bert" # 模型名称前缀(可用于文件命名等)
self.data_path = "../../01-data" # 数据集存放根目录
self.train_path = self.data_path + "/train.txt" # 训练集文件完整路径
self.dev_path = self.data_path + "/dev3.txt" # 验证集文件路径
self.test_path = self.data_path + "/test.txt" # 测试集文件路径
self.class_path = self.data_path + "/class.txt" # 存储类别标签名称的文本文件路径
# 读取类别列表,每行一个类别
# 结果如 ['体育', '财经', ...], 按文件顺序
self.class_list = [line.strip() for line in open(self.class_path, encoding="utf-8")]
self.model_save_path = (f"../save_models/bertclassifier_model_{current_date}") # 模型训练后保存的主文件夹路径
# ========== 硬件参数 ==========
# 检测cuda(GPU)是否可用,优先用GPU,否则回退为CPU
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 设备类型
# ========== 数据/超参数 ==========
self.num_classes = len(self.class_list) # 标签类别个数(自动适应类别文件内容)
self.num_epochs = 2 # 训练总轮次,轮数通常取决于数据量可调整
self.batch_size = 8 # 单次batch处理的样本数
self.pad_size = 32 # 每句话最大长度(超出则截断,短的补齐)
self.learning_rate = 5e-5 # 优化器学习率
# ========== 预训练BERT模型参数 ==========
self.bert_path = "../bert-base-chinese" # 磁盘中的预训练BERT主目录
# 加载BERT主干模型(transformers BertModel),需要与任务gpu/cpu匹配
self.bert_model = BertModel.from_pretrained(self.bert_path)
# 加载与BERT结构匹配的分词器
self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
# 加载BERT配置对象(如hidden_size、层数等,可以用于自定义模型)
self.bert_config = BertConfig.from_pretrained(self.bert_path)
# 通常base模型为768,large则为1024
# self.hidden_size = 768
self.hidden_size = self.bert_config.hidden_size
# ========== 训练Trainer相关配置 ==========
self.output_dir = "./training_output" # transformers训练输出文件夹
self.logging_dir = "./logs" # 日志文件保存目录
self.warmup_steps = 500 # 学习率预热步数(可视具体任务适当调整)
self.weight_decay = 0.01 # 权重衰减,防止过拟合
self.logging_steps = 100 # 多久打印一次训练日志
self.eval_steps = 500 # 每隔多少步进行一次评估
self.save_steps = 500 # 每隔多少步保存一次模型checkpoint
self.save_total_limit = 2 # 只保留最近N个训练模型,防止磁盘爆满
if __name__ == "__main__":
# 测试用途:打印部分关键信息以确认配置加载无误
conf = Config()
print("BERT模型结构配置:\n", conf.bert_config)
print("BERT模型结构:\n", conf.bert_model)
# 测试分词器将中文token转换为BERT词表下的ID
input_size = conf.tokenizer.convert_tokens_to_ids(["你", "好", "中国", "人"])
print("分词器ID编码示例:", input_size)
print("类别列表:", conf.class_list)
utils文件
python
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from config import Config
# 实例化config类对象
conf = Config()
# todo:加载数据集
def load_data(path):
"""
:param path: 文件路径
:return: [(文本句子, 标签下标), (文本句子, 标签下标), ...]
"""
datas_list = []
# 读取文件数据集
with open(path, 'r', encoding='utf-8') as f:
# 循环遍历文件中的每一行
for line in tqdm(f, desc='Loading data'):
# 去掉末尾换行符
line = line.strip()
# 判断行数据是否为空
# 为空跳过
if not line:
continue
# 不为空, 进行分割 \t
text, label = line.split('\t')
# 将分割结果保存到元组并保存到列表中
datas_list.append((text, int(label)))
return datas_list
# todo:构建dataset数据集对象
class TextDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, item):
"""
:param item: 数据集中行索引 样本索引
:return:
"""
# 获取样本数据中的x和y两部分
x = self.data[item][0]
y = self.data[item][1]
return x, y
# 封装函数, 获取三份数据集对象
def build_datasets():
train_data = load_data(conf.train_path)
test_data = load_data(conf.test_path)
dev_data = load_data(conf.dev_path)
train_dataset = TextDataset(train_data)
test_dataset = TextDataset(test_data)
dev_dataset = TextDataset(dev_data)
# Trainer参数要求, 接收dataset数据集对象
return train_dataset, test_dataset, dev_dataset
# todo:构建dataloader数据加载器
# collate_fn自定义函数
def collate_fn(batch):
"""
批次样本数据处理
:param batch: 批次样本 [(文本句子, 标签下标), (文本句子, 标签下标), ...]
:return: {input_ids:xxx, attention_mask:xxx, labels:xxx}
"""
# print('batch--->\n', batch)
# 获取批次样本的texts和labels两部分数据, 存储到两个列表中
texts = [item[0] for item in batch]
labels = [item[1] for item in batch]
# print('texts--->\n', texts)
# print('labels--->\n', labels)
# 通过分词器将texts进行数据处理
# inputs = conf.tokenizer(texts,
# padding='max_length',
# truncation=True,
# max_length=conf.pad_size,
# return_tensors='pt')
inputs = conf.tokenizer(texts,
padding=True,
return_tensors='pt')
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
# labels转换成张量对象
labels = torch.tensor(data=labels, dtype=torch.long)
# 返回字典, 后续trainer对象中模型预测是通过 model(**inputs) 方式实现, 对字典进行拆包
return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}
def build_dataloaders():
# 加载数据集
train_dataset, test_dataset, dev_dataset = build_datasets()
# 创建dataloader对象
train_dataloader = DataLoader(train_dataset,
batch_size=conf.batch_size,
shuffle=True,
collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset,
batch_size=conf.batch_size,
shuffle=False,
collate_fn=collate_fn)
dev_dataloader = DataLoader(dev_dataset,
batch_size=conf.batch_size,
shuffle=False,
collate_fn=collate_fn)
return train_dataloader, test_dataloader, dev_dataloader
if __name__ == '__main__':
train_dataloader, test_dataloader, dev_dataloader = build_dataloaders()
for i in train_dataloader:
print(i['input_ids'].shape, i['input_ids'])
print(i['attention_mask'])
print(i['labels'])
exit()
model定义文件
python
import torch
import torch.nn as nn
from transformers import PreTrainedModel, BertConfig
from transformers.modeling_outputs import SequenceClassifierOutput
from config import Config
from utils import build_dataloaders
# 实例化全局配置对象,包含模型参数、类别数、device等关键信息
conf = Config()
# 创建自定义网络模型类继承PreTrainedModel
class BertClassifier(PreTrainedModel):
# 指定本模型所对应的配置类,供 from_pretrained 使用
config_class = BertConfig
# 指定基础模型前缀,帮助权重加载时对齐 state_dict 键前缀(如 "bert.")
base_model_prefix = "bert"
# init方法
def __init__(self, config=None):
if config is None:
config = conf.bert_config
# 调用父类PreTrainedModel的init方法, 初始化模型参数config
super(BertClassifier, self).__init__(config)
# 实例化预训练模型结构
self.bert = conf.bert_model
# 实例化输出层
self.fc = nn.Linear(conf.hidden_size, conf.num_classes)
# forward方法
def forward(self, input_ids, attention_mask, labels=None, return_dict=True):
# 预训练模型计算
outputs, pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=False)
# print('outputs--->\n', outputs.shape, outputs)
# print('pooled_output--->\n', pooled_output.shape, pooled_output)
# 输出层计算
logits = self.fc(pooled_output)
# print('logits--->\n', logits.shape, logits)
# 计算损失值
loss = None
if labels is not None:
# 实例化损失器对象
criterion = nn.CrossEntropyLoss()
# 调用对象
# 损失器形状要求: 预测值(batch_size, num_classes) 真实值(batch_size,)
loss = criterion(logits.view(-1, conf.num_classes), labels.view(-1))
if return_dict: # 为True时, 返回SequenceClassifierOutput对象
return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=None, attentions=None)
else: # 返False时, 并且labels不为None时, 返回(loss, logits), labels为None时, 仅返回logits
return (loss, logits) if labels is not None else (logits,)
if __name__ == '__main__':
# 构建数据加载器
train_dataloader, test_dataloader, dev_dataloader = build_dataloaders()
# 实例化模型对象
model = BertClassifier().to(conf.device)
# print('model--->\n', model)
# 循环遍历数据加载器
for batch in train_dataloader:
input_ids = batch['input_ids'].to(conf.device)
attention_mask = batch['attention_mask'].to(conf.device)
labels = batch['labels'].to(conf.device)
output = model(input_ids, attention_mask, labels=labels, return_dict=True)
print('output--->\n', type(output), output)
logits = output.logits
print('logits--->\n', logits.shape, logits)
loss = output.loss
print('loss--->\n', loss.shape, loss)
exit()
模型训练
python
from sklearn.metrics import f1_score, accuracy_score, precision_score
from transformers import TrainingArguments, Trainer
from config import Config
from utils import build_datasets, collate_fn
from bert_classifier_model import BertClassifier
import warnings
warnings.filterwarnings("ignore")
# 加载配置对象,包含模型、数据路径、训练超参数等
conf = Config()
# 评估函数
def compute_metrics(eval_preds):
"""
:param eval_preds: 固定参数, 固定格式 元组类型(预测值logits, 真实值labels)
:return: 评估指标
"""
predictions, labels = eval_preds
# 将logits转换为分类id
predictions = predictions.argmax(axis=-1)
# 微平均 F1
f1 = f1_score(labels, predictions, average="micro")
# 总体准确率
accuracy = accuracy_score(labels, predictions)
# 微平均精确率
precision = precision_score(labels, predictions, average="micro")
# 返回供Trainer自动记录的指标(eval_ 前缀不可变)
return {"eval_f1": f1, "eval_accuracy": accuracy, "eval_precision": precision}
# 训练函数
def model2train():
# 加载数据集对象
train_dataset, test_dataset, dev_dataset = build_datasets()
# 实例化模型对象
model = BertClassifier() # 不需要选择设备, 不需要切换模型模式
# 实例化模型参数对象
train_args = TrainingArguments(output_dir=conf.output_dir, # 输出目录
num_train_epochs=conf.num_epochs, # 训练轮数
per_device_train_batch_size=conf.batch_size, # 每卡/每设备训练batch
per_device_eval_batch_size=conf.batch_size, # 每卡/每设备验证batch
warmup_steps=conf.warmup_steps, # 学习率预热步数
weight_decay=conf.weight_decay, # 权重衰减
learning_rate=conf.learning_rate, # 学习率
logging_dir=conf.logging_dir, # 日志输出目录
logging_steps=conf.logging_steps, # 日志打印间隔(步)
eval_strategy="steps", # 评估触发模式(新参数名) 旧参数名:valuation_strategy
eval_steps=conf.eval_steps, # 评估间隔(步)
save_strategy="steps", # 保存模式
save_steps=conf.save_steps, # 保存间隔(步)
save_total_limit=conf.save_total_limit, # 最多保存几个模型
load_best_model_at_end=True, # 训练结束后自动恢复最优模型
metric_for_best_model="eval_f1", # 以何指标为"最优模型"
greater_is_better=True) # 指标越大越好
# 实例化训练器对象 trainer对象
trainer = Trainer(model=model,
tokenizer=conf.tokenizer,
args=train_args,
train_dataset=train_dataset,
eval_dataset=dev_dataset,
compute_metrics=compute_metrics,
data_collator=collate_fn)
# 训练模型
print("开始训练...")
trainer.train()
# 保存模型
trainer.save_model(conf.model_save_path)
print(f"模型已保存到: {conf.model_save_path}")
# 模型测试
print("在测试集上评估...")
test_results = trainer.evaluate(test_dataset)
print("测试集结果:")
for key, value in test_results.items():
print(f"{key}: {value:.4f}")
if __name__ == "__main__":
# 主程序入口:调用训练主流程
print("开始使用TrainingArguments和Trainer进行训练...")
model2train()
print("\n训练完成!")
开放模型推理
python
import torch
from bert_classifier_model import BertClassifier
from config import Config
# 加载配置对象(Config类负责所有全局超参数、路径、类别名、分词器等)
conf = Config()
device = conf.device # 设备(cuda/cpu)
tokenizer = conf.tokenizer # BERT分词器
# 加载模型对象
# model = BertClassifier.from_pretrained(r'E:\TMF\code\04-bert\save_models\bertclassifier_model_20251017')
model = BertClassifier.from_pretrained(conf.model_save_path).to(conf.device)
model.eval()
# 封装推理函数
def predict(text):
"""
:param text: {text:xxxxx}
:return: {text:xxxx, pred_class:xxx}
"""
# 获取文本数据 x
text = text.get('text', "")
# print('text--->\n', text)
# 判断x数据类型 以及 是否为空, 返回预测值None
if not isinstance(text, str) or not text.strip():
return {'text': text, 'pred_class': None}
# 调用分词器对象进行处理
inputs = tokenizer.encode_plus(text, return_tensors='pt')
# print('inputs--->\n', inputs)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
# 模型预测
with torch.no_grad():
outputs = model(input_ids, attention_mask)
# print('outputs--->\n', outputs)
# isinstance:判断对象是否为指定类型
# hasattr: 判断对象是否包含指定属性
if isinstance(outputs, tuple) or hasattr(outputs, 'logits'):
output =outputs.logits if hasattr(outputs, 'logits') else outputs[0]
else:
output = outputs
# 获取预测下标
pred_index = torch.argmax(output, dim=-1)
# print('pred_index--->\n', pred_index)
# 获取预测类别名称
pred_class = conf.class_list[pred_index]
# print('pred_class--->\n', pred_class)
return {'text': text, 'pred_class': pred_class}
if __name__ == '__main__':
# 测试示例
sample_data = {"text": "中华女子学院:本科层次仅1专业招男生"} # 示例输入数据
result = predict(sample_data)
print(result)
后端api服务
python
from flask import Flask, request, jsonify
from predict_fun import predict
import warnings
warnings.filterwarnings('ignore')
# todo:1-创建app对象
app = Flask(__name__)
# todo:2-创建路由
@app.route('/predict', methods=['POST'])
def predict_api():
# 获取前端数据
data = request.get_json()
print('data--->\n', data)
# 判断是否有数据, 没有收集异常信息
if not data or 'text' not in data:
# 状态码: 2xx->请求成功 3xx->重定向 4xx->请求端报错 5xx->服务端报错
return jsonify({'error': 'Missing text field in JSON'}), 400
# 调用模型预测接口实现预测
result = predict(data)
print('result--->\n', result)
# 返回json结果
return jsonify(result)
if __name__ == '__main__':
# 启动服务端
app.run(host='0.0.0.0', port=8000, debug=True)
后端api测试
python
# 不要求掌握
import requests
import time
# 定义预测接口地址
url = 'http://127.0.0.1:8000/predict'
# 构造请求数据
data = {'text': "中国人民公安大学2012年硕士研究生目录及书目"}
start_time = time.time()
try:
# 发送post请求, 获取响应对象
response = requests.post(url, json=data)
print('response--->\n', response)
# 耗时
duration = (time.time() - start_time) * 1000 # ms
print(f'耗时: {duration:.2f}ms')
# 判断状态码是否为200, 如果是, 获取响应数据
if response.status_code == 200:
result = response.json()
print('result--->\n', type(result), result)
print('预测结果--->\n', result['pred_class'])
# 如果不是, 获取错误信息
else:
error = response.json()['error']
print(print(f"请求失败: {response.status_code}, {error}"))
except Exception as e:
print(f"请求出错: {str(e)}")
前端服务
python
import streamlit as st
import requests
import time
# todo:1-设置页面标题
st.title('文本分类系统')
# todo:2-创建输入框
data_text = st.text_area('请输入预测文本:', "中国人民公安大学2012年硕士研究生目录及书目")
# todo:3-创建预测按钮
if st.button('预测'):
# todo:4-调用模型推理接口实现预测
start_time = time.time()
try:
# 构造请求数据
data = {'text': data_text}
url = 'http://127.0.0.1:8000/predict'
# 发送post请求, 获取响应对象
response = requests.post(url, json=data)
duration = (time.time() - start_time) * 1000
# 判断状态码是否为200
if response.status_code == 200:
result = response.json()
# todo:5-显示预测结果
st.success(f"预测结果: {result['pred_class']}")
st.info(f"请求耗时: {duration:.2f}ms")
else:
st.error(f"请求失败: {response.json()['error']}")
except Exception as e:
st.error(f"请求出错: {str(e)}")
# todo:6-页面提示内容
st.write("请确保 Flask API 服务已在 localhost:8000 运行")
