基于huggingface库Trainer实现Bert文本分类实战

Trainer介绍

Trainer 是 Hugging Face transformers 库中的一个核心API,它为PyTorch模型提供了一个功能完整的训练和评估循环。它的主要目标是简化训练流程,让你不需要手动编写繁琐的训练代码,可以更专注于模型、数据和参数本身。

简单来说,你只需要把模型、数据集和训练参数"喂"给 Trainer,调用 .train() 方法,它就会自动处理背后几乎所有复杂的事情,比如:

  • 自动训练循环 :自动处理前向传播、计算损失、反向传播和梯度更新

  • 分布式训练支持:轻松地在多GPU、TPU上进行训练。

  • 混合精度训练 :通过简单的参数开启 fp16bf16 训练,以加速训练并节省显存。

  • 回调与日志:支持TensorBoard等日志工具,并可通过回调自定义行为。

  • 使用流程如下:

    1. 加载数据集

    2. 数据预处理

    3. 准备训练参数

    4. 准备模型

    5. 创建Trainer并开始训练

config文件

python 复制代码
import torch
import datetime
from transformers.models import BertModel, BertTokenizer, BertConfig

# 获取当前日期字符串,用于模型文件等命名
# print('当前日期--->\n', datetime.datetime.now().date())
current_date = datetime.datetime.now().date().strftime("%Y%m%d")
# print('当前日期--->\n', type(current_date), current_date)


class Config(object):
    """
    配置类Config

    该类用于集中管理项目开发和训练/推理阶段涉及到的所有重要参数,包括数据路径、模型参数、类别信息等。

    属性说明:
        model_name (str):         模型名称(一般为'bert')。
        data_path (str):          数据集存放的主目录路径。
        train_path (str):         训练集文件路径,格式通常为每行'文本\t标签'。
        dev_path (str):           验证集(开发集)文件路径,用于模型调优。
        test_path (str):          测试集文件路径,用于最终评估。
        class_path (str):         类别文件路径,存放所有标签类别(每行一类)。
        class_list (List[str]):   标签类别列表,从class_path读取,每项为类别名。
        model_save_path (str):    模型训练好后权重/配置的保存路径。
        device (torch.device):    训练/推理时使用的硬件设备(cpu或cuda)。
        num_classes (int):        类别数,根据class_list自适应计算。
        num_epochs (int):         训练总轮数。
        batch_size (int):         每一批次(batch)的数据条数。
        pad_size (int):           句子最大填充/截断长度(超出截断,不足补0)。
        learning_rate (float):    优化器学习率。
        bert_path (str):          预训练BERT模型的文件目录。
        bert_model (BertModel):   加载的BERT主干神经网络(transformers实现)。
        tokenizer (BertTokenizer):BERT分词器(与模型完全对应)。
        bert_config (BertConfig): BERT结构参数对象,方便后续定义自有模型头部。
        hidden_size (int):        BERT编码输出的向量维度(base版通常为768)。
        output_dir (str):         transformers Trainer输出目录(如模型、日志等)。
        logging_dir (str):        日志保存目录。
        warmup_steps (int):       学习率预热步数。
        weight_decay (float):     优化器的权重衰减系数。
        logging_steps (int):      打印日志的步频。
        eval_steps (int):         验证评估的间隔步数。
        save_steps (int):         检查点保存的间隔步数。
        save_total_limit (int):   最多仅保留多少个最新模型,超过会自动淘汰旧文件。
    """

    def __init__(self):
        # ========== 路径相关 ==========
        self.model_name = "bert"  # 模型名称前缀(可用于文件命名等)
        self.data_path = "../../01-data"  # 数据集存放根目录
        self.train_path = self.data_path + "/train.txt"  # 训练集文件完整路径
        self.dev_path = self.data_path + "/dev3.txt"  # 验证集文件路径
        self.test_path = self.data_path + "/test.txt"  # 测试集文件路径

        self.class_path = self.data_path + "/class.txt"  # 存储类别标签名称的文本文件路径
        # 读取类别列表,每行一个类别
        # 结果如 ['体育', '财经', ...], 按文件顺序
        self.class_list = [line.strip() for line in open(self.class_path, encoding="utf-8")]

        self.model_save_path = (f"../save_models/bertclassifier_model_{current_date}")  # 模型训练后保存的主文件夹路径

        # ========== 硬件参数 ==========
        # 检测cuda(GPU)是否可用,优先用GPU,否则回退为CPU
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 设备类型

        # ========== 数据/超参数 ==========
        self.num_classes = len(self.class_list)  # 标签类别个数(自动适应类别文件内容)
        self.num_epochs = 2  # 训练总轮次,轮数通常取决于数据量可调整
        self.batch_size = 8  # 单次batch处理的样本数
        self.pad_size = 32  # 每句话最大长度(超出则截断,短的补齐)
        self.learning_rate = 5e-5  # 优化器学习率

        # ========== 预训练BERT模型参数 ==========
        self.bert_path = "../bert-base-chinese"  # 磁盘中的预训练BERT主目录
        # 加载BERT主干模型(transformers BertModel),需要与任务gpu/cpu匹配
        self.bert_model = BertModel.from_pretrained(self.bert_path)
        # 加载与BERT结构匹配的分词器
        self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
        # 加载BERT配置对象(如hidden_size、层数等,可以用于自定义模型)
        self.bert_config = BertConfig.from_pretrained(self.bert_path)
        # 通常base模型为768,large则为1024
        # self.hidden_size = 768
        self.hidden_size = self.bert_config.hidden_size

        # ========== 训练Trainer相关配置 ==========
        self.output_dir = "./training_output"  # transformers训练输出文件夹
        self.logging_dir = "./logs"  # 日志文件保存目录
        self.warmup_steps = 500  # 学习率预热步数(可视具体任务适当调整)
        self.weight_decay = 0.01  # 权重衰减,防止过拟合
        self.logging_steps = 100  # 多久打印一次训练日志
        self.eval_steps = 500  # 每隔多少步进行一次评估
        self.save_steps = 500  # 每隔多少步保存一次模型checkpoint
        self.save_total_limit = 2  # 只保留最近N个训练模型,防止磁盘爆满


