基于Bert的模型迁移文本分类项目

一个完整的 BERT 文本分类系统,涵盖数据加载、模型训练、验证评估、模型保存、API 部署和前端展示。代码采用模块化设计,支持多卡训练(accelerate),每 100 个 batch 验证一次并保存最优模型。后续计划加入 TensorBoard 日志、单元测试和 Docker 部署
config文件实现文件统一分发,包括文件路径和bert模型超参以及bert模型的

bert_model 预训练bert模型

tokenizer bert分词器

bert_config bert模型配置

hidden_size bert模型输出


utils实现文件加载封装,dataset封装,dataloader封装,dataloder自定义方法处理


BertClassifier类实现bert模型➕输出头结构全参微调


模型训练与评估实现

dataloader加载

模型加载

optimizer优化器加载

crossEntropyLoss损失函数加载

accelerator实现训练加速,实现多卡并行

模型训练

前向传播

损失计算

梯度清零

反向传播

参数更新

调用验证函数实现模型验证 并保存验证后优秀的模型

计算评估指标,打印评估报告


封装推理方法

基于flask实现后端接口api服务

基于streamlit 实现前端服务

Bert模型介绍

python 复制代码
"""
bert的模型结构:(输入样本数 batch_size为2 )
=========================================================================================================
Layer (type:depth-idx)                                  Output Shape              Param #
=========================================================================================================
BertModel                                               [2, 768]                  --
├─BertEmbeddings: 1-1                                   [2, 128, 768]             --
│    └─Embedding: 2-1                                   [2, 128, 768]             16,226,304
│    └─Embedding: 2-2                                   [2, 128, 768]             1,536
│    └─Embedding: 2-3                                   [1, 128, 768]             393,216
│    └─LayerNorm: 2-4                                   [2, 128, 768]             1,536
│    └─Dropout: 2-5                                     [2, 128, 768]             --
├─BertEncoder: 1-2                                      [2, 128, 768]             --
│    └─ModuleList: 2-6                                  --                        --
│    │    └─BertLayer: 3-1                              [2, 128, 768]             7,087,872
│    │    └─BertLayer: 3-2                              [2, 128, 768]             7,087,872
│    │    └─BertLayer: 3-3                              [2, 128, 768]             7,087,872
│    │    └─BertLayer: 3-4                              [2, 128, 768]             7,087,872
│    │    └─BertLayer: 3-5                              [2, 128, 768]             7,087,872
│    │    └─BertLayer: 3-6                              [2, 128, 768]             7,087,872
│    │    └─BertLayer: 3-7                              [2, 128, 768]             7,087,872
│    │    └─BertLayer: 3-8                              [2, 128, 768]             7,087,872
│    │    └─BertLayer: 3-9                              [2, 128, 768]             7,087,872
│    │    └─BertLayer: 3-10                             [2, 128, 768]             7,087,872
│    │    └─BertLayer: 3-11                             [2, 128, 768]             7,087,872
│    │    └─BertLayer: 3-12                             [2, 128, 768]             7,087,872
├─BertPooler: 1-3                                       [2, 768]                  --
│    └─Linear: 2-7                                      [2, 768]                  590,592
│    └─Tanh: 2-8                                        [2, 768]                  --
=========================================================================================================
"""

config文件统一分发

config文件实现

文件路径 / 模型超参数 / bert_model / tokenizer / bert_config / hidden_size

等统一管理与分发

python 复制代码
import torch
import datetime
from transformers.models import BertModel, BertTokenizer, BertConfig

# 获取当前日期
# current_date = datetime.datetime.now().date().strftime("%Y%m%d")
# print('current_date--->', current_date)


class Config(object):
    def __init__(self):
        """
        配置类,包含模型和训练所需的各种参数。
        """
        self.model_name = "bert"  # 模型名称
        self.data_path = "../../01-data"  # 数据集的根路径
        self.train_path = self.data_path + "/train.txt"  # 训练集
        self.dev_path = self.data_path + "/dev3.txt"  # 少量验证集,快速验证
        self.test_path = self.data_path + "/test.txt"  # 测试集

        self.class_path = self.data_path + "/class.txt"  # 类别文件
        self.class_list = [line.strip() for line in open(self.class_path, 'r', encoding='utf-8')]

        self.model_save_path = "../save_models/bertclassifier_model.pt"  # 模型训练结果保存路径

        # 模型训练+预测的时候
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 训练设备,如果GPU可用,则为cuda,否则为cpu

        self.num_classes = len(self.class_list)  # 类别数
        self.num_epochs = 2  # epoch数
        self.batch_size = 32  # mini-batch大小
        self.pad_size = 32  # 每句话处理成的长度(短填长切)
        self.learning_rate = 5e-5  # 学习率
        self.bert_path = "../bert-base-chinese"  # 预训练BERT模型的路径
        self.bert_model = BertModel.from_pretrained(self.bert_path)
        self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
        self.bert_config = BertConfig.from_pretrained(self.bert_path)
        self.hidden_size = 768  # BERT模型的隐藏层大小
        # self.hidden_size = self.bert_config.hidden_size  # BERT模型的隐藏层大小


