十一、基于 BERT 的中文文本情感分类模型训练全解析

在自然语言处理(NLP)领域,中文文本情感分类是一项兼具实用价值与技术代表性的任务,广泛应用于电商评论分析、舆情监控、客户反馈处理等场景。传统机器学习模型难以捕捉中文的深层语义与语境信息,而预训练语言模型的出现,为该任务提供了高效、高精度的解决方案。

本文将基于一套模块化的完整训练代码,从数据加载、模型构建、端到端训练三个核心环节入手,结合代码片段与详细解释,全方位拆解基于 BERT 的中文文本情感分类模型训练全流程,帮助读者不仅能看懂代码,更能理解背后的技术逻辑。

一、任务概述

本次训练的核心任务是中文文本二分类情感分析,即输入一段中文评价文本(如酒店、餐厅评论),模型输出该文本对应的情感倾向(正面 / 负面,分别用标签 1/0 表示)。

核心依赖与数据集

  1. 核心依赖:PyTorch(模型构建与训练)、Transformers(BERT 预训练模型加载与分词)。
  2. 数据集:ChnSentiCorp 中文情感分析数据集,已按train(训练集)、validation(验证集)、test(测试集)完成拆分,满足模型训练的基本数据要求。
  3. 代码架构:采用模块化设计,分为 3 个核心文件,职责明确、逻辑解耦:
    • MyData.py:标准化数据集加载模块。
    • net.py:基于 BERT 的下游分类网络构建模块。
    • trainer.py:端到端训练流程落地模块。

下载训练模型

复制代码
from datasets import load_from_disk, load_dataset
from huggingface_hub import snapshot_download
# 指定镜像站地址(国内推荐)
repo_id = "google-bert/bert-base-chinese"
local_dir = "./bert-base-chinese"  # 本地保存路径

# 下载数据集
snapshot_download(
    repo_id=repo_id,
    repo_type="model",
    local_dir=local_dir,
    local_dir_use_symlinks=False,
    endpoint="https://hf-mirror.com"  # 使用国内镜像
)

print(f"模型已下载到: {local_dir}")

二、模块一:数据加载(MyData.py)------ 标准化数据接口实现

数据是模型训练的源头,高质量的数据加载逻辑是保证训练顺利进行的前提。该模块的核心目标是实现 PyTorch 生态下的自定义数据集加载,为后续训练提供统一、可复用的样本获取接口,解耦数据读取与训练逻辑。

完整代码

复制代码
# 从PyTorch的工具包中导入Dataset抽象类,自定义数据集必须继承该类并实现核心方法
from torch.utils.data import Dataset
# 从Hugging Face的datasets库中导入从本地磁盘加载数据集的函数
from datasets import load_from_disk

# 定义自定义数据集类,继承自torch.utils.data.Dataset
class Mydataset(Dataset):
    # 初始化数据(类的构造方法),用于加载和筛选数据集
    def __init__(self, split):
        """
        自定义数据集的初始化方法
        Args:
            split (str): 数据集划分标识,可选值为 "train"(训练集)、"validation"(验证集)、"test"(测试集)
        """
        # 从本地指定磁盘路径加载已保存的Hugging Face格式数据集(ChnSentiCorp是中文情感分析数据集)
        self.dataset = load_from_disk(r"D:\pyprojecgt\flaskProject\langchainstudy\modelscope\bert\trainstudy\data\ChnSentiCorp")
        
        # 根据传入的split参数,筛选对应的数据集划分
        if split == "train":
            # 筛选出训练集数据并赋值给实例变量self.dataset
            self.dataset = self.dataset["train"]
        elif split == "validation":
            # 筛选出验证集数据并赋值给实例变量self.dataset
            self.dataset = self.dataset["validation"]
        elif split == "test":
            # 筛选出测试集数据并赋值给实例变量self.dataset
            self.dataset = self.dataset["test"]
        else:
            # 若传入无效的split参数,打印错误提示信息
            print("数据集名称错误!请传入有效划分:train/validation/test")
    
    # 获取数据集的样本总数(必须实现的Dataset抽象方法)
    def __len__(self):
        """
        返回当前数据集划分的样本总数
        Returns:
            int: 数据集样本数量
        """
        # 直接返回加载的数据集划分的长度(即样本总数)
        return len(self.dataset)
    
    # 根据索引获取单个样本的定制化数据(必须实现的Dataset抽象方法)
    def __getitem__(self, item):
        """
        根据索引提取单个样本的文本和标签
        Args:
            item (int): 样本的索引值(从0开始,不超过__len__()返回的数值-1)
        Returns:
            tuple: 包含单个样本的(文本内容,情感标签)元组
        """
        # 从数据集中根据索引item提取对应样本的文本内容
        text = self.dataset[item]["text"]
        # 从数据集中根据索引item提取对应样本的情感标签(0表示负面,1表示正面,视数据集标注规则而定)
        label = self.dataset[item]["label"]
        
        # 返回处理后的单个样本数据(文本+标签)
        return text, label

# 主程序入口,仅在直接运行该脚本时执行(用于测试自定义数据集是否正常工作)
if __name__ == '__main__':
    # 实例化自定义数据集类,加载"validation"(验证集)数据
    dataset = Mydataset("validation")
    # 遍历验证集的所有样本,逐个打印样本内容(文本+标签),验证数据加载是否正常
    for data in dataset:
        print(data)

代码详细解释

这段代码的核心是自定义一个符合 PyTorch 数据加载规范的数据集类Mydataset ,用于加载本地磁盘上的中文情感分析数据集ChnSentiCorp,并支持按train(训练集)、validation(验证集)、test(测试集)划分数据,最终能通过索引或遍历获取每个样本的「文本内容」和「情感标签」,为后续的深度学习模型(如 BERT)训练做数据准备。

第一步:导入必要的库

复制代码
# 从PyTorch的工具包中导入Dataset抽象类
from torch.utils.data import Dataset
# 从Hugging Face的datasets库中导入从本地磁盘加载数据集的函数
from datasets import load_from_disk
  1. torch.utils.data.Dataset:这是 PyTorch 中所有自定义数据集的基类(抽象类) ,它规定了自定义数据集必须实现两个核心方法:__len__()__getitem__(),只有遵循这个规范,后续才能和 PyTorch 的DataLoader(数据加载器,用于批量加载、打乱数据等)配合使用。
  2. datasets.load_from_disk:这是 Hugging Face datasets库提供的函数,用于加载已经提前保存到本地磁盘的 Hugging Face 格式数据集 (通常是通过dataset.save_to_disk()保存的)。这种格式的数据集自带结构化信息,无需手动解析 txt/csv 等文件,使用起来更便捷。

第二步:定义自定义数据集类Mydataset

所有自定义 PyTorch 数据集都必须继承Dataset基类,这里定义的Mydataset就是为了适配ChnSentiCorp数据集的加载和提取。

1. 构造方法:__init__(self, split)

这是类的初始化方法,用于加载数据集、筛选数据划分,在实例化类的时候会自动执行。