if __name__ == "__main__":
    # 测试用途:打印部分关键信息以确认配置加载无误
    conf = Config()
    print("BERT模型结构配置:\n", conf.bert_config)
    print("BERT模型结构:\n", conf.bert_model)
    # 测试分词器将中文token转换为BERT词表下的ID
    input_size = conf.tokenizer.convert_tokens_to_ids(["你", "好", "中国", "人"])
    print("分词器ID编码示例:", input_size)
    print("类别列表:", conf.class_list)

utils文件

python 复制代码
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from config import Config

# 实例化config类对象
conf = Config()


# todo:加载数据集
def load_data(path):
    """
    :param path: 文件路径
    :return: [(文本句子, 标签下标), (文本句子, 标签下标), ...]
    """
    datas_list = []
    # 读取文件数据集
    with open(path, 'r', encoding='utf-8') as f:
        # 循环遍历文件中的每一行
        for line in tqdm(f, desc='Loading data'):
            # 去掉末尾换行符
            line = line.strip()
            # 判断行数据是否为空
            # 为空跳过
            if not line:
                continue
            # 不为空, 进行分割 \t
            text, label = line.split('\t')
            # 将分割结果保存到元组并保存到列表中
            datas_list.append((text, int(label)))

    return datas_list


# todo:构建dataset数据集对象
class TextDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        """
        :param item: 数据集中行索引 样本索引
        :return:
        """
        # 获取样本数据中的x和y两部分
        x = self.data[item][0]
        y = self.data[item][1]
        return x, y


# 封装函数, 获取三份数据集对象
def build_datasets():
    train_data = load_data(conf.train_path)
    test_data = load_data(conf.test_path)
    dev_data = load_data(conf.dev_path)

    train_dataset = TextDataset(train_data)
    test_dataset = TextDataset(test_data)
    dev_dataset = TextDataset(dev_data)
    # Trainer参数要求, 接收dataset数据集对象
    return train_dataset, test_dataset, dev_dataset


# todo:构建dataloader数据加载器
# collate_fn自定义函数
def collate_fn(batch):
    """
    批次样本数据处理
    :param batch: 批次样本 [(文本句子, 标签下标), (文本句子, 标签下标), ...]
    :return: {input_ids:xxx, attention_mask:xxx, labels:xxx}
    """
    # print('batch--->\n', batch)
    # 获取批次样本的texts和labels两部分数据, 存储到两个列表中
    texts = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    # print('texts--->\n', texts)
    # print('labels--->\n', labels)

    # 通过分词器将texts进行数据处理
    # inputs = conf.tokenizer(texts,
    #                         padding='max_length',
    #                         truncation=True,
    #                         max_length=conf.pad_size,
    #                         return_tensors='pt')
    inputs = conf.tokenizer(texts,
                            padding=True,
                            return_tensors='pt')
    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']

    # labels转换成张量对象
    labels = torch.tensor(data=labels, dtype=torch.long)

    # 返回字典, 后续trainer对象中模型预测是通过 model(**inputs) 方式实现, 对字典进行拆包
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}