if __name__ == '__main__':
    conf = Config()
    print(conf.class_list)
    print(conf.bert_config)
    inputs = conf.tokenizer.convert_tokens_to_ids(["你", "好", "中国", "人"])
    print(inputs)

utils封装

数据加载load_data方法封装

TextDataset 构建dataset类封装

build_dataloader数据集构建方法封装

collate_fn构建数据加载自定义函数封装

python 复制代码
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from config import Config

# 实例化config类对象
conf = Config()


# todo:加载数据集
def load_data(path):
    """
    加载数据集, 进行格式转换
    :param path: 原始文件路径
    :return: [(句子1, 标签1), (句子2, 标签2), ...]
    """
    # todo:1-初始化空列表
    data_list = []
    # todo:2-加载数据集
    with open(path, 'r', encoding='utf-8') as f:
        # todo:3-按行处理数据
        for line in tqdm(f, desc='加载数据...'):
            # 去掉末尾换行符
            line = line.strip()
            # print('line--->\n', line)
            # 如果line为空, 跳出当前循环
            if not line:
                continue
            # 使用\t分割符进行分割处理
            # 返回列表, 进行列表拆包操作
            text, label = line.split('\t')
            # print('text--->\n', text)
            # print('label--->\n', label)
            # 将句子和标签以元组形式保存到列表中
            data_list.append((text, int(label)))

    return data_list


# todo:构建dataset类
class TextDataset(Dataset):
    # todo:1-init初始化方法
    def __init__(self, data):
        self.data = data

    # todo:2-len方法
    def __len__(self):
        return len(self.data)

    # todo:3-getitem方法
    def __getitem__(self, item):
        # 获取当前行样本的x和y部分
        x = self.data[item][0]
        # print('x--->\n', x)
        y = self.data[item][1]
        # print('y--->\n', y)
        return x, y


# todo:构建数据加载, 自定义函数
def collate_fn(batch):
    # print('batch--->\n', batch)
    # 获取批次的x和y数据保存到对应列表中
    texts = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    # print('texts--->\n', texts)
    # print('labels--->\n', labels)

    # 通过分词器对象对x进行数据处理
    inputs = conf.tokenizer(texts, padding=True, return_tensors='pt')
    # print('inputs--->\n', inputs)
    input_ids = inputs['input_ids'].to(conf.device)
    attention_mask = inputs['attention_mask'].to(conf.device)

    # 对y转换成张量对象
    labels = torch.tensor(labels, device=conf.device)

    # 返回x和y张量对象
    return input_ids, attention_mask, labels


def build_dataloader():
    # 加载数据集
    train_data = load_data(conf.train_path)
    test_data = load_data(conf.test_path)
    dev_data = load_data(conf.dev_path)
    # print(train_data[:10])
    # print(test_data[:10])
    # print(dev_data[:10])

    # 实例化dataset对象
    train_dataset = TextDataset(train_data)
    # print('train_dataset--->', train_dataset)
    # print(len(train_dataset))
    # print(train_dataset[0])
    test_dataset = TextDataset(test_data)
    dev_dataset = TextDataset(dev_data)

    # 实例化数据加器对象
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=conf.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_fn)
    test_dataloader = DataLoader(dataset=test_dataset,
                                 batch_size=conf.batch_size,
                                 shuffle=False,
                                 collate_fn=collate_fn)
    dev_dataloader = DataLoader(dataset=dev_dataset,
                                batch_size=conf.batch_size,
                                shuffle=False,
                                collate_fn=collate_fn)
    return train_dataloader, test_dataloader, dev_dataloader


if __name__ == '__main__':
    train_dataloader, test_dataloader, dev_dataloader = build_dataloader()
    # 循环遍历数据加载对象
    for input_ids, attention_mask, labels in train_dataloader:
        print('input_ids--->\n', input_ids)
        print('attention_mask--->\n', attention_mask)
        print('labels--->\n', labels)
        exit()