复制代码
def __init__(self, split):
    """
    自定义数据集的初始化方法
    Args:
        split (str): 数据集划分标识,可选值为 "train"(训练集)、"validation"(验证集)、"test"(测试集)
    """
    # 1. 从本地指定路径加载Hugging Face格式数据集
    self.dataset = load_from_disk(r"D:\pyprojecgt\flaskProject\langchainstudy\modelscope\bert\trainstudy\data\ChnSentiCorp")
    
    # 2. 根据split参数筛选对应的数据集划分
    if split == "train":
        self.dataset = self.dataset["train"]
    elif split == "validation":
        self.dataset = self.dataset["validation"]
    elif split == "test":
        self.dataset = self.dataset["test"]
    else:
        print("数据集名称错误!请传入有效划分:train/validation/test")
  • self.dataset:这是类的实例变量 ,整个类的其他方法(__len____getitem__)都可以访问它。首先通过load_from_disk加载完整的ChnSentiCorp数据集(该数据集本身已经包含trainvalidationtest三个划分)。
  • split参数:用于指定要加载的数据集划分,比如传入"validation"就只保留验证集数据,后续操作都只针对这个划分。
  • 异常提示:如果传入无效的split值(比如"val"),会打印错误提示,避免程序直接崩溃(新手友好型处理)。
2. 样本数量方法:__len__(self)

这是 PyTorch Dataset 基类要求必须实现 的方法,用于返回当前数据集划分的样本总数

复制代码
def __len__(self):
    """
    返回当前数据集划分的样本总数
    Returns:
        int: 数据集样本数量
    """
    return len(self.dataset)
  • 逻辑非常简单:直接返回self.dataset的长度(即样本个数)。比如ChnSentiCorp的验证集大概有 1000 个样本,调用len(dataset)datasetMydataset实例)时,就会自动调用这个方法,返回 1000。
  • 作用:后续DataLoader加载数据时,需要通过这个方法知道数据集的总长度,从而确定遍历的边界。
3. 单个样本提取方法:__getitem__(self, item)

这也是 PyTorch Dataset 基类要求必须实现 的核心方法,用于根据索引item提取并返回单个样本的处理后数据

复制代码
def __getitem__(self, item):
    """
    根据索引提取单个样本的文本和标签
    Args:
        item (int): 样本的索引值(从0开始,不超过__len__()返回的数值-1)
    Returns:
        tuple: 包含单个样本的(文本内容,情感标签)元组
    """
    # 1. 提取索引对应的样本文本
    text = self.dataset[item]["text"]
    # 2. 提取索引对应的样本标签
    label = self.dataset[item]["label"]
    
    # 3. 返回处理后的单个样本
    return text, label
  • item参数:是样本的索引值 (整数),从 0 开始,最大不超过__len__()-1(比如数据集有 1000 个样本,item的取值范围是 0~999)。
  • self.dataset[item]["text"]:Hugging Face 格式的数据集支持「索引 + 键名」的方式提取数据,"text"是数据集里存储文本内容的键,"label"是存储情感标签的键(通常 0 表示负面情感,1 表示正面情感,由ChnSentiCorp数据集的标注规则决定)。
  • 返回值:以 ** 元组(text, label)** 的形式返回单个样本,这样后续遍历或批量加载时,能很方便地拆分文本和标签。

