基于BERT的文本分类模型训练全流程:从环境搭建到显存优化实战

作为NLP领域的经典任务,文本分类是入门Transformer模型微调的绝佳实践。本文将以二分类任务为例,详细记录基于BERT预训练模型的微调全流程,从环境配置、数据准备到模型训练、显存优化,带你踩平所有坑点,最终实现一个高性能的文本分类模型。

一、任务背景与技术选型

本次任务聚焦文本二分类 (例如情感分析、垃圾邮件识别等),选择BERT预训练模型进行微调------这是NLP领域"预训练+微调"范式的典型应用,既能复用BERT在大规模语料上学到的语言知识,又能快速适配特定任务。

技术栈

  • 框架:PyTorch + Hugging Face Transformers/Datasets
  • 模型:BERT-base(或自定义预训练模型)
  • 优化:AdamW优化器 + 混合精度训练 + 显存碎片优化

二、环境搭建与依赖安装

首先搭建稳定的训练环境,推荐使用conda管理环境:

bash 复制代码
# 创建并激活虚拟环境
conda create -n torch_env python=3.10
conda activate torch_env

# 安装核心依赖
pip install torch transformers datasets pandas

若需使用GPU加速,需额外安装对应版本的CUDA和cuDNN(可参考PyTorch官方安装指南)。

三、数据准备:Hugging Face Dataset格式解析

本次实验使用本地自定义数据集 ,需遵循Hugging Face datasets的格式规范:

  • 数据集包含train/test/validation拆分
  • 每个样本需包含text(输入文本)和label(0/1标签)字段

将数据按如下结构存放(与项目路径保持一致):

复制代码
Remote/
├── data/
│   ├── train-00000-of-00001-xxx.parquet  # 训练集
│   ├── test-00000-of-00001-xxx.parquet   # 测试集
│   └── validation-00000-of-00001-xxx.parquet  # 验证集
├── mode/  # BERT预训练模型文件(config.json、pytorch_model.bin等)
└── train.py  # 训练脚本

通过datasets.load_dataset加载本地数据:

python 复制代码
from datasets import load_dataset

# 加载本地数据集
hf_train_dataset = load_dataset("./Remote/data", split="train")
print(f"数据集规模:{len(hf_train_dataset)}条样本")

四、模型构建:BERT + 分类头的经典架构

我们基于Hugging Face的BertModel构建自定义分类模型,核心逻辑是复用BERT的预训练特征,添加单层全连接层做二分类

python 复制代码
import torch
from transformers import BertModel, AutoTokenizer

class BertClassifier(torch.nn.Module):
    def __init__(self, model_name="./Remote/mode"):
        super().__init__()
        # 加载预训练BERT模型
        self.bert = BertModel.from_pretrained(model_name)
        # 分类头:将BERT的768维输出映射到2个类别
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 2)
    
    def forward(self, **kwargs):
        # BERT前向传播,获取[CLS] token的池化输出
        outputs = self.bert(**kwargs)
        pooled_output = outputs.pooler_output
        # 分类头输出logits
        logits = self.classifier(pooled_output)
        return logits

# 初始化Tokenizer和模型
tokenizer = AutoTokenizer.from_pretrained("./Remote/mode")
model = BertClassifier().to(device)

五、训练全流程:从数据加载到模型收敛

训练流程的核心是数据批量处理、设备一致性保证、梯度更新三个环节,以下是关键步骤的代码与解析。

5.1 数据集封装与批量编码

为适配PyTorch的DataLoader,需自定义数据集类,并在collate_fn中完成文本编码和设备迁移:

python 复制代码
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):
    def __init__(self, raw_dataset):
        self.raw_dataset = raw_dataset
    
    def __len__(self):
        return len(self.raw_dataset)
    
    def __getitem__(self, idx):
        return {
            "text": self.raw_dataset[idx]["text"],
            "label": self.raw_dataset[idx]["label"]
        }