Bert模型➕输出头结构全参微调

python 复制代码
# 模型搭建: bert训练模型结构+输出头结构
# 全参微调(适用于小模型) -> 修改bert预训练模型的所有参数
import torch
import torch.nn as nn
from config import Config
from utils import build_dataloader

# 实例化config类对象
config = Config()


# 创建自定义模型类
class BertClassifier(nn.Module):
    # init方法
    def __init__(self):
        super().__init__()  # 调用父类方法
        # 预训练模型bert结构
        self.bert = config.bert_model
        # 下游任务的输出层结构
        # in_features: 上一层(预训练模型)的输出维度
        # out_features: 类别数
        self.fc = nn.Linear(in_features=config.hidden_size,
                            out_features=config.num_classes)

    # forward方法
    def forward(self, input_ids, attention_mask):
        # 获取bert预训练模型的输出(特征提取/语义向量表示) 全参微调,不进行冻结
        # return_dict=False: 返回一个元组(last_hidden_state, pooler_output)
        # pooler_output: (batch_size, hidden_dim) -> 对句子中最后一层隐层cls值进行了一次池化处理(线性映射) 句子语义表示
        last_hidden_output, pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask,
                                                      return_dict=False)
        # print('last_hidden_output--->\n', last_hidden_output.shape, last_hidden_output)
        # print('pooled_output--->\n', pooled_output.shape, pooled_output)

        # 计算文本类别预测结果
        # output: (32, 10)
        output = self.fc(pooled_output)
        # print('output--->\n', output.shape, output)
        return output


if __name__ == '__main__':
    # 实例化数据加载器对象
    train_dataloader, test_dataloader, dev_dataloader = build_dataloader()

    # 实例化模型对象
    model = BertClassifier().to(config.device)
    model.train()
    # print('model--->\n', model)

    # 循环遍历数据加载器对象
    for input_dis, attention_mask, labels in train_dataloader:
        # 调用模型进行进行训练
        output = model(input_dis, attention_mask)
        # 预测标签下标
        pred_labels = torch.argmax(input=output, dim=-1)
        print('pred_labels--->\n', pred_labels)
        print('labels--->\n', labels)
        exit()

模型训练与评估

python 复制代码
import torch
import torch.nn as nn
from torch.optim import AdamW
# 评估指标 分类报告 f1分数 准确率 精确率 召回率
from sklearn.metrics import classification_report, f1_score, accuracy_score, precision_score, recall_score
from tqdm import tqdm
from config import Config
from utils import build_dataloader
from bert_classifier_model import BertClassifier
from accelerate import Accelerator
# 忽略的警告信息
import warnings

warnings.filterwarnings("ignore")

# 实例化config类对象
config = Config()


# todo:1-训练函数
def model2train():
    # 构建数据加载器对象
    train_dataloader, test_dataloader, dev_dataloader = build_dataloader()

    # 获取config对象的属性
    epochs = config.num_epochs  # 训练轮次
    device = config.device  # 设备
    learning_rate = config.learning_rate  # 学习率
    model_save_path = config.model_save_path  # 模型保存路径

    accelerator = Accelerator()

    # 实例化自定义模型对象
    model = BertClassifier().to(device)
    model.train()

    # 实例化优化器 损失器
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    train_dataloader, eval_dataloader, model, optimizer = accelerator.prepare(train_dataloader,
                                                                              dev_dataloader,
                                                                              model,
                                                                              optimizer)

    # 模型训练
    # 初始化最佳模型的f1分数, 默认为0
    best_dev_f1 = 0.0
    # 双层循环
    for epoch in range(epochs):
        total_loss = 0.0
        total_iters = 0
        # 预测标签和真实标签存储列表
        pred_labels_list, true_labels_list = [], []
        for batch, (input_ids, attention_mask, labels) in tqdm(enumerate(train_dataloader, start=1),
                                                               desc=f"Bert Classifier Training Epoch {epoch + 1}/{epochs}...."):
            # 前向传播
            pred_output = model(input_ids, attention_mask)
            # print('pred_output--->\n', pred_output.shape, pred_output)

            # 损失计算
            loss = criterion(pred_output, labels)
            # print('loss--->\n', loss)
            total_loss += loss.item()  # 累加损失
            total_iters += 1  # 累加批次数
            avg_loss = total_loss / total_iters  # 平均损失

            # 梯度清零
            optimizer.zero_grad()
            # 反向传播
            # loss.backward()
            accelerator.backward(loss)
            # 参数更新
            optimizer.step()

            # 获取预测标签下标
            pred_labels = pred_output.argmax(dim=-1)
            # print('pred_labels--->\n', pred_labels)
            # 将预测标签下标和真实标签下标保存到列表中
            pred_labels_list.extend(pred_labels.tolist())
            true_labels_list.extend(labels.tolist())
            # print('pred_labels_list--->\n', pred_labels_list)
            # print('true_labels_list--->\n', true_labels_list)

            # 打印训练信息
            if batch % 100 == 0:
                print(f"Epoch {epoch + 1}/{epochs}")
                print(f"Train Loss: {avg_loss:.4f}")
                # 调用验证函数实现模型验证
                report, f1score, accuracy, precision = model2dev(model, dev_dataloader)
                print(f"Dev f1score: {f1score}")
                print(f"Dev accuracy: {accuracy}")

                # 保存模型, 基于最高f1分数进行保存
                if f1score > best_dev_f1:
                    # 更新最佳f1分数
                    best_dev_f1 = f1score
                    torch.save(model.state_dict(), model_save_path)
                    print(f"Saved model to {model_save_path}")

        # 打印每轮分类评估报告
        train_report = classification_report(true_labels_list, pred_labels_list, labels=config.class_list, output_dict=True)
        print('train_report--->\n', train_report)