第三步:主程序测试(if __name__ == '__main__'

这部分代码的作用是测试自定义数据集是否能正常工作,只有在直接运行该脚本时才会执行(如果被其他脚本导入,这部分不会执行)。

复制代码
if __name__ == '__main__':
    # 1. 实例化Mydataset类,加载验证集(validation)
    dataset = Mydataset("validation")
    # 2. 遍历数据集的所有样本,逐个打印
    for data in dataset:
        print(data)
  1. 实例化类:dataset = Mydataset("validation"),此时会调用__init__方法,加载验证集数据。
  2. 遍历数据集:for data in dataset,这背后是 PyTorch 的迭代器支持:
    • 首先调用__len__()获取数据集总长度。
    • 然后自动生成从 0 到len(dataset)-1的索引,逐个传入__getitem__()方法。
    • __getitem__()返回的(text, label)元组赋值给data,并打印。
  3. 预期输出:会逐个打印验证集中的样本,格式类似("这个商品质量很好,推荐购买", 1)("物流太慢了,体验很差", 0)

补充:关键细节和使用场景

  1. 为什么要自定义 Dataset?
    • PyTorch 的DataLoader只能加载符合Dataset规范的数据集。
    • 实际项目中,数据集的格式、需要提取的字段(这里是textlabel)各不相同,自定义Dataset可以灵活适配不同的数据集需求。
  2. 后续扩展 :这段代码只提取了原始文本和标签,后续训练 BERT 等模型时,还需要在__getitem__()中添加文本编码 (比如用BertTokenizer把文本转换成 token id、attention mask 等),才能输入到模型中。
  3. 路径注意事项 :代码中的数据集路径是绝对路径(r"D:\pyprojecgt\..."),如果后续移动脚本或数据集,需要修改该路径,也可以改成相对路径提高可移植性。

总结

  1. 核心结构:自定义Mydataset继承torch.utils.data.Dataset,必须实现__init__(加载筛选)、__len__(返回样本数)、__getitem__(提取单个样本)三个方法。
  2. 数据流转:load_from_disk加载本地数据集 → 按split筛选划分 → 按索引提取(text, label)样本。
  3. 核心作用:为 PyTorch 模型训练提供符合规范的数据源,后续可配合DataLoader实现批量加载、数据打乱等功能。

模块核心价值

该模块将繁琐的数据集读取、格式解析、索引映射等工作封装在独立模块中,后续训练流程无需关注数据的底层存储格式与读取方式,只需通过Mydataset类即可快速获取标准化样本,极大提升了代码的可维护性与可复用性。同时,若后续需要更换数据集,只需修改__init__方法中的数据读取逻辑,无需改动训练核心代码,具备良好的扩展性。

三、模块二:网络模型构建(net.py)------ 基于 BERT 的下游分类网络搭建

模型是训练的核心载体,本次训练依托bert-base-chinese预训练模型,搭建轻量级下游情感分类网络,充分复用预训练模型的中文语义提取能力,降低下游任务的训练成本与难度,兼顾训练效率与分类精度。

完整代码

复制代码
# 分类系统案例:自然语言任务分类(文本生成、文本分类、问答系统)
# 从transformers库中导入自动分词器类,用于处理文本数据(分词、转token id等)
from transformers import (
    AutoTokenizer,
)

# 从transformers库中导入BERT预训练模型类,用于加载预训练的BERT主干网络
from transformers import BertModel

# 导入PyTorch核心库,用于搭建神经网络、张量计算等
import torch

# 1. 加载本地预训练模型和对应的分词器
# 定义本地bert-base-chinese模型的存储路径(该模型是中文通用预训练BERT模型)
model_dir = "D:\\本地模型\\google-bert\\bert-base-chinese"

# 加载分词器:AutoTokenizer会自动匹配预训练模型的分词规则,从本地路径加载
# 作用:将原始中文文本转换为BERT模型能够识别的输入格式(token id、attention mask等)
tokenizer = AutoTokenizer.from_pretrained(model_dir)

# 定义模型训练/推理使用的设备(优先使用GPU,无GPU则使用CPU)
# 以下是注释掉的DirectML设备配置(用于支持AMD显卡加速)
# import torch_directml as dml
# DEVICE = dml.device() if dml.is_available() else torch.device("cpu")

# 优先使用NVIDIA CUDA显卡加速,若未安装CUDA或无NVIDIA显卡,则使用CPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载BERT预训练主干模型:从本地路径加载bert-base-chinese,并将模型移至指定设备(GPU/CPU)
# pretrained:预训练模型实例,用于提取文本的高维语义特征(不包含下游分类头)
pretrained = BertModel.from_pretrained(model_dir).to(DEVICE)

# 以下两行是注释掉的调试代码,用于查看预训练模型的结构和词嵌入层信息
# print(pretrained)  # 打印完整的BERT模型结构(网络层、参数等)
# print(pretrained.embeddings.word_embeddings)  # 打印BERT的词嵌入层(将token id转换为词向量)

# 2. 定义下游任务模型(继承PyTorch的Module类)
# 作用:基于BERT预训练模型提取的特征,搭建简单的分类头,完成文本分类任务
class Model(torch.nn.Module):
    # 模型初始化方法:定义下游分类任务的网络结构
    def __init__(self):
        # 调用父类torch.nn.Module的初始化方法,必须保留
        super().__init__()
        
        # 定义全连接层(分类头):将BERT提取的768维特征映射到2分类输出
        # torch.nn.Linear(in_features, out_features)
        # in_features=768:bert-base-chinese模型的隐藏层输出维度固定为768
        # out_features=2:分类任务的类别数(此处为2分类,如情感分析的正面/负面)
        self.fc = torch.nn.Linear(768, 2)
    
    # 模型前向传播方法:定义数据的流转路径和计算逻辑(必须实现)
    # 输入:BERT模型要求的三个核心输入参数(input_ids、attention_mask、token_type_ids)
    def forward(self, input_ids, attention_mask, token_type_ids):
        # 冻结预训练模型(上游BERT主干网络不参与训练,不更新参数)
        # with torch.no_grad():上下文管理器,禁用梯度计算,减少显存占用,提升计算速度
        with torch.no_grad():
            # 调用预训练BERT模型,传入三个输入参数,获取文本特征输出
            # out:BERT模型的返回结果,包含last_hidden_state(最后一层隐藏层输出)等信息
            out = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        
        # 提取有效特征:取last_hidden_state的第0个token(<CLS> token)的768维特征
        # BERT模型中,<CLS> token(对应索引0)用于聚合整句文本的语义信息,适合用于分类任务
        # out.last_hidden_state.shape 一般为 [batch_size, sequence_length, 768]
        # out.last_hidden_state[:, 0] 取所有样本(batch)的第0个token,形状变为 [batch_size, 768]
        cls_feature = out.last_hidden_state[:, 0]
        
        # 下游分类任务:将<CLS> token的768维特征输入全连接层,映射到2维输出
        out = self.fc(cls_feature)
        
        # 对全连接层输出进行softmax激活,将结果转换为概率分布(各分类的概率值,总和为1)
        # dim=1:在分类维度(第1维,对应2个类别)上进行softmax计算
        out = out.softmax(dim=1)
        
        # 返回最终的分类概率结果
        return out

代码详细解释

第一步:导入必要的库

复制代码
# 分类系统案例:自然语言任务分类(文本生成、文本分类、问答系统)
from transformers import (
    AutoTokenizer,
)
# 从transformers库导入BERT预训练模型类
from transformers import BertModel
# 导入PyTorch核心库,用于搭建神经网络、张量计算、设备管理等
import torch
  1. AutoTokenizer:Hugging Face transformers库提供的「自动分词器」类,用于将原始中文文本转换为 BERT 模型能够识别的输入格式(包括分词、转换为 token id、生成注意力掩码等),且能自动匹配对应预训练模型的分词规则。
  2. BertModel:BERT 模型的主干网络类,不包含下游任务的分类头,仅用于提取文本的语义特征,是预训练模型的核心部分。
  3. torch:PyTorch 深度学习框架的核心库,用于搭建自定义分类头、张量计算、设备配置等。

第二步:加载本地预训练模型 / 分词器 & 配置计算设备

复制代码
# 1. 加载本地模型和分词器
# 定义本地中文BERT模型的存储路径(bert-base-chinese是中文通用预训练模型)
model_dir = "D:\\本地模型\\google-bert\\bert-base-chinese"
# 从本地路径加载分词器,与预训练模型配套使用
tokenizer = AutoTokenizer.from_pretrained(model_dir)

# 定义模型计算设备(优先使用GPU加速,无GPU则使用CPU)
# 注释掉的是AMD显卡加速配置(torch_directml),新手可暂时忽略
# import torch_directml as dml
# DEVICE = dml.device() if dml.is_available() else torch.device("cpu")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 从本地路径加载BERT主干预训练模型,并将模型移至指定设备(GPU/CPU)
pretrained = BertModel.from_pretrained(model_dir).to(DEVICE)

# 注释掉的是调试代码,用于查看模型结构和词嵌入层信息
# print(pretrained)
# print(pretrained.embeddings.word_embeddings)
  1. 本地模型路径model_dir:这里存放的是bert-base-chinese预训练模型的所有文件(包括配置文件、权重文件等),通过from_pretrained(model_dir)可直接加载,无需重新从网络下载。
  2. 分词器加载tokenizer = AutoTokenizer.from_pretrained(model_dir):必须与预训练模型配套加载,保证分词规则、词汇表与 BERT 模型一致,否则会出现输入不匹配的错误。
  3. 计算设备配置DEVICE
    • torch.cuda.is_available():判断是否有可用的 NVIDIA GPU(需安装 CUDA、cuDNN)。
    • 模型移至设备.to(DEVICE):将 BERT 模型的所有参数加载到指定设备(GPU/CPU),后续所有计算都将在该设备上进行,避免「张量与模型不在同一设备」的报错。
  4. 调试代码注释:print(pretrained)可打印 BERT 的完整网络结构(包含嵌入层、多层 Transformer 编码器),print(pretrained.embeddings.word_embeddings)可查看词嵌入层(将 token id 转换为 768 维词向量)。

第三步:定义下游任务分类模型(Model类)

这是代码的核心,继承自torch.nn.Module(PyTorch 所有神经网络的基类),搭建用于二分类的「分类头」,并实现数据的前向传播逻辑。

1. 初始化方法__init__(self):定义网络结构
复制代码
class Model(torch.nn.Module):
    # 模型结构设计
    def __init__(self):
        # 调用父类torch.nn.Module的初始化方法,必须保留,否则模型无法正常工作
        super().__init__()
        # 定义全连接层(分类头):将768维特征映射到2分类输出
        self.fc = torch.nn.Linear(768, 2)
  • super().__init__():必须调用父类的初始化方法,用于初始化torch.nn.Module的内置属性(如模型参数、设备信息等),缺少这行代码会导致后续模型训练 / 推理报错。
  • self.fc = torch.nn.Linear(768, 2):定义一个单层全连接层(作为分类头),核心参数说明:
    • in_features=768:输入特征维度,bert-base-chinese模型的隐藏层输出维度固定为 768,这是预训练模型的固有属性,不能随意修改。
    • out_features=2:输出特征维度,对应二分类任务(如情感分析的「正面 / 负面」、文本的「相关 / 不相关」),若要做多分类,只需修改该数值为类别数即可。
    • 作用:将 BERT 提取的 768 维文本特征,映射到 2 个分类的得分值,完成从「特征提取」到「分类预测」的转换。
2. 前向传播方法forward(...):定义数据流转逻辑
复制代码
    def forward(self,input_ids,attention_mask,token_type_ids):
        # 上游任务不参与训练(冻结BERT预训练模型参数)
        with torch.no_grad():
            out = pretrained(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
        
        # 下游任务参与训练(提取有效特征 + 分类 + 概率转换)
        out = self.fc(out.last_hidden_state[:,0])
        out = out.softmax(dim=1)
        
        return out

这是模型的核心方法,定义了输入数据如何通过网络层得到输出结果,调用模型实例时会自动执行该方法,关键步骤拆解:

  1. 冻结 BERT 预训练模型:with torch.no_grad()

    • torch.no_grad():上下文管理器,禁用该代码块内的梯度计算和参数更新。
    • 作用:冻结 BERT 主干模型的所有参数,使其在后续训练中保持不变(仅训练自定义的分类头self.fc)。这样做的好处是:减少显存占用、提升训练速度、避免破坏 BERT 预训练好的通用语义特征,适合新手入门和小数据集训练。
    • 补充:后续若想提升模型效果,可解冻 BERT 的部分层(如最后几层 Transformer),进行「联合微调」。
  2. 调用 BERT 提取文本特征:out = pretrained(...)

    • 输入参数:input_idsattention_masktoken_type_ids,这是 BERT 模型要求的三个核心输入(由tokenizer处理原始文本后生成),各自作用:
    • input_ids:原始文本分词后,每个 token 对应的数字编码(张量格式)。
    • attention_mask:标记有效 token(1)和填充 token(0),让 BERT 忽略填充的无效部分。
    • token_type_ids:区分句子对(如问答任务中的问题和答案),单句文本分类任务中该张量全为 0。
    • 输出结果out:BERT 模型的返回对象,包含多个属性,其中last_hidden_state是我们需要的「最后一层隐藏层输出」,形状为[batch_size, sequence_length, 768]
    • batch_size:批次大小(一次处理的样本数)。
    • sequence_length:每个样本的序列长度(分词后的 token 数量,固定为同一长度)。
    • 768:每个 token 的特征维度。
  3. 提取有效分类特征:out.last_hidden_state[:, 0]

    • BERT 模型在处理文本时,会在每个句子开头添加一个特殊 token <CLS>(对应索引 0),该 token 的特征会聚合整个句子的语义信息,是专门为分类任务设计的有效特征。
    • out.last_hidden_state[:, 0]:取所有样本(:对应batch_size)的第 0 个 token(<CLS>)特征,形状从[batch_size, sequence_length, 768]转换为[batch_size, 768],刚好匹配全连接层self.fc的输入维度。
  4. 分类头映射:out = self.fc(...)

    • <CLS> token 的 768 维特征输入全连接层self.fc,得到形状为[batch_size, 2]的输出,此时的输出是「未归一化的得分值(logits)」,还不能直接作为分类概率。
  5. 概率转换:out = out.softmax(dim=1)

    • softmax:激活函数,用于将多分类的得分值转换为概率分布(各分类概率之和为 1)。
    • dim=1:指定在「分类维度」(第 1 维,对应 2 个类别)上进行 softmax 计算,最终每个样本会得到两个概率值(如[0.05, 0.95]),分别对应两个类别的预测概率。
  6. 返回结果:return out

    • 返回形状为[batch_size, 2]的分类概率张量,后续可通过「取概率最大值对应的索引」得到最终的分类结果(如torch.argmax(out, dim=1))。

补充:关键细节与后续扩展

  1. 为什么选择<CLS> token?
    • 除了<CLS> token,也可以取所有 token 特征的平均值作为句子特征,但 BERT 设计时<CLS> token 专门用于聚合整句语义,在分类任务上效果更稳定、更常用。
  2. 后续需要补充的步骤
    • 这段代码仅搭建了模型结构,还需要完成:数据加载(结合之前的Mydataset)、文本编码(用tokenizer处理原始文本生成三个输入参数)、定义损失函数(如CrossEntropyLoss)、优化器(如AdamW)、训练循环、验证与推理,才能完成完整的文本分类任务。
  3. 多分类扩展
    • 若要实现 3 分类、10 分类,只需修改全连接层的输出维度:self.fc = torch.nn.Linear(768, 类别数)softmax(dim=1)会自动适配新的类别数,将得分转换为对应概率分布。

总结

  1. 核心思路:冻结 BERT 主干(提取 768 维语义特征)+ 单层全连接层(二分类映射)+ softmax(概率转换),搭建轻量化文本分类模型。
  2. 关键步骤:文本转 BERT 输入格式 → 冻结 BERT 提取<CLS>特征 → 全连接层分类 → softmax 转换为概率。
  3. 核心价值:为中文文本二分类任务提供了基础模型结构,可直接用于后续的训练和推理,且易于扩展到多分类任务。

模块核心价值

该模块实现了 "预训练模型复用 + 下游任务轻量化" 的网络构建思路,既利用了 BERT 的强大语义提取能力,保证了分类任务的精度,又通过冻结预训练层降低了训练成本,让模型在普通硬件设备上也能高效训练,非常适合入门者学习与小规模 NLP 任务落地。

四、模块三:训练主程序(trainer.py)------ 端到端训练流程落地

如果说数据加载与模型构建模块是 "基石",那么训练主程序模块就是 "粘合剂"。该模块整合了前两个模块的成果,配置了各类训练组件,实现了 "前向计算→损失计算→反向传播→参数更新" 的训练闭环,同时包含训练监控与模型保存功能,落地完整的训练流程。

完整代码

复制代码
# 导入PyTorch核心库,用于张量计算、模型训练、设备管理等
import torch
# 导入自定义的数据集类Mydataset(用于加载本地ChnSentiCorp数据集)
from MyData import Mydataset
# 导入PyTorch的数据加载器DataLoader,用于批量加载、打乱数据集
from torch.utils.data import DataLoader
# 导入自定义的下游分类模型Model(基于BERT的二分类模型)
from net import Model
# 导入BERT分词器,用于文本编码处理
from transformers import BertTokenizer
# 导入AdamW优化器(针对Transformer模型优化的Adam变体,常用在预训练模型微调)
from torch.optim import AdamW

# 定义模型训练/推理使用的设备(优先使用NVIDIA CUDA GPU,无GPU则使用CPU)
# 以下是注释掉的AMD显卡加速配置(torch_directml),新手可暂时忽略
# import torch_directml as dml
# DEVICE = dml.device() if dml.is_available() else torch.device("cpu")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 打印当前使用的计算设备,方便确认是否成功启用GPU
print(f"使用设备: {DEVICE}")

# 定义模型训练的总轮次(Epoch):完整遍历训练集的次数
EPOCH = 30
# 定义本地中文BERT预训练模型的存储路径
model_name = r'D:\本地模型\google-bert\bert-base-chinese'
# 从本地路径加载BERT分词器,与预训练模型配套使用,用于文本编码
token = BertTokenizer.from_pretrained(model_name)

# 自定义数据批处理函数collate_fn:用于DataLoader中处理批量样本,完成文本编码和格式转换
# 为什么需要collate_fn?
# 1. 神经网络要求输入尺寸固定,但原始文本长度各不相同,需要统一填充到相同长度
# 2. 批量处理可利用GPU并行计算,大幅提升训练效率,避免单样本处理的低效问题
# 3. 集中管理所有数据预处理逻辑(分词、填充、张量转换),让代码更整洁,减少重复
def collate_fn(data):
    # 1. 从批量样本中分离文本内容和对应的标签(data是Mydataset返回的(text, label)元组列表)
    sentes = [i[0] for i in data]  # 提取所有样本的文本,组成列表:["文本1", "文本2", ..., "文本N"]
    label = [i[1] for i in data]   # 提取所有样本的标签,组成列表:[标签1, 标签2, ..., 标签N](0或1,二分类)
    
    # 2. 使用BERT分词器对批量文本进行编码处理,转换为模型可识别的输入格式
    data = token.batch_encode_plus(
        batch_text_or_text_pairs=sentes,  # 待编码的批量文本列表
        truncation=True,                  # 开启截断:若文本长度超过max_length,截断到max_length
        padding="max_length",             # 开启填充:将所有文本填充到max_length指定的长度
        max_length=350,                   # 文本的最大序列长度(统一输入尺寸,根据数据集调整)
        return_tensors="pt",              # 返回PyTorch的张量(Tensor)格式,方便后续模型计算
        return_length=True,               # 返回每个文本编码后的实际长度(可选,此处未使用,保留用于调试)
    )
    
    # 3. 从编码结果中提取BERT模型需要的三个核心输入张量
    input_ids = data["input_ids"]          # 文本分词后的数字编码张量,形状:[batch_size, max_length]
    attention_mask = data["attention_mask"]# 注意力掩码张量:标记有效token(1)和填充token(0),形状:[batch_size, max_length]
    token_type_ids = data["token_type_ids"]# 句子对分隔张量:单句分类任务中全为0,形状:[batch_size, max_length]
    
    # 4. 将标签列表转换为PyTorch的长整型张量(LongTensor),适配CrossEntropyLoss损失函数
    labels = torch.LongTensor(label)
    
    # 5. 返回处理后的批量输入数据和标签,供模型训练使用
    return input_ids, attention_mask, token_type_ids, labels

# 6. 创建训练集数据集实例:加载Mydataset中的"train"划分数据
train_dataset = Mydataset("train")

# 7. 创建训练集数据加载器(DataLoader):用于批量迭代训练数据
train_laoder = DataLoader(
    dataset=train_dataset,    # 传入已创建的训练集数据集
    batch_size=32,            # 批次大小:每次迭代处理32个样本(根据GPU显存调整,显存不足可减小)
    shuffle=True,             # 开启数据打乱:每个Epoch前打乱训练集顺序,避免模型过拟合到数据顺序
    drop_last=True,           # 删除最后一个不完整的批次:若数据集总数无法被batch_size整除,丢弃最后一个不足32个样本的批次
    collate_fn=collate_fn     # 指定自定义的批处理函数,处理批量样本的编码和格式转换
)

# 主程序入口:仅在直接运行该脚本时执行训练逻辑(被其他脚本导入时不执行)
if __name__ == '__main__':
    # 打印设备信息,确认训练设备
    print(DEVICE)
    
    # 1. 实例化自定义的下游分类模型(基于BERT的二分类模型)
    model = Model()
    print(f"\n3. 移动模型到 {DEVICE}...")
    
    # 2. 关键步骤:将模型移至指定计算设备(GPU/CPU),后续模型计算均在该设备上进行
    # 避免「模型参数与输入张量不在同一设备」的报错
    model = model.to(DEVICE)
    print("  模型移动成功")
    
    # 3. 初始化AdamW优化器:用于更新模型参数(仅更新自定义分类头的参数,BERT主干已冻结)
    # model.parameters():获取模型中可训练的参数(此处仅全连接层fc的参数)
    # lr=5e-4:学习率(控制参数更新的步长,Transformer模型微调常用5e-5~1e-4)
    optimizer = AdamW(model.parameters(), lr=5e-4)
    print("  优化器创建成功")
    
    # 4. 定义损失函数:交叉熵损失(CrossEntropyLoss),适用于分类任务
    # 自动整合了log_softmax和nll_loss,适合直接处理模型输出的概率分布和标签
    loss_func = torch.nn.CrossEntropyLoss()
    
    # 5. 开启模型训练模式:启用Dropout、BatchNorm等训练相关层(此处模型无此类层,仅为规范操作)
    model.train()
    
    # 6. 开始训练循环:遍历所有Epoch
    for epoch in range(EPOCH):
        # 7. 遍历DataLoader中的批量数据:enumerate返回批次索引i和批量数据
        for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(train_laoder):
            # 8. 关键步骤:将批量输入数据和标签移至指定计算设备(与模型在同一设备)
            input_ids, attention_mask, token_type_ids, labels = input_ids.to(DEVICE), attention_mask.to(DEVICE), token_type_ids.to(DEVICE), labels.to(DEVICE)
            
            # 9. 模型前向传播:传入三个核心输入张量,得到分类概率输出
            out = model(input_ids, attention_mask, token_type_ids)
            print(f"  前向计算成功,输出形状: {out.shape}")  # 输出形状应为 [32, 2](batch_size=32,二分类)
            
            # 10. 重新实例化损失函数(注:此处可优化,无需在循环内重复实例化,外部定义一次即可)
            loss_func = torch.nn.CrossEntropyLoss()
            # 11. 计算批次损失:传入模型输出out和真实标签labels
            loss = loss_func(out, labels)
            print(f"  损失计算成功: {loss.item():.4f}")  # loss.item()提取张量中的标量值,保留4位小数
            
            # 12. 反向传播前:清零优化器的梯度(PyTorch梯度会累积,需手动清零)
            optimizer.zero_grad()
            # 13. 反向传播:计算损失对模型可训练参数的梯度
            loss.backward()
            print("  反向传播成功")
            
            # 14. 调试:检查是否存在有效梯度(确认模型参数是否在更新)
            grad_found = False
            for name, param in model.named_parameters():
                if param.grad is not None:
                    grad_found = True
                    grad_norm = param.grad.norm().item()  # 计算梯度的范数,用于观察梯度大小
                    print(f"  {name} 梯度范数: {grad_norm:.6f}")
                    break  # 仅打印第一个有梯度的参数,避免输出过多
            if not grad_found:
                print("  警告: 没有找到梯度")
            
            # 15. 优化器更新参数:根据计算出的梯度,更新模型可训练参数(此处为全连接层fc的参数)
            optimizer.step()
            print("  参数更新成功")
            
            # 16. 每5个批次打印一次训练信息(Epoch、批次索引、损失值、准确率),方便监控训练进度
            if i % 5 == 0:
                # 提取预测结果:取概率分布中最大值对应的索引(0或1),作为模型的分类预测
                out = out.argmax(dim=1)
                # 计算批次准确率:预测正确的样本数 / 总样本数
                acc = (out == labels).sum().item() / len(labels)
                # 打印训练监控信息
                print(f"Epoch: {epoch}, Batch: {i}, Loss: {loss.item():.4f}, Acc: {acc:.4f}")
        
        # 17. 每个Epoch训练完成后,保存模型参数(state_dict),方便后续加载继续训练或推理
        # state_dict:保存模型中可训练参数的字典,是PyTorch保存模型的标准方式
        torch.save(model.state_dict(), f"params/{epoch}bert.pt")
        print(f"Epoch {epoch} 训练完成,参数保存成功!")

代码详细解释

这段代码是端到端的 BERT 文本二分类模型训练代码 ,承接之前搭建的Mydataset(自定义数据集)和Model(自定义分类模型),实现了从「批量数据预处理」到「模型参数保存」的完整训练流程,核心特点是:

  1. 采用「批量处理」提升训练效率,利用 GPU 并行计算优势。
  2. 仅训练自定义的全连接层分类头,冻结 BERT 主干模型参数。
  3. 包含完整的训练闭环:前向计算→损失计算→反向传播→参数更新→进度监控→模型保存。
  4. 最终输出训练过程中的损失和准确率,并保存每一轮的模型参数,方便后续推理和继续训练。

第一步:导入依赖库 / 模块

复制代码
import torch
from MyData import Mydataset
from torch.utils.data import DataLoader
from net import Model
from transformers import BertTokenizer
from torch.optim import AdamW

每个导入项的核心作用(呼应之前的代码,形成闭环):

  1. torch:PyTorch 核心库,负责张量计算、模型训练、设备管理、损失函数定义等。
  2. Mydataset:你自定义的数据集类,用于加载本地ChnSentiCorp训练集数据,返回(text, label)元组。
  3. DataLoader:PyTorch 的批量数据加载器,用于将Mydataset的样本打包成批次,支持打乱、并行加载等功能。
  4. Model:你自定义的下游分类模型,基于 BERT 提取特征 + 单层全连接层实现二分类。
  5. BertTokenizer:BERT 配套分词器,用于将原始文本转换为模型可识别的张量格式。
  6. AdamW:针对 Transformer 模型优化的 Adam 变体优化器(修正了权重衰减逻辑),是预训练模型微调的首选优化器。

第二步:配置计算设备 & 定义基础参数

复制代码
# 定义训练/推理使用的设备(优先GPU,无GPU则CPU)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {DEVICE}")

# 定义训练总轮次(Epoch):完整遍历训练集的次数
EPOCH = 30
# 本地中文BERT预训练模型路径
model_name = r'D:\本地模型\google-bert\bert-base-chinese'
# 加载与BERT配套的分词器
token = BertTokenizer.from_pretrained(model_name)
  1. 设备配置DEVICE

    • 核心逻辑:优先使用 NVIDIA CUDA GPU(需安装 CUDA/cuDNN),无 GPU 则降级为 CPU。
    • 关键作用:后续模型和所有输入数据都要移至该设备,避免「模型参数与张量不在同一设备」的报错(PyTorch 中,张量和模型必须在同一设备才能进行计算)。
    • 打印设备信息:方便你确认是否成功启用 GPU(若显示cuda:0则表示 GPU 可用,显示cpu则表示使用 CPU)。
  2. 超参数与分词器

    • EPOCH=30:训练轮次,即完整遍历训练集 30 次。轮次过多可能过拟合,过少可能欠拟合,后续可根据验证集效果调整。
    • token = BertTokenizer.from_pretrained(model_name):加载与 BERT 预训练模型配套的分词器,保证分词规则、词汇表与 BERT 一致,否则会出现输入不匹配错误。

第三步:核心自定义函数collate_fn(批量数据处理)

这是训练代码的关键辅助函数 ,用于给DataLoader提供批量数据的处理逻辑,解决「文本长度不一致」和「批量编码效率低」的问题。

复制代码
def collate_fn(data):
    # 1. 分离文本和标签(data是Mydataset返回的(text, label)元组列表)
    sentes = [i[0] for i in data]# ["文本1", "文本2", ...]
    label = [i[1] for i in data] # [标签1, 标签2, ...]
    
    # 2. 批量编码文本:转换为BERT所需的张量格式
    data = token.batch_encode_plus(
        batch_text_or_text_pairs=sentes,
        truncation=True,
        padding="max_length",
        max_length=350,
        return_tensors="pt", # 返回 PyTorch tensor
        return_length=True,
    )
    
    # 3. 提取BERT核心输入张量
    input_ids = data["input_ids"]
    attention_mask = data["attention_mask"]
    token_type_ids = data["token_type_ids"]
    
    # 4. 标签转换为长整型张量
    labels = torch.LongTensor(label)

    return input_ids,attention_mask,token_type_ids,labels
逐步骤解释:
  1. 分离文本和标签

    • dataDataLoader传入的「批量样本列表」,每个元素是Mydataset.__getitem__()返回的(text, label)元组。
    • 用列表推导式分离出文本列表sentes和标签列表label,方便后续分别处理。
  2. 批量编码token.batch_encode_plus() :这是BertTokenizer的批量编码方法,比单样本编码效率高 10 倍以上,核心参数解释:

    • batch_text_or_text_pairs=sentes:传入待编码的批量文本列表。
    • truncation=True:开启截断,若文本分词后长度超过max_length=350,则截断到 350 个 token。
    • padding="max_length":开启「填充到最大长度」,将所有文本填充到350个 token,保证批量内所有文本输入尺寸一致(神经网络要求固定尺寸输入)。
    • max_length=350:文本的最大序列长度,可根据你的数据集调整(显存不足可减小,如 128、256)。
    • return_tensors="pt":返回 PyTorch 张量格式,无需手动转换,可直接用于模型计算。
    • return_length=True:返回每个文本的实际编码长度(可选,此处未使用,仅用于调试)。
  3. 提取 BERT 核心输入张量 :编码后返回的data是一个字典,包含 BERT 模型必需的 3 个张量,形状均为[batch_size, 350]batch_size=32):

    • input_ids:文本分词后的数字编码(每个 token 对应一个唯一数字)。
    • attention_mask:注意力掩码,1表示有效 token,0表示填充 token,让 BERT 忽略填充部分。
    • token_type_ids:句子对分隔符,单句分类任务中全为0(仅在问答、句对匹配等任务中有用)。
  4. 标签转换为torch.LongTensor

    • 将标签列表转换为长整型张量,适配CrossEntropyLoss损失函数的输入要求 (该损失函数要求标签是LongTensor类型,不能是普通列表或浮点型张量)。
    • 最终返回 4 个张量,供模型训练使用。
为什么必须用collate_fn
  • 解决「文本长度不一致」的问题,统一输入尺寸。
  • 批量处理利用 GPU 并行计算,提升训练效率。
  • 集中管理数据预处理逻辑,让代码更整洁,避免重复编码。
  • 若不使用collate_fnDataLoader会直接返回原始元组列表,无法直接输入模型。

第四步:创建数据集和DataLoader(批量数据迭代器)

复制代码
#创建数据集
train_dataset = Mydataset("train")
#创建DataLoader
train_laoder = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True,
    drop_last=True,
    collate_fn=collate_fn
)
  1. train_dataset = Mydataset("train"):实例化自定义数据集,加载ChnSentiCorp的训练集数据。
  2. DataLoader:创建批量数据迭代器,核心参数解释:
    • dataset=train_dataset:传入已加载的训练集。
    • batch_size=32:批次大小,每次迭代处理 32 个样本(根据 GPU 显存调整,显存不足可减小为 16、8)。
    • shuffle=True:每个 Epoch 前打乱训练集顺序,避免模型过拟合到数据顺序(若不打乱,模型可能记住样本顺序,而非学习文本语义)。
    • drop_last=True:删除最后一个不完整的批次(若训练集总数无法被 32 整除,丢弃最后一个不足 32 个样本的批次),保证每个批次的样本数一致,避免训练报错。
    • collate_fn=collate_fn:指定自定义的批量数据处理函数,DataLoader每次获取批量样本后,会调用该函数进行处理。