# 构建DataLoader,关键在于collate_fn中完成编码和设备迁移
train_loader = DataLoader(
    TextDataset(hf_train_dataset),
    batch_size=16,
    shuffle=True,
    collate_fn=lambda batch: (
        {k: v.to(device) for k, v in tokenizer.batch_encode_plus(
            [item["text"] for item in batch],
            padding="longest",
            truncation=True,
            return_tensors="pt"
        ).items()},
        torch.tensor([item["label"] for item in batch], dtype=torch.long).to(device)
    )
)

5.2 训练循环与梯度更新

使用AdamW优化器(Transformer微调的首选),结合混合精度训练提升效率:

python 复制代码
from torch.cuda.amp import autocast, GradScaler

# 初始化损失函数、优化器、梯度缩放器
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=1e-5)
scaler = GradScaler()  # 混合精度训练用

# 训练循环
num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for batch_idx, (encoding, labels) in enumerate(train_loader, 1):
        with autocast():  # 启用混合精度
            logits = model(**encoding)
            loss = criterion(logits, labels)
        
        # 梯度更新(混合精度版本)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        if batch_idx % 10 == 0:
            avg_loss = total_loss / batch_idx
            print(f"Epoch {epoch+1}/{num_epochs} | Batch {batch_idx} | Avg Loss: {avg_loss:.4f}")
    
    # 每个Epoch后清理显存,防止碎片累积
    torch.cuda.empty_cache()

# 保存模型
torch.save({
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "epoch": num_epochs
}, "Remote/qing_model.pth")

六、踩坑实录:显存不足与设备不一致的解决方案

训练过程中最常见的两类问题及解决方案如下:

6.1 显存不足(CUDA out of memory

  • 直接方案 :减小batch_size(如从16改为8)。
  • 进阶方案
    • 启用混合精度训练(如上述代码中的autocastGradScaler),显存占用直接减半。

    • 梯度累积:通过accumulation_steps模拟大batch,代码示例:

      python 复制代码
      accumulation_steps = 2
      for batch_idx, (encoding, labels) in enumerate(train_loader, 1):
          with autocast():
              logits = model(**encoding)
              loss = criterion(logits, labels) / accumulation_steps
          
          loss.backward()
          if batch_idx % accumulation_steps == 0:
              optimizer.step()
              optimizer.zero_grad()

6.2 设备不一致(Expected all tensors to be on the same device

错误根源是模型、数据、标签不在同一设备 ,解决方案是在collate_fn中显式将所有张量移至目标设备:

python 复制代码
collate_fn=lambda batch: (
    {k: v.to(device) for k, v in tokenizer.batch_encode_plus(...).items()},
    torch.tensor(...).to(device)
)

七、优化技巧:让训练更高效、更稳定

除了上述问题解决,这些优化技巧能进一步提升训练体验:

  1. 提前加载Tokenizer :避免在collate_fn中重复加载,减少IO开销:

    python 复制代码
    tokenizer = AutoTokenizer.from_pretrained(model_name)  # 全局提前加载
  2. 梯度裁剪 :防止梯度爆炸,在loss.backward()后添加:

    python 复制代码
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
  3. 学习率调度 :使用ReduceLROnPlateau根据损失自动调整学习率:

    python 复制代码
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
    # 训练循环中调用
    scheduler.step(avg_loss)

八、总结与延伸

本次基于BERT的文本分类训练实战,覆盖了从环境搭建、数据处理、模型构建到训练优化的全流程,核心收获包括:

  • 掌握Transformer模型微调的经典范式:预训练模型 + 分类头
  • 解决NLP训练中最常见的"显存不足"和"设备不一致"问题。
  • 学会混合精度训练、梯度累积等高级优化技巧,提升训练效率。

若需进一步提升模型性能,可尝试:

  • 更换更大的预训练模型(如bert-large)或领域适配的预训练模型(如医疗领域的BioBERT)。
  • 引入数据增强技术(如回译、同义词替换),提升模型泛化能力。
  • 尝试知识蒸馏,将大模型的知识迁移到小模型,平衡性能与效率。

核心代码如下:

train.py

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, AutoTokenizer
from datasets import load_dataset
import os

# -------------------------- 全局配置 --------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")  # 新增:打印当前使用的设备

# 动态路径拼接(确保与项目根目录一致)
current_script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_script_dir)
model_name = os.path.join(project_root, "Remote", "mode")
hf_dataset_dir = os.path.join(project_root, "Remote", "data")

batch_size = 8
num_epochs = 3
learning_rate = 2e-5

# 提前加载Tokenizer(避免每次batch重复加载,提升效率)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# -------------------------- 自定义数据集 --------------------------
class MineDataset(Dataset):
 def __init__(self, raw_dataset):
     self.raw_dataset = raw_dataset
 
 def __len__(self):
     return len(self.raw_dataset)
 
 def __getitem__(self, idx):
     return {
         "text": self.raw_dataset[idx]["text"],
         "label": self.raw_dataset[idx]["label"]
     }

# -------------------------- 模型定义 --------------------------
class MineModel(nn.Module):
 def __init__(self, model_name=model_name):
     super().__init__()
     self.bert = BertModel.from_pretrained(model_name)
     self.classifier = nn.Linear(self.bert.config.hidden_size, 2)  # 二分类任务
 
 def forward(self, **kwargs):
     # 简化:直接传递所有编码后的参数(input_ids/attention_mask/token_type_ids)
     # BertModel原生支持这些参数,无需过滤
     outputs = self.bert(**kwargs)
     pooled_output = outputs.pooler_output  # 使用BERT的池化输出
     logits = self.classifier(pooled_output)
     return logits

# -------------------------- 数据加载 --------------------------
print("Starting training preparation...")
print(f"Model path: {model_name}")
print(f"Dataset directory: {hf_dataset_dir}")

# 1. 加载本地数据集(确保数据集格式符合Hugging Face datasets规范)
hf_train_dataset = load_dataset(hf_dataset_dir, split="train")
print(f"Loaded {len(hf_train_dataset)} training samples.")

# 2. 封装自定义数据集
train_dataset = MineDataset(hf_train_dataset)

# 3. 构建DataLoader(核心修复:collate_fn中统一设备)
train_loader = DataLoader(
 train_dataset,
 batch_size=batch_size,
 shuffle=True,
 collate_fn=lambda batch: (
     # 编码文本并将所有张量移到目标设备
     {k: v.to(device) for k, v in tokenizer.batch_encode_plus(
         [item["text"] for item in batch],
         padding="longest",  # 按批次最长文本填充
         truncation=True,    # 截断超过模型最大长度的文本
         return_tensors="pt" # 返回PyTorch张量
     ).items()},
     # 标签张量移到目标设备
     torch.tensor([item["label"] for item in batch], dtype=torch.long).to(device)
 )
)

# -------------------------- 初始化组件 --------------------------
# 模型移到目标设备
model = MineModel().to(device)
# 损失函数(二分类交叉熵)
criterion = nn.CrossEntropyLoss().to(device)  # 移到设备(非必须,但统一更稳妥)
# 优化器(AdamW是Transformer模型常用优化器)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)  # 新增权重衰减,防止过拟合

# -------------------------- 训练循环 --------------------------
print("\nStarting training...")
for epoch in range(num_epochs):
 model.train()  # 开启训练模式(启用Dropout等)
 total_loss = 0.0
 total_steps = len(train_loader)
 
 for batch_idx, (encoding, labels) in enumerate(train_loader, 1):
     # 验证设备(可选,用于调试)
     # print(f"Batch {batch_idx}: Input device={encoding['input_ids'].device}, Labels device={labels.device}")
     
     # 前向传播
     logits = model(**encoding)  # 直接传递编码后的所有参数
     loss = criterion(logits, labels)  # 计算损失
     
     # 反向传播+参数更新
     optimizer.zero_grad()  # 清空梯度
     loss.backward()        # 计算梯度
     optimizer.step()       # 更新参数
     
     # 累计损失
     total_loss += loss.item()
     
     # 打印批次信息(每10个批次打印一次,避免输出过多)
     if batch_idx % 10 == 0 or batch_idx == total_steps:
         avg_loss = total_loss / batch_idx
         print(f"Epoch [{epoch+1}/{num_epochs}] | Batch [{batch_idx}/{total_steps}] | Loss: {loss.item():.4f} | Avg Loss: {avg_loss:.4f}")

# -------------------------- 保存模型 --------------------------
# 保存模型状态字典(推荐,占用空间小,灵活度高)
save_path = "qing_model.pth"
torch.save({
 "model_state_dict": model.state_dict(),
 "optimizer_state_dict": optimizer.state_dict(),
 "epoch": num_epochs,
 "total_loss": total_loss
}, save_path)

print(f"\nTraining completed! Model saved to: {save_path}")

test_model.py

python 复制代码
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, BertModel
from datasets import load_dataset
import os

# -------------------------- 全局配置 --------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 项目路径配置(与训练代码保持一致)
current_script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_script_dir)
model_path = os.path.join(project_root, "Remote","qing_model.pth")  # 训练好的模型路径
model_name = os.path.join(project_root, "Remote", "mode")   # BERT模型路径
hf_dataset_dir = os.path.join(project_root, "Remote", "data")  # 数据集路径