def build_dataloaders():
    # 加载数据集
    train_dataset, test_dataset, dev_dataset = build_datasets()

    # 创建dataloader对象
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=conf.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_fn)

    test_dataloader = DataLoader(test_dataset,
                                 batch_size=conf.batch_size,
                                 shuffle=False,
                                 collate_fn=collate_fn)

    dev_dataloader = DataLoader(dev_dataset,
                                batch_size=conf.batch_size,
                                shuffle=False,
                                collate_fn=collate_fn)

    return train_dataloader, test_dataloader, dev_dataloader


if __name__ == '__main__':
    train_dataloader, test_dataloader, dev_dataloader = build_dataloaders()

    for i in train_dataloader:
        print(i['input_ids'].shape, i['input_ids'])
        print(i['attention_mask'])
        print(i['labels'])
        exit()

model定义文件

python 复制代码
import torch
import torch.nn as nn
from transformers import PreTrainedModel, BertConfig
from transformers.modeling_outputs import SequenceClassifierOutput
from config import Config
from utils import build_dataloaders

# 实例化全局配置对象,包含模型参数、类别数、device等关键信息
conf = Config()


# 创建自定义网络模型类继承PreTrainedModel
class BertClassifier(PreTrainedModel):
    # 指定本模型所对应的配置类,供 from_pretrained 使用
    config_class = BertConfig
    # 指定基础模型前缀,帮助权重加载时对齐 state_dict 键前缀(如 "bert.")
    base_model_prefix = "bert"

    # init方法
    def __init__(self, config=None):
        if config is None:
            config = conf.bert_config
        # 调用父类PreTrainedModel的init方法, 初始化模型参数config
        super(BertClassifier, self).__init__(config)

        # 实例化预训练模型结构
        self.bert = conf.bert_model

        # 实例化输出层
        self.fc = nn.Linear(conf.hidden_size, conf.num_classes)

    # forward方法
    def forward(self, input_ids, attention_mask, labels=None, return_dict=True):
        # 预训练模型计算
        outputs, pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=False)
        # print('outputs--->\n', outputs.shape, outputs)
        # print('pooled_output--->\n', pooled_output.shape, pooled_output)

        # 输出层计算
        logits = self.fc(pooled_output)
        # print('logits--->\n', logits.shape, logits)

        # 计算损失值
        loss = None
        if labels is not None:
            # 实例化损失器对象
            criterion = nn.CrossEntropyLoss()
            # 调用对象
            # 损失器形状要求: 预测值(batch_size, num_classes) 真实值(batch_size,)
            loss = criterion(logits.view(-1, conf.num_classes), labels.view(-1))

        if return_dict:  # 为True时, 返回SequenceClassifierOutput对象
            return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=None, attentions=None)
        else:  # 返False时, 并且labels不为None时, 返回(loss, logits), labels为None时, 仅返回logits
            return (loss, logits) if labels is not None else (logits,)


if __name__ == '__main__':
    # 构建数据加载器
    train_dataloader, test_dataloader, dev_dataloader = build_dataloaders()
    # 实例化模型对象
    model = BertClassifier().to(conf.device)
    # print('model--->\n', model)
    # 循环遍历数据加载器
    for batch in train_dataloader:
        input_ids = batch['input_ids'].to(conf.device)
        attention_mask = batch['attention_mask'].to(conf.device)
        labels = batch['labels'].to(conf.device)
        output = model(input_ids, attention_mask, labels=labels, return_dict=True)
        print('output--->\n', type(output), output)
        logits = output.logits
        print('logits--->\n', logits.shape, logits)
        loss = output.loss
        print('loss--->\n', loss.shape, loss)
        exit()

模型训练

python 复制代码
from sklearn.metrics import f1_score, accuracy_score, precision_score
from transformers import TrainingArguments, Trainer
from config import Config
from utils import build_datasets, collate_fn
from bert_classifier_model import BertClassifier
import warnings

warnings.filterwarnings("ignore")

# 加载配置对象,包含模型、数据路径、训练超参数等
conf = Config()