第五步:主程序训练逻辑(核心闭环)

if __name__ == '__main__':内的代码是完整的训练闭环,也是这段代码的核心,实现了从模型初始化到参数保存的全流程。

复制代码
if __name__ == '__main__':
    # 1. 实例化模型并移至指定设备
    model = Model()
    model = model.to(DEVICE)
    print(f"模型已移动到 {DEVICE}")
    
    # 2. 初始化优化器(AdamW)
    optimizer = AdamW(model.parameters(),lr=5e-4)
    print("  优化器创建成功")
    
    # 3. 定义损失函数(交叉熵损失)
    loss_func = torch.nn.CrossEntropyLoss()
    
    # 4. 开启模型训练模式
    model.train()
    
    # 5. 多轮训练循环(Epoch)
    for epoch in range(EPOCH):
        # 6. 批量迭代训练数据(Batch)
        for i,(input_ids,attention_mask,token_type_ids,labels) in enumerate(train_laoder):
            # 7. 数据移至指定设备(与模型同一设备)
            input_ids, attention_mask, token_type_ids, labels = input_ids.to(DEVICE),attention_mask.to(DEVICE),token_type_ids.to(DEVICE),labels.to(DEVICE)
            
            # 8. 前向传播:获取模型预测结果
            out = model(input_ids, attention_mask, token_type_ids)
            print(f"  前向计算成功,输出形状: {out.shape}")
            
            # 9. 计算批次损失
            loss = loss_func(out, labels)
            print(f"  损失计算成功: {loss.item():.4f}")
            
            # 10. 反向传播三部曲:清零梯度→计算梯度→更新参数
            optimizer.zero_grad()  # 清零梯度(PyTorch梯度会累积)
            loss.backward()        # 反向传播:计算损失对可训练参数的梯度
            optimizer.step()       # 优化器更新参数:根据梯度调整分类头权重
            
            # 11. 调试:检查是否存在有效梯度
            grad_found = False
            for name, param in model.named_parameters():
                if param.grad is not None:
                    grad_found = True
                    grad_norm = param.grad.norm().item()
                    print(f"  {name} 梯度范数: {grad_norm:.6f}")
                    break
            if not grad_found:
                print("  警告: 没有找到梯度")
            
            # 12. 每5个批次打印训练进度(损失+准确率)
            if i%5 ==0:
                out = out.argmax(dim=1)  # 提取概率最大值对应的分类索引(0/1)
                acc = (out == labels).sum().item()/len(labels)  # 计算批次准确率
                print(epoch,i,loss.item(),acc)
        
        # 13. 每个Epoch结束后保存模型参数
        torch.save(model.state_dict(),f"params/{epoch}bert.pt")
        print(epoch,"参数保存成功!")