# 提前加载Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# -------------------------- 模型定义(与训练代码完全一致) --------------------------
class MineModel(torch.nn.Module):
  def __init__(self, model_name=model_name):
      super().__init__()
      self.bert = BertModel.from_pretrained(model_name)
      self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 2)  # 二分类
  
  def forward(self, **kwargs):
      outputs = self.bert(**kwargs)
      pooled_output = outputs.pooler_output
      logits = self.classifier(pooled_output)
      return logits

# -------------------------- 自定义测试数据集(与训练代码逻辑一致) --------------------------
class TestDataset(Dataset):
  def __init__(self, raw_dataset):
      self.raw_dataset = raw_dataset
  
  def __len__(self):
      return len(self.raw_dataset)
  
  def __getitem__(self, idx):
      return {
          "text": self.raw_dataset[idx]["text"],
          "label": self.raw_dataset[idx]["label"]
      }

# -------------------------- 数据加载与预处理 --------------------------
def load_test_data():
  # 加载测试数据集
  hf_test_dataset = load_dataset(hf_dataset_dir, split="test")
  print(f"Loaded {len(hf_test_dataset)} test samples.")
  
  # 封装为自定义Dataset
  test_dataset = TestDataset(hf_test_dataset)
  
  # 构建DataLoader(与训练时的collate_fn逻辑一致,确保设备一致)
  test_loader = DataLoader(
      test_dataset,
      batch_size=8,  # 测试时batch_size可适当减小,降低显存压力
      shuffle=False,  # 测试时无需打乱
      collate_fn=lambda batch: (
          {k: v.to(device) for k, v in tokenizer.batch_encode_plus(
              [item["text"] for item in batch],
              padding="longest",
              truncation=True,
              return_tensors="pt"
          ).items()},
          torch.tensor([item["label"] for item in batch], dtype=torch.long).to(device)
      )
  )
  return test_loader

# -------------------------- 模型测试与评估 --------------------------
def test_model():
  # 初始化模型并加载权重
  model = MineModel().to(device)
  model.load_state_dict(torch.load(model_path)["model_state_dict"])
  model.eval()  # 切换为评估模式(禁用Dropout)
  
  # 加载测试数据
  test_loader = load_test_data()
  
  # 推理与评估
  total = 0
  correct = 0
  with torch.no_grad():  # 测试时禁用梯度计算,节省显存
      for encoding, labels in test_loader:
          logits = model(**encoding)
          predictions = torch.argmax(logits, dim=1)  # 取概率最大的类别
          correct += (predictions == labels).sum().item()
          total += labels.size(0)
  
  # 计算准确率
  accuracy = correct / total
  print(f"Test Accuracy: {accuracy:.4f}")