# 评估函数
def compute_metrics(eval_preds):
    """
    :param eval_preds: 固定参数, 固定格式 元组类型(预测值logits, 真实值labels)
    :return: 评估指标
    """
    predictions, labels = eval_preds
    # 将logits转换为分类id
    predictions = predictions.argmax(axis=-1)
    # 微平均 F1
    f1 = f1_score(labels, predictions, average="micro")
    # 总体准确率
    accuracy = accuracy_score(labels, predictions)
    # 微平均精确率
    precision = precision_score(labels, predictions, average="micro")
    # 返回供Trainer自动记录的指标(eval_ 前缀不可变)
    return {"eval_f1": f1, "eval_accuracy": accuracy, "eval_precision": precision}

# 训练函数
def model2train():
    # 加载数据集对象
    train_dataset, test_dataset, dev_dataset = build_datasets()
    # 实例化模型对象
    model = BertClassifier()  # 不需要选择设备, 不需要切换模型模式

    # 实例化模型参数对象
    train_args = TrainingArguments(output_dir=conf.output_dir,  # 输出目录
                                   num_train_epochs=conf.num_epochs,  # 训练轮数
                                   per_device_train_batch_size=conf.batch_size,  # 每卡/每设备训练batch
                                   per_device_eval_batch_size=conf.batch_size,  # 每卡/每设备验证batch
                                   warmup_steps=conf.warmup_steps,  # 学习率预热步数
                                   weight_decay=conf.weight_decay,  # 权重衰减
                                   learning_rate=conf.learning_rate,  # 学习率
                                   logging_dir=conf.logging_dir,  # 日志输出目录
                                   logging_steps=conf.logging_steps,  # 日志打印间隔(步)
                                   eval_strategy="steps",  # 评估触发模式(新参数名)  旧参数名:valuation_strategy
                                   eval_steps=conf.eval_steps,  # 评估间隔(步)
                                   save_strategy="steps",  # 保存模式
                                   save_steps=conf.save_steps,  # 保存间隔(步)
                                   save_total_limit=conf.save_total_limit,  # 最多保存几个模型
                                   load_best_model_at_end=True,  # 训练结束后自动恢复最优模型
                                   metric_for_best_model="eval_f1",  # 以何指标为"最优模型"
                                   greater_is_better=True)  # 指标越大越好

    # 实例化训练器对象 trainer对象
    trainer = Trainer(model=model,
                      tokenizer=conf.tokenizer,
                      args=train_args,
                      train_dataset=train_dataset,
                      eval_dataset=dev_dataset,
                      compute_metrics=compute_metrics,
                      data_collator=collate_fn)

    # 训练模型
    print("开始训练...")
    trainer.train()

    # 保存模型
    trainer.save_model(conf.model_save_path)
    print(f"模型已保存到: {conf.model_save_path}")

    # 模型测试
    print("在测试集上评估...")
    test_results = trainer.evaluate(test_dataset)
    print("测试集结果:")
    for key, value in test_results.items():
        print(f"{key}: {value:.4f}")


if __name__ == "__main__":
    # 主程序入口:调用训练主流程
    print("开始使用TrainingArguments和Trainer进行训练...")
    model2train()
    print("\n训练完成!")

开放模型推理

python 复制代码
import torch
from bert_classifier_model import BertClassifier
from config import Config

# 加载配置对象(Config类负责所有全局超参数、路径、类别名、分词器等)
conf = Config()
device = conf.device  # 设备(cuda/cpu)
tokenizer = conf.tokenizer  # BERT分词器


# 加载模型对象
# model = BertClassifier.from_pretrained(r'E:\TMF\code\04-bert\save_models\bertclassifier_model_20251017')
model = BertClassifier.from_pretrained(conf.model_save_path).to(conf.device)
model.eval()

# 封装推理函数
def predict(text):
    """
    :param text: {text:xxxxx}
    :return: {text:xxxx, pred_class:xxx}
    """
    # 获取文本数据 x
    text = text.get('text', "")
    # print('text--->\n', text)
    # 判断x数据类型 以及 是否为空, 返回预测值None
    if not isinstance(text, str) or not text.strip():
        return {'text': text, 'pred_class': None}
    # 调用分词器对象进行处理
    inputs = tokenizer.encode_plus(text, return_tensors='pt')
    # print('inputs--->\n', inputs)
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    # 模型预测
    with torch.no_grad():
        outputs = model(input_ids, attention_mask)
        # print('outputs--->\n', outputs)
        # isinstance:判断对象是否为指定类型
        # hasattr: 判断对象是否包含指定属性
        if isinstance(outputs, tuple) or hasattr(outputs, 'logits'):
            output =outputs.logits if hasattr(outputs, 'logits') else outputs[0]
        else:
            output = outputs
        # 获取预测下标
        pred_index = torch.argmax(output, dim=-1)
        # print('pred_index--->\n', pred_index)
        # 获取预测类别名称
        pred_class = conf.class_list[pred_index]
        # print('pred_class--->\n', pred_class)
        return {'text': text, 'pred_class': pred_class}