# todo:2-验证函数, 一边训练一边验证模型效果
def model2dev(model: BertClassifier, dataloader):
    # 模型切换成推理模式
    model.eval()
    # 准备两个列表, 保存预测标签和真实标签
    pred_labels_list, true_labels_list = [], []
    # 循环遍历集数据加载器对象
    for input_ids, attention_mask, labels in tqdm(dataloader, desc="Bert Classifier Evaluating..."):
        with torch.no_grad():
            # 模型预测
            logits = model(input_ids, attention_mask)
            # print('logits--->\n', logits.shape, logits)
            # 获取预测标签下标
            pred_labels = torch.argmax(logits, dim=-1)
            # 将预测标签下标和真实标签下标保存到列表中
            pred_labels_list.extend(pred_labels.tolist())
            true_labels_list.extend(labels.tolist())

    # 计算评估指标
    report = classification_report(true_labels_list, pred_labels_list)
    f1score = f1_score(true_labels_list, pred_labels_list, average='micro')
    accuracy = accuracy_score(true_labels_list, pred_labels_list)
    precision = precision_score(true_labels_list, pred_labels_list, average='micro')
    # 返回评估指标
    return report, f1score, accuracy, precision


if __name__ == '__main__':
    model2train()

    # 1. 加载测试集数据
    train_dataloader, test_dataloader, dev_dataloader = build_dataloader()
    # 2. 初始化 BERT 分类模型
    model = BertClassifier()
    # 3. 加载预训练模型权重
    model.load_state_dict(torch.load(config.model_save_path))
    # 4. 将模型移动到指定设备
    model.to(config.device)
    # 5. 在测试集上评估模型
    test_report, f1score, accuracy, precision = model2dev(model, test_dataloader)
    # 6. 打印测试集评估结果
    print("Test Set Evaluation:")
    print(f"Test F1: {f1score:.4f}")
    print("Test Classification Report:")
    print(test_report)

模型推理func封装

python 复制代码
import torch
from bert_classifier_model import BertClassifier
from config import Config
import time

# 初始化配置
conf = Config()
device = conf.device
tokenizer = conf.tokenizer
model_save_path = conf.model_save_path

# 实例化模型对象
model = BertClassifier().to(device)
# 加载最优模型
model.load_state_dict(torch.load(model_save_path))
model.eval()


# 推理函数
def predict(data):
    """
    :param data: dict类型 {text: xxxxxxxxx}
    :return: dict类型 {text: xxxxxxxxx, pred_class: xxx}
    """
    # 获取文本数据 原始x
    text = data['text']
    # print('text--->\n', text)
    if not text.strip():
        print('文本为空')
        return {'text': text, 'pred_class': None}

    # 通过分词器进行数据处理
    inputs = tokenizer(text, return_tensors='pt')
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    # 调用模型进行预测
    with torch.no_grad():
        # 开始时间
        start_time = time.time()
        # 模型预测
        logits = model(input_ids, attention_mask)
        # print('logits--->\n', logits)
        # 获取预测标签下标
        pred_label = torch.argmax(logits, dim=-1)
        # print('pred_label--->\n', pred_label)
        # 获取标签名
        pred_class = conf.class_list[pred_label.item()]
        # print('pred_class--->\n', pred_class)
        print('预测耗时--->\n', (time.time() - start_time) * 1000)
    return {'text': text, 'pred_class': pred_class}