逐步骤解释(训练闭环核心):
  1. 模型实例化与设备迁移

    • model = Model():实例化你自定义的二分类模型(包含 BERT 主干 + 全连接层分类头)。
    • model = model.to(DEVICE):将模型所有参数移至指定设备(GPU/CPU),这是必须步骤,否则后续模型计算会报错。
  2. 初始化优化器AdamW

    • optimizer = AdamW(model.parameters(), lr=5e-4)
    • model.parameters():获取模型中所有「可训练参数」(此处仅全连接层self.fc的参数,因为 BERT 主干已被torch.no_grad()冻结)。
    • lr=5e-4:学习率,控制参数更新的步长(Transformer 模型微调常用5e-5~1e-4,步长过大会导致训练不稳定,过小会导致训练缓慢)。
    • 为什么用AdamW?:相比普通AdamAdamW修正了权重衰减的计算逻辑,更适合 Transformer 模型,能有效防止过拟合。
  3. 定义损失函数CrossEntropyLoss

    • 适用于分类任务的经典损失函数,自动整合了log_softmax(概率归一化)和nll_loss(负对数似然损失),无需手动对模型输出做额外处理。
    • 输入要求:模型输出out(形状[32,2])和真实标签labels(形状[32]LongTensor类型)。
  4. 开启训练模式model.train()

    • 启用模型中的「训练相关层」(如DropoutBatchNorm),虽然你的模型中没有这些层,但这是 PyTorch 训练的规范操作,为后续扩展模型(添加 Dropout 防止过拟合)做准备。
    • 对应评估模式model.eval():验证 / 推理时使用,关闭Dropout等层,保证结果稳定。
  5. 外层Epoch循环(多轮训练)

    • 遍历30个 Epoch,每个 Epoch 完整遍历一次训练集,目的是让模型充分学习数据特征。
    • 每个 Epoch 结束后保存模型参数,方便后续选择「损失最低、准确率最高」的模型进行推理。
  6. 内层Batch循环(批量迭代)

    • enumerate(train_laoder):遍历DataLoader,返回批次索引i和处理后的批量数据(4 个张量)。
    • 数据设备迁移:将input_idsattention_masktoken_type_idslabels全部移至DEVICE保证与模型在同一设备,这是避免报错的关键步骤。
  7. 前向传播(模型预测)

    • out = model(input_ids, attention_mask, token_type_ids):调用模型的forward()方法,得到分类概率输出。
    • 输出形状out.shape[32, 2]batch_size=32,二分类),每个元素是对应类别的概率(总和为 1)。
  8. 损失计算

    • loss = loss_func(out, labels):计算该批次的平均损失,损失值越小表示模型预测结果越接近真实标签。
    • loss.item():提取张量中的标量值(去掉张量封装),方便打印和后续分析(避免打印多余的张量信息)。
  9. 反向传播三部曲(核心:参数更新) :这是 PyTorch 模型训练的核心流程,实现模型参数的自动优化:

    • optimizer.zero_grad()清零优化器的梯度缓存(PyTorch 中梯度会累积,若不清零,会导致后续批次的梯度叠加,影响参数更新)。
    • loss.backward()反向传播,从损失值出发,计算损失对所有可训练参数(此处为全连接层参数)的梯度(梯度表示参数更新的方向和幅度)。
    • optimizer.step()优化器更新参数 ,根据计算出的梯度,按照AdamW的优化逻辑,调整可训练参数的取值,让损失值逐渐降低。
  10. 梯度调试(可选)

    • 遍历模型参数,检查是否存在有效梯度,目的是确认模型参数是否在正常更新
    • 若显示「没有找到梯度」,可能的原因:BERT 主干冻结且分类头参数不可训练、输入输出不匹配、学习率为 0 等,需要排查问题。
  11. 训练进度监控(每 5 个批次打印)

    • out.argmax(dim=1):取概率分布中最大值对应的索引(01),作为模型的最终分类预测结果(如[0.05, 0.95]对应索引1,即正面情感)。
    • acc = (out == labels).sum().item()/len(labels):计算批次准确率,即预测正确的样本数占总样本数的比例,准确率越高表示模型效果越好。
    • 每 5 个批次打印一次,避免输出过多信息,方便监控训练趋势(损失是否逐渐降低,准确率是否逐渐提升)。
  12. 模型参数保存

    • torch.save(model.state_dict(), f"params/{epoch}bert.pt"):保存模型的「状态字典」(state_dict),即模型中所有可训练参数的键值对(此处仅全连接层参数)。
    • state_dict是 PyTorch 保存模型的标准方式,后续可通过model.load_state_dict(torch.load("xxx.pt"))加载模型参数,进行推理或继续训练。
    • 注意:需要提前创建params文件夹,否则会报「文件不存在」错误。