if __name__ == '__main__':
    # 测试示例
    sample_data = {"text": "中华女子学院:本科层次仅1专业招男生"}  # 示例输入数据
    result = predict(sample_data)
    print(result)

后端api服务

python 复制代码
from flask import Flask, request, jsonify
from predict_fun import predict
import warnings

warnings.filterwarnings('ignore')

# todo:1-创建app对象
app = Flask(__name__)


# todo:2-创建路由
@app.route('/predict', methods=['POST'])
def predict_api():
    # 获取前端数据
    data = request.get_json()
    print('data--->\n', data)
    # 判断是否有数据, 没有收集异常信息
    if not data or 'text' not in data:
        # 状态码: 2xx->请求成功 3xx->重定向 4xx->请求端报错 5xx->服务端报错
        return jsonify({'error': 'Missing text field in JSON'}), 400
    # 调用模型预测接口实现预测
    result = predict(data)
    print('result--->\n', result)
    # 返回json结果
    return jsonify(result)


if __name__ == '__main__':
    # 启动服务端
    app.run(host='0.0.0.0', port=8000, debug=True)

后端api测试

python 复制代码
# 不要求掌握
import requests
import time

# 定义预测接口地址
url = 'http://127.0.0.1:8000/predict'

# 构造请求数据
data = {'text': "中国人民公安大学2012年硕士研究生目录及书目"}

start_time = time.time()

try:
    # 发送post请求, 获取响应对象
    response = requests.post(url, json=data)
    print('response--->\n', response)
    # 耗时
    duration = (time.time() - start_time) * 1000  # ms
    print(f'耗时: {duration:.2f}ms')
    # 判断状态码是否为200, 如果是, 获取响应数据
    if response.status_code == 200:
        result = response.json()
        print('result--->\n', type(result), result)
        print('预测结果--->\n', result['pred_class'])
    # 如果不是, 获取错误信息
    else:
        error = response.json()['error']
        print(print(f"请求失败: {response.status_code}, {error}"))
except Exception as e:
    print(f"请求出错: {str(e)}")

前端服务

python 复制代码
import streamlit as st
import requests
import time

# todo:1-设置页面标题
st.title('文本分类系统')

# todo:2-创建输入框
data_text = st.text_area('请输入预测文本:', "中国人民公安大学2012年硕士研究生目录及书目")

# todo:3-创建预测按钮
if st.button('预测'):
    # todo:4-调用模型推理接口实现预测
    start_time = time.time()
    try:
        # 构造请求数据
        data = {'text': data_text}
        url = 'http://127.0.0.1:8000/predict'
        # 发送post请求, 获取响应对象
        response = requests.post(url, json=data)
        duration = (time.time() - start_time) * 1000
        # 判断状态码是否为200
        if response.status_code == 200:
            result = response.json()
            # todo:5-显示预测结果
            st.success(f"预测结果: {result['pred_class']}")
            st.info(f"请求耗时: {duration:.2f}ms")
        else:
            st.error(f"请求失败: {response.json()['error']}")
    except Exception as e:
        st.error(f"请求出错: {str(e)}")

# todo:6-页面提示内容
st.write("请确保 Flask API 服务已在 localhost:8000 运行")
相关推荐
Gale2World2 小时前
专题九:【终局演进】从“单体网关”到去中心化集群:分布式数字员工(Swarm)的宏大涌现
人工智能·agent
天天代码码天天2 小时前
C# OnnxRuntime BEN2 前景分割
人工智能
moers2 小时前
从cosh到AgentSecCore:拆解阿里云Agentic OS的四个技术决策
人工智能
饼干哥哥2 小时前
RPA也被AI干死了!!一键生成监听100个小红书博主的工作流
人工智能
前端付豪2 小时前
实现消息级操作栏
前端·人工智能·后端
HarryPoint2 小时前
🔥Claude Code 源码分析报告新鲜出炉了
人工智能
Clavis2 小时前
我给 Mac 的 Photo Booth 写了自动化脚本。为什么隐私比你想的重要得多
人工智能·python
AI问答工程师2 小时前
从"检索一次就完事"到"Agent 自主决策":Agentic RAG 架构深度解析与实战
人工智能
Codebee2 小时前
当软件从"工具"进化为"伙伴"ooderAgent 产品设计解析
人工智能