if __name__ == '__main__':
    # 测试输入
    sample_data = {"text": "中华女子学院:本科层次仅1专业招男生"}
    result = predict(sample_data)
    print('result--->\n', result)

基于flask的后端接口api

python 复制代码
from flask import Flask, request, jsonify
from predict_fun import predict
import warnings

warnings.filterwarnings('ignore')

# todo:1-创建app对象
app = Flask(__name__)


# todo:2-创建路由
@app.route('/predict', methods=['POST'])
def predict_api():
    # 获取前端数据
    data = request.get_json()
    print('data--->\n', data)
    # 判断是否有数据, 没有收集异常信息
    if not data or 'text' not in data:
        # 状态码: 2xx->请求成功 3xx->重定向 4xx->请求端报错 5xx->服务端报错
        return jsonify({'error': 'Missing text field in JSON'}), 400
    # 调用模型预测接口实现预测
    result = predict(data)
    print('result--->\n', result)
    # 返回json结果
    return jsonify(result)


if __name__ == '__main__':
    # 启动服务端
    app.run(host='0.0.0.0', port=8000, debug=True)

flask api测试

python 复制代码
# 不要求掌握
import requests
import time

# 定义预测接口地址
url = 'http://127.0.0.1:8000/predict'

# 构造请求数据
data = {'text': "中国人民公安大学2012年硕士研究生目录及书目"}

start_time = time.time()

try:
    # 发送post请求, 获取响应对象
    response = requests.post(url, json=data)
    print('response--->\n', response)
    # 耗时
    duration = (time.time() - start_time) * 1000  # ms
    print(f'耗时: {duration:.2f}ms')
    # 判断状态码是否为200, 如果是, 获取响应数据
    if response.status_code == 200:
        result = response.json()
        print('result--->\n', type(result), result)
        print('预测结果--->\n', result['pred_class'])
    # 如果不是, 获取错误信息
    else:
        error = response.json()['error']
        print(print(f"请求失败: {response.status_code}, {error}"))
except Exception as e:
    print(f"请求出错: {str(e)}")

基于streamlit前端服务

python 复制代码
import streamlit as st
import requests
import time

# todo:1-设置页面标题
st.title('文本分类系统')

# todo:2-创建输入框
data_text = st.text_area('请输入预测文本:', "中国人民公安大学2012年硕士研究生目录及书目")

# todo:3-创建预测按钮
if st.button('预测'):
    # todo:4-调用模型推理接口实现预测
    start_time = time.time()
    try:
        # 构造请求数据
        data = {'text': data_text}
        url = 'http://127.0.0.1:8000/predict'
        # 发送post请求, 获取响应对象
        response = requests.post(url, json=data)
        duration = (time.time() - start_time) * 1000
        # 判断状态码是否为200
        if response.status_code == 200:
            result = response.json()
            # todo:5-显示预测结果
            st.success(f"预测结果: {result['pred_class']}")
            st.info(f"请求耗时: {duration:.2f}ms")
        else:
            st.error(f"请求失败: {response.json()['error']}")
    except Exception as e:
        st.error(f"请求出错: {str(e)}")

# todo:6-页面提示内容
st.write("请确保 Flask API 服务已在 localhost:8000 运行")

项目细节

相关推荐
寒月小酒2 小时前
3.29+3.30
数据结构·算法
杜子不疼.2 小时前
高并发场景下 Spring MVC + 虚拟线程 vs WebFlux 选型对比
java·人工智能·spring·mvc
ZoeJoy82 小时前
算法筑基(六):分治算法——大事化小,小事化了
算法·排序算法·动态规划·哈希算法·图搜索算法
新加坡内哥谈技术2 小时前
AI代理可能会让自由软件再次变得重要
人工智能·ai编程
美式请加冰2 小时前
BFS算法(下)
算法·宽度优先
少许极端2 小时前
算法奇妙屋(三十七)-贪心算法学习之路4
学习·算法·贪心算法·田忌赛马
Fleshy数模2 小时前
从零实现Word2Vec之CBOW模型:理解词向量的核心原理
人工智能·自然语言处理·word2vec
We་ct2 小时前
LeetCode 373. 查找和最小的 K 对数字:题解+代码详解
前端·算法·leetcode·typescript·二分·
Ricky_Theseus2 小时前
探索群体智慧:蚁群算法(ACO)从原理到实践——python实现
python·算法·机器学习