补充:关键细节与优化点

  1. 代码小优化loss_func = torch.nn.CrossEntropyLoss()无需在Batch循环内重复实例化,在循环外定义一次即可,减少冗余。
  2. 显存不足解决 :减小batch_size(如从 32 改为 16)、减小max_length(如从 350 改为 256)、使用 CPU 训练(速度较慢)。
  3. 过拟合预防:后续可添加验证集,监控验证集损失和准确率,若验证集损失上升,说明模型过拟合,可提前停止训练(早停法)。
  4. 梯度裁剪 :若梯度范数过大(如超过 10),可添加torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0),防止梯度爆炸。

总结

  1. 核心流程:配置设备→批量数据处理→初始化模型 / 优化器 / 损失函数→多轮训练(前向→损失→反向→更新)→监控进度→保存参数,形成完整的训练闭环。
  2. 关键细节:collate_fn统一输入尺寸,to(DEVICE)保证设备一致,反向传播三部曲实现参数更新,state_dict保存模型参数。
  3. 核心价值:实现了基于 BERT 的文本二分类模型的可落地训练,输出可监控的训练指标和可复用的模型参数,为后续推理和优化提供基础。

模块核心价值

该模块实现了端到端的模型训练流程,涵盖了数据预处理、组件初始化、参数迭代、监控保存等所有核心功能,形成了完整的训练闭环。代码逻辑清晰,注释详尽,能够让入门者直观理解模型训练的底层逻辑,同时具备良好的可修改性,可根据实际任务需求调整超参数、优化器、损失函数等组件。