# -------------------------- 主函数 --------------------------
if __name__ == "__main__":
  test_model()  

SFtest_modle.py

python 复制代码
import torch
from transformers import AutoTokenizer, BertModel
import os

# -------------------------- 全局配置 --------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 项目路径配置
current_script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_script_dir)
model_path = os.path.join(project_root, "Remote", "qing_model.pth")  # 训练好的模型路径
model_name = os.path.join(project_root, "Remote", "mode")           # BERT模型路径

# 类别映射(根据你的任务修改,如0-负类,1-正类)
label_map = {0: "类别0", 1: "类别1"}  

# 提前加载Tokenizer和模型
tokenizer = AutoTokenizer.from_pretrained(model_name)

# -------------------------- 模型定义 --------------------------
class MineModel(torch.nn.Module):
  def __init__(self, model_name=model_name):
      super().__init__()
      self.bert = BertModel.from_pretrained(model_name)
      self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 2)  # 二分类
  
  def forward(self, **kwargs):
      outputs = self.bert(**kwargs)
      pooled_output = outputs.pooler_output
      logits = self.classifier(pooled_output)
      return logits

# -------------------------- 加载模型 --------------------------
def load_model():
  model = MineModel().to(device)
  model.load_state_dict(torch.load(model_path, map_location=device)["model_state_dict"])
  model.eval()  # 切换为评估模式
  return model

# -------------------------- 文本推理函数 --------------------------
def predict_text(text, model):
  # 文本编码
  encoding = tokenizer(
      text,
      padding="longest",
      truncation=True,
      return_tensors="pt"
  ).to(device)  # 确保张量在目标设备上
  
  # 模型推理
  with torch.no_grad():
      logits = model(**encoding)
      probs = torch.nn.functional.softmax(logits, dim=1)  # 转换为概率
      pred_label = torch.argmax(probs, dim=1).item()     # 预测类别
      confidence = probs[0, pred_label].item()           # 置信度
  
  return pred_label, confidence

# -------------------------- 主函数 --------------------------
if __name__ == "__main__":
  model = load_model()
  print("模型加载完成,输入文本进行验证(输入'exit'退出):")
  
  while True:
      text = input("\n请输入要验证的文本:")
      if text.lower() == "exit":
          break
      if not text.strip():
          print("文本不能为空,请重新输入!")
          continue
      
      pred_label, confidence = predict_text(text, model)
      print(f"预测类别:{label_map[pred_label]}")
      print(f"置信度:{confidence:.4f}")
相关推荐
Lxinccode2 小时前
python(59) : 多线程调用大模型ocr提取图片文本
开发语言·python·图片提取文字·批量提取文件·多线程ocr
梁辰兴2 小时前
PyCharm使用了Conda的虚拟环境创建的的Python项目,下载库(包)到该项目的虚拟环境中
python·pycharm·conda·错误·异常·异常报错
自由日记2 小时前
python简单线性回归
开发语言·python·线性回归
Halo_tjn3 小时前
Set集合专项实验
java·开发语言·前端·python
vvoennvv3 小时前
【Python TensorFlow】 BiTCN-LSTM双向时间序列卷积长短期记忆神经网络时序预测算法(附代码)
python·神经网络·tensorflow·lstm·tcn
q***42053 小时前
python的sql解析库-sqlparse
数据库·python·sql
大数据追光猿4 小时前
LangChain / LangGraph / AutoGPT / CrewAI / AutoGen 五大框架对比
经验分享·笔记·python·langchain·agent
wang_yb4 小时前
别急着转投 Polars!Pandas 3.0 带着“黑科技”杀回来了
python·databook
Jamesvalley4 小时前
flask处理所有logging
后端·python·flask