五、训练核心原理总结与优化方向

1. 核心原理回顾

本次模型训练的核心逻辑是 "复用预训练知识,轻量化迭代优化":

  • 数据层:通过标准化数据集接口,为训练提供统一、高质量的样本输入。
  • 模型层:复用 BERT 预训练模型的中文语义提取能力,仅训练下游全连接分类层,降低训练成本。
  • 训练层:通过 "前向计算→损失计算→反向传播→参数更新" 的闭环,实现模型参数的迭代优化,让模型逐步学习文本情感特征。

2. 后续优化方向(基于现有代码,不额外新增代码)

  • 解冻 BERT 部分层微调:当前代码冻结了全部 BERT 层,可尝试解冻最后 2~3 层 BERT 参数参与训练,让模型学习到更贴合情感分类任务的语义特征,提升分类精度。
  • 增加早停策略:在训练中监控验证集损失,若验证集损失连续多轮上升,停止训练,避免模型过拟合。
  • 超参数调优:调整BATCH_SIZELEARNING_RATE、训练轮次等超参数,或尝试不同的优化器(如 Adam、SGD),寻找最优参数组合。
  • 数据增强:对训练集文本进行同义词替换、语句重组等增强操作,增加数据多样性,提升模型泛化能力。

六、全文总结

本文基于一套模块化的完整代码,详细解析了基于 BERT 的中文文本情感分类模型训练全流程,从数据加载到模型构建,再到端到端训练闭环,每个模块都提供了完整代码与详细解释,帮助读者不仅能看懂代码,更能理解背后的技术逻辑。

这套训练框架兼具轻量性与实用性,既适合 NLP 入门者学习模型训练的核心流程,也适合小规模中文情感分类任务的落地。掌握这套框架的核心逻辑,将为后续更复杂的 NLP 任务(如文本摘要、命名实体识别)学习与落地奠定坚实的基础。

相关推荐
Piar1231sdafa3 小时前
蓝莓目标检测——改进YOLO11-C2TSSA-DYT-Mona模型实现
人工智能·目标检测·计算机视觉
愚公搬代码3 小时前
【愚公系列】《AI短视频创作一本通》002-AI引爆短视频创作革命(短视频创作者必备的能力)
人工智能
数据猿视觉3 小时前
新品上市|奢音S5耳夹耳机:3.5g无感佩戴,178.8元全场景适配
人工智能
我有酒两杯3 小时前
引导模型生成具有反思和验证机制的response的指令
深度学习
蚁巡信息巡查系统3 小时前
网站信息发布再巡查机制怎么建立?
大数据·人工智能·数据挖掘·内容运营
AI浩3 小时前
C-RADIOv4(技术报告)
人工智能·目标检测
Purple Coder3 小时前
AI赋予超导材料预测论文初稿
人工智能
Data_Journal3 小时前
Scrapy vs. Crawlee —— 哪个更好?!
运维·人工智能·爬虫·媒体·社媒营销
云边云科技_云网融合3 小时前
AIoT智能物联网平台:架构解析与边缘应用新图景
大数据·网络·人工智能·安全
康康的AI博客3 小时前
什么是API中转服务商?如何低成本高稳定调用海量AI大模型?
人工智能·ai