在自然语言处理(NLP)领域,中文文本情感分类是一项兼具实用价值与技术代表性的任务,广泛应用于电商评论分析、舆情监控、客户反馈处理等场景。传统机器学习模型难以捕捉中文的深层语义与语境信息,而预训练语言模型的出现,为该任务提供了高效、高精度的解决方案。
本文将基于一套模块化的完整训练代码,从数据加载、模型构建、端到端训练三个核心环节入手,结合代码片段与详细解释,全方位拆解基于 BERT 的中文文本情感分类模型训练全流程,帮助读者不仅能看懂代码,更能理解背后的技术逻辑。
一、任务概述
本次训练的核心任务是中文文本二分类情感分析,即输入一段中文评价文本(如酒店、餐厅评论),模型输出该文本对应的情感倾向(正面 / 负面,分别用标签 1/0 表示)。
核心依赖与数据集
- 核心依赖:PyTorch(模型构建与训练)、Transformers(BERT 预训练模型加载与分词)。
- 数据集:ChnSentiCorp 中文情感分析数据集,已按
train(训练集)、validation(验证集)、test(测试集)完成拆分,满足模型训练的基本数据要求。 - 代码架构:采用模块化设计,分为 3 个核心文件,职责明确、逻辑解耦:
MyData.py:标准化数据集加载模块。net.py:基于 BERT 的下游分类网络构建模块。trainer.py:端到端训练流程落地模块。
下载训练模型
from datasets import load_from_disk, load_dataset
from huggingface_hub import snapshot_download
# 指定镜像站地址(国内推荐)
repo_id = "google-bert/bert-base-chinese"
local_dir = "./bert-base-chinese" # 本地保存路径
# 下载数据集
snapshot_download(
repo_id=repo_id,
repo_type="model",
local_dir=local_dir,
local_dir_use_symlinks=False,
endpoint="https://hf-mirror.com" # 使用国内镜像
)
print(f"模型已下载到: {local_dir}")
二、模块一:数据加载(MyData.py)------ 标准化数据接口实现
数据是模型训练的源头,高质量的数据加载逻辑是保证训练顺利进行的前提。该模块的核心目标是实现 PyTorch 生态下的自定义数据集加载,为后续训练提供统一、可复用的样本获取接口,解耦数据读取与训练逻辑。
完整代码
# 从PyTorch的工具包中导入Dataset抽象类,自定义数据集必须继承该类并实现核心方法
from torch.utils.data import Dataset
# 从Hugging Face的datasets库中导入从本地磁盘加载数据集的函数
from datasets import load_from_disk
# 定义自定义数据集类,继承自torch.utils.data.Dataset
class Mydataset(Dataset):
# 初始化数据(类的构造方法),用于加载和筛选数据集
def __init__(self, split):
"""
自定义数据集的初始化方法
Args:
split (str): 数据集划分标识,可选值为 "train"(训练集)、"validation"(验证集)、"test"(测试集)
"""
# 从本地指定磁盘路径加载已保存的Hugging Face格式数据集(ChnSentiCorp是中文情感分析数据集)
self.dataset = load_from_disk(r"D:\pyprojecgt\flaskProject\langchainstudy\modelscope\bert\trainstudy\data\ChnSentiCorp")
# 根据传入的split参数,筛选对应的数据集划分
if split == "train":
# 筛选出训练集数据并赋值给实例变量self.dataset
self.dataset = self.dataset["train"]
elif split == "validation":
# 筛选出验证集数据并赋值给实例变量self.dataset
self.dataset = self.dataset["validation"]
elif split == "test":
# 筛选出测试集数据并赋值给实例变量self.dataset
self.dataset = self.dataset["test"]
else:
# 若传入无效的split参数,打印错误提示信息
print("数据集名称错误!请传入有效划分:train/validation/test")
# 获取数据集的样本总数(必须实现的Dataset抽象方法)
def __len__(self):
"""
返回当前数据集划分的样本总数
Returns:
int: 数据集样本数量
"""
# 直接返回加载的数据集划分的长度(即样本总数)
return len(self.dataset)
# 根据索引获取单个样本的定制化数据(必须实现的Dataset抽象方法)
def __getitem__(self, item):
"""
根据索引提取单个样本的文本和标签
Args:
item (int): 样本的索引值(从0开始,不超过__len__()返回的数值-1)
Returns:
tuple: 包含单个样本的(文本内容,情感标签)元组
"""
# 从数据集中根据索引item提取对应样本的文本内容
text = self.dataset[item]["text"]
# 从数据集中根据索引item提取对应样本的情感标签(0表示负面,1表示正面,视数据集标注规则而定)
label = self.dataset[item]["label"]
# 返回处理后的单个样本数据(文本+标签)
return text, label
# 主程序入口,仅在直接运行该脚本时执行(用于测试自定义数据集是否正常工作)
if __name__ == '__main__':
# 实例化自定义数据集类,加载"validation"(验证集)数据
dataset = Mydataset("validation")
# 遍历验证集的所有样本,逐个打印样本内容(文本+标签),验证数据加载是否正常
for data in dataset:
print(data)
代码详细解释
这段代码的核心是自定义一个符合 PyTorch 数据加载规范的数据集类Mydataset ,用于加载本地磁盘上的中文情感分析数据集ChnSentiCorp,并支持按train(训练集)、validation(验证集)、test(测试集)划分数据,最终能通过索引或遍历获取每个样本的「文本内容」和「情感标签」,为后续的深度学习模型(如 BERT)训练做数据准备。
第一步:导入必要的库
# 从PyTorch的工具包中导入Dataset抽象类
from torch.utils.data import Dataset
# 从Hugging Face的datasets库中导入从本地磁盘加载数据集的函数
from datasets import load_from_disk
torch.utils.data.Dataset:这是 PyTorch 中所有自定义数据集的基类(抽象类) ,它规定了自定义数据集必须实现两个核心方法:__len__()和__getitem__(),只有遵循这个规范,后续才能和 PyTorch 的DataLoader(数据加载器,用于批量加载、打乱数据等)配合使用。datasets.load_from_disk:这是 Hugging Facedatasets库提供的函数,用于加载已经提前保存到本地磁盘的 Hugging Face 格式数据集 (通常是通过dataset.save_to_disk()保存的)。这种格式的数据集自带结构化信息,无需手动解析 txt/csv 等文件,使用起来更便捷。
第二步:定义自定义数据集类Mydataset
所有自定义 PyTorch 数据集都必须继承Dataset基类,这里定义的Mydataset就是为了适配ChnSentiCorp数据集的加载和提取。
1. 构造方法:__init__(self, split)
这是类的初始化方法,用于加载数据集、筛选数据划分,在实例化类的时候会自动执行。
def __init__(self, split):
"""
自定义数据集的初始化方法
Args:
split (str): 数据集划分标识,可选值为 "train"(训练集)、"validation"(验证集)、"test"(测试集)
"""
# 1. 从本地指定路径加载Hugging Face格式数据集
self.dataset = load_from_disk(r"D:\pyprojecgt\flaskProject\langchainstudy\modelscope\bert\trainstudy\data\ChnSentiCorp")
# 2. 根据split参数筛选对应的数据集划分
if split == "train":
self.dataset = self.dataset["train"]
elif split == "validation":
self.dataset = self.dataset["validation"]
elif split == "test":
self.dataset = self.dataset["test"]
else:
print("数据集名称错误!请传入有效划分:train/validation/test")
self.dataset:这是类的实例变量 ,整个类的其他方法(__len__、__getitem__)都可以访问它。首先通过load_from_disk加载完整的ChnSentiCorp数据集(该数据集本身已经包含train、validation、test三个划分)。split参数:用于指定要加载的数据集划分,比如传入"validation"就只保留验证集数据,后续操作都只针对这个划分。- 异常提示:如果传入无效的
split值(比如"val"),会打印错误提示,避免程序直接崩溃(新手友好型处理)。
2. 样本数量方法:__len__(self)
这是 PyTorch Dataset 基类要求必须实现 的方法,用于返回当前数据集划分的样本总数。
def __len__(self):
"""
返回当前数据集划分的样本总数
Returns:
int: 数据集样本数量
"""
return len(self.dataset)
- 逻辑非常简单:直接返回
self.dataset的长度(即样本个数)。比如ChnSentiCorp的验证集大概有 1000 个样本,调用len(dataset)(dataset是Mydataset实例)时,就会自动调用这个方法,返回 1000。 - 作用:后续
DataLoader加载数据时,需要通过这个方法知道数据集的总长度,从而确定遍历的边界。
3. 单个样本提取方法:__getitem__(self, item)
这也是 PyTorch Dataset 基类要求必须实现 的核心方法,用于根据索引item提取并返回单个样本的处理后数据。
def __getitem__(self, item):
"""
根据索引提取单个样本的文本和标签
Args:
item (int): 样本的索引值(从0开始,不超过__len__()返回的数值-1)
Returns:
tuple: 包含单个样本的(文本内容,情感标签)元组
"""
# 1. 提取索引对应的样本文本
text = self.dataset[item]["text"]
# 2. 提取索引对应的样本标签
label = self.dataset[item]["label"]
# 3. 返回处理后的单个样本
return text, label
item参数:是样本的索引值 (整数),从 0 开始,最大不超过__len__()-1(比如数据集有 1000 个样本,item的取值范围是 0~999)。self.dataset[item]["text"]:Hugging Face 格式的数据集支持「索引 + 键名」的方式提取数据,"text"是数据集里存储文本内容的键,"label"是存储情感标签的键(通常 0 表示负面情感,1 表示正面情感,由ChnSentiCorp数据集的标注规则决定)。- 返回值:以 ** 元组
(text, label)** 的形式返回单个样本,这样后续遍历或批量加载时,能很方便地拆分文本和标签。
第三步:主程序测试(if __name__ == '__main__')
这部分代码的作用是测试自定义数据集是否能正常工作,只有在直接运行该脚本时才会执行(如果被其他脚本导入,这部分不会执行)。
if __name__ == '__main__':
# 1. 实例化Mydataset类,加载验证集(validation)
dataset = Mydataset("validation")
# 2. 遍历数据集的所有样本,逐个打印
for data in dataset:
print(data)
- 实例化类:
dataset = Mydataset("validation"),此时会调用__init__方法,加载验证集数据。 - 遍历数据集:
for data in dataset,这背后是 PyTorch 的迭代器支持:- 首先调用
__len__()获取数据集总长度。 - 然后自动生成从 0 到
len(dataset)-1的索引,逐个传入__getitem__()方法。 - 把
__getitem__()返回的(text, label)元组赋值给data,并打印。
- 首先调用
- 预期输出:会逐个打印验证集中的样本,格式类似
("这个商品质量很好,推荐购买", 1)、("物流太慢了,体验很差", 0)。
补充:关键细节和使用场景
- 为什么要自定义 Dataset?
- PyTorch 的
DataLoader只能加载符合Dataset规范的数据集。 - 实际项目中,数据集的格式、需要提取的字段(这里是
text和label)各不相同,自定义Dataset可以灵活适配不同的数据集需求。
- PyTorch 的
- 后续扩展 :这段代码只提取了原始文本和标签,后续训练 BERT 等模型时,还需要在
__getitem__()中添加文本编码 (比如用BertTokenizer把文本转换成 token id、attention mask 等),才能输入到模型中。 - 路径注意事项 :代码中的数据集路径是绝对路径(
r"D:\pyprojecgt\..."),如果后续移动脚本或数据集,需要修改该路径,也可以改成相对路径提高可移植性。
总结
- 核心结构:自定义
Mydataset继承torch.utils.data.Dataset,必须实现__init__(加载筛选)、__len__(返回样本数)、__getitem__(提取单个样本)三个方法。 - 数据流转:
load_from_disk加载本地数据集 → 按split筛选划分 → 按索引提取(text, label)样本。 - 核心作用:为 PyTorch 模型训练提供符合规范的数据源,后续可配合
DataLoader实现批量加载、数据打乱等功能。
模块核心价值
该模块将繁琐的数据集读取、格式解析、索引映射等工作封装在独立模块中,后续训练流程无需关注数据的底层存储格式与读取方式,只需通过Mydataset类即可快速获取标准化样本,极大提升了代码的可维护性与可复用性。同时,若后续需要更换数据集,只需修改__init__方法中的数据读取逻辑,无需改动训练核心代码,具备良好的扩展性。
三、模块二:网络模型构建(net.py)------ 基于 BERT 的下游分类网络搭建
模型是训练的核心载体,本次训练依托bert-base-chinese预训练模型,搭建轻量级下游情感分类网络,充分复用预训练模型的中文语义提取能力,降低下游任务的训练成本与难度,兼顾训练效率与分类精度。
完整代码
# 分类系统案例:自然语言任务分类(文本生成、文本分类、问答系统)
# 从transformers库中导入自动分词器类,用于处理文本数据(分词、转token id等)
from transformers import (
AutoTokenizer,
)
# 从transformers库中导入BERT预训练模型类,用于加载预训练的BERT主干网络
from transformers import BertModel
# 导入PyTorch核心库,用于搭建神经网络、张量计算等
import torch
# 1. 加载本地预训练模型和对应的分词器
# 定义本地bert-base-chinese模型的存储路径(该模型是中文通用预训练BERT模型)
model_dir = "D:\\本地模型\\google-bert\\bert-base-chinese"
# 加载分词器:AutoTokenizer会自动匹配预训练模型的分词规则,从本地路径加载
# 作用:将原始中文文本转换为BERT模型能够识别的输入格式(token id、attention mask等)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
# 定义模型训练/推理使用的设备(优先使用GPU,无GPU则使用CPU)
# 以下是注释掉的DirectML设备配置(用于支持AMD显卡加速)
# import torch_directml as dml
# DEVICE = dml.device() if dml.is_available() else torch.device("cpu")
# 优先使用NVIDIA CUDA显卡加速,若未安装CUDA或无NVIDIA显卡,则使用CPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载BERT预训练主干模型:从本地路径加载bert-base-chinese,并将模型移至指定设备(GPU/CPU)
# pretrained:预训练模型实例,用于提取文本的高维语义特征(不包含下游分类头)
pretrained = BertModel.from_pretrained(model_dir).to(DEVICE)
# 以下两行是注释掉的调试代码,用于查看预训练模型的结构和词嵌入层信息
# print(pretrained) # 打印完整的BERT模型结构(网络层、参数等)
# print(pretrained.embeddings.word_embeddings) # 打印BERT的词嵌入层(将token id转换为词向量)
# 2. 定义下游任务模型(继承PyTorch的Module类)
# 作用:基于BERT预训练模型提取的特征,搭建简单的分类头,完成文本分类任务
class Model(torch.nn.Module):
# 模型初始化方法:定义下游分类任务的网络结构
def __init__(self):
# 调用父类torch.nn.Module的初始化方法,必须保留
super().__init__()
# 定义全连接层(分类头):将BERT提取的768维特征映射到2分类输出
# torch.nn.Linear(in_features, out_features)
# in_features=768:bert-base-chinese模型的隐藏层输出维度固定为768
# out_features=2:分类任务的类别数(此处为2分类,如情感分析的正面/负面)
self.fc = torch.nn.Linear(768, 2)
# 模型前向传播方法:定义数据的流转路径和计算逻辑(必须实现)
# 输入:BERT模型要求的三个核心输入参数(input_ids、attention_mask、token_type_ids)
def forward(self, input_ids, attention_mask, token_type_ids):
# 冻结预训练模型(上游BERT主干网络不参与训练,不更新参数)
# with torch.no_grad():上下文管理器,禁用梯度计算,减少显存占用,提升计算速度
with torch.no_grad():
# 调用预训练BERT模型,传入三个输入参数,获取文本特征输出
# out:BERT模型的返回结果,包含last_hidden_state(最后一层隐藏层输出)等信息
out = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
# 提取有效特征:取last_hidden_state的第0个token(<CLS> token)的768维特征
# BERT模型中,<CLS> token(对应索引0)用于聚合整句文本的语义信息,适合用于分类任务
# out.last_hidden_state.shape 一般为 [batch_size, sequence_length, 768]
# out.last_hidden_state[:, 0] 取所有样本(batch)的第0个token,形状变为 [batch_size, 768]
cls_feature = out.last_hidden_state[:, 0]
# 下游分类任务:将<CLS> token的768维特征输入全连接层,映射到2维输出
out = self.fc(cls_feature)
# 对全连接层输出进行softmax激活,将结果转换为概率分布(各分类的概率值,总和为1)
# dim=1:在分类维度(第1维,对应2个类别)上进行softmax计算
out = out.softmax(dim=1)
# 返回最终的分类概率结果
return out
代码详细解释
第一步:导入必要的库
# 分类系统案例:自然语言任务分类(文本生成、文本分类、问答系统)
from transformers import (
AutoTokenizer,
)
# 从transformers库导入BERT预训练模型类
from transformers import BertModel
# 导入PyTorch核心库,用于搭建神经网络、张量计算、设备管理等
import torch
AutoTokenizer:Hugging Facetransformers库提供的「自动分词器」类,用于将原始中文文本转换为 BERT 模型能够识别的输入格式(包括分词、转换为 token id、生成注意力掩码等),且能自动匹配对应预训练模型的分词规则。BertModel:BERT 模型的主干网络类,不包含下游任务的分类头,仅用于提取文本的语义特征,是预训练模型的核心部分。torch:PyTorch 深度学习框架的核心库,用于搭建自定义分类头、张量计算、设备配置等。
第二步:加载本地预训练模型 / 分词器 & 配置计算设备
# 1. 加载本地模型和分词器
# 定义本地中文BERT模型的存储路径(bert-base-chinese是中文通用预训练模型)
model_dir = "D:\\本地模型\\google-bert\\bert-base-chinese"
# 从本地路径加载分词器,与预训练模型配套使用
tokenizer = AutoTokenizer.from_pretrained(model_dir)
# 定义模型计算设备(优先使用GPU加速,无GPU则使用CPU)
# 注释掉的是AMD显卡加速配置(torch_directml),新手可暂时忽略
# import torch_directml as dml
# DEVICE = dml.device() if dml.is_available() else torch.device("cpu")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 从本地路径加载BERT主干预训练模型,并将模型移至指定设备(GPU/CPU)
pretrained = BertModel.from_pretrained(model_dir).to(DEVICE)
# 注释掉的是调试代码,用于查看模型结构和词嵌入层信息
# print(pretrained)
# print(pretrained.embeddings.word_embeddings)
- 本地模型路径
model_dir:这里存放的是bert-base-chinese预训练模型的所有文件(包括配置文件、权重文件等),通过from_pretrained(model_dir)可直接加载,无需重新从网络下载。 - 分词器加载
tokenizer = AutoTokenizer.from_pretrained(model_dir):必须与预训练模型配套加载,保证分词规则、词汇表与 BERT 模型一致,否则会出现输入不匹配的错误。 - 计算设备配置
DEVICE:torch.cuda.is_available():判断是否有可用的 NVIDIA GPU(需安装 CUDA、cuDNN)。- 模型移至设备
.to(DEVICE):将 BERT 模型的所有参数加载到指定设备(GPU/CPU),后续所有计算都将在该设备上进行,避免「张量与模型不在同一设备」的报错。
- 调试代码注释:
print(pretrained)可打印 BERT 的完整网络结构(包含嵌入层、多层 Transformer 编码器),print(pretrained.embeddings.word_embeddings)可查看词嵌入层(将 token id 转换为 768 维词向量)。
第三步:定义下游任务分类模型(Model类)
这是代码的核心,继承自torch.nn.Module(PyTorch 所有神经网络的基类),搭建用于二分类的「分类头」,并实现数据的前向传播逻辑。
1. 初始化方法__init__(self):定义网络结构
class Model(torch.nn.Module):
# 模型结构设计
def __init__(self):
# 调用父类torch.nn.Module的初始化方法,必须保留,否则模型无法正常工作
super().__init__()
# 定义全连接层(分类头):将768维特征映射到2分类输出
self.fc = torch.nn.Linear(768, 2)
super().__init__():必须调用父类的初始化方法,用于初始化torch.nn.Module的内置属性(如模型参数、设备信息等),缺少这行代码会导致后续模型训练 / 推理报错。self.fc = torch.nn.Linear(768, 2):定义一个单层全连接层(作为分类头),核心参数说明:in_features=768:输入特征维度,bert-base-chinese模型的隐藏层输出维度固定为 768,这是预训练模型的固有属性,不能随意修改。out_features=2:输出特征维度,对应二分类任务(如情感分析的「正面 / 负面」、文本的「相关 / 不相关」),若要做多分类,只需修改该数值为类别数即可。- 作用:将 BERT 提取的 768 维文本特征,映射到 2 个分类的得分值,完成从「特征提取」到「分类预测」的转换。
2. 前向传播方法forward(...):定义数据流转逻辑
def forward(self,input_ids,attention_mask,token_type_ids):
# 上游任务不参与训练(冻结BERT预训练模型参数)
with torch.no_grad():
out = pretrained(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
# 下游任务参与训练(提取有效特征 + 分类 + 概率转换)
out = self.fc(out.last_hidden_state[:,0])
out = out.softmax(dim=1)
return out
这是模型的核心方法,定义了输入数据如何通过网络层得到输出结果,调用模型实例时会自动执行该方法,关键步骤拆解:
-
冻结 BERT 预训练模型:
with torch.no_grad()torch.no_grad():上下文管理器,禁用该代码块内的梯度计算和参数更新。- 作用:冻结 BERT 主干模型的所有参数,使其在后续训练中保持不变(仅训练自定义的分类头
self.fc)。这样做的好处是:减少显存占用、提升训练速度、避免破坏 BERT 预训练好的通用语义特征,适合新手入门和小数据集训练。 - 补充:后续若想提升模型效果,可解冻 BERT 的部分层(如最后几层 Transformer),进行「联合微调」。
-
调用 BERT 提取文本特征:
out = pretrained(...)- 输入参数:
input_ids、attention_mask、token_type_ids,这是 BERT 模型要求的三个核心输入(由tokenizer处理原始文本后生成),各自作用: input_ids:原始文本分词后,每个 token 对应的数字编码(张量格式)。attention_mask:标记有效 token(1)和填充 token(0),让 BERT 忽略填充的无效部分。token_type_ids:区分句子对(如问答任务中的问题和答案),单句文本分类任务中该张量全为 0。- 输出结果
out:BERT 模型的返回对象,包含多个属性,其中last_hidden_state是我们需要的「最后一层隐藏层输出」,形状为[batch_size, sequence_length, 768]: batch_size:批次大小(一次处理的样本数)。sequence_length:每个样本的序列长度(分词后的 token 数量,固定为同一长度)。768:每个 token 的特征维度。
- 输入参数:
-
提取有效分类特征:
out.last_hidden_state[:, 0]- BERT 模型在处理文本时,会在每个句子开头添加一个特殊 token
<CLS>(对应索引 0),该 token 的特征会聚合整个句子的语义信息,是专门为分类任务设计的有效特征。 out.last_hidden_state[:, 0]:取所有样本(:对应batch_size)的第 0 个 token(<CLS>)特征,形状从[batch_size, sequence_length, 768]转换为[batch_size, 768],刚好匹配全连接层self.fc的输入维度。
- BERT 模型在处理文本时,会在每个句子开头添加一个特殊 token
-
分类头映射:
out = self.fc(...)- 将
<CLS>token 的 768 维特征输入全连接层self.fc,得到形状为[batch_size, 2]的输出,此时的输出是「未归一化的得分值(logits)」,还不能直接作为分类概率。
- 将
-
概率转换:
out = out.softmax(dim=1)softmax:激活函数,用于将多分类的得分值转换为概率分布(各分类概率之和为 1)。dim=1:指定在「分类维度」(第 1 维,对应 2 个类别)上进行 softmax 计算,最终每个样本会得到两个概率值(如[0.05, 0.95]),分别对应两个类别的预测概率。
-
返回结果:
return out- 返回形状为
[batch_size, 2]的分类概率张量,后续可通过「取概率最大值对应的索引」得到最终的分类结果(如torch.argmax(out, dim=1))。
- 返回形状为
补充:关键细节与后续扩展
- 为什么选择
<CLS>token?- 除了
<CLS>token,也可以取所有 token 特征的平均值作为句子特征,但 BERT 设计时<CLS>token 专门用于聚合整句语义,在分类任务上效果更稳定、更常用。
- 除了
- 后续需要补充的步骤
- 这段代码仅搭建了模型结构,还需要完成:数据加载(结合之前的
Mydataset)、文本编码(用tokenizer处理原始文本生成三个输入参数)、定义损失函数(如CrossEntropyLoss)、优化器(如AdamW)、训练循环、验证与推理,才能完成完整的文本分类任务。
- 这段代码仅搭建了模型结构,还需要完成:数据加载(结合之前的
- 多分类扩展
- 若要实现 3 分类、10 分类,只需修改全连接层的输出维度:
self.fc = torch.nn.Linear(768, 类别数),softmax(dim=1)会自动适配新的类别数,将得分转换为对应概率分布。
- 若要实现 3 分类、10 分类,只需修改全连接层的输出维度:
总结
- 核心思路:冻结 BERT 主干(提取 768 维语义特征)+ 单层全连接层(二分类映射)+ softmax(概率转换),搭建轻量化文本分类模型。
- 关键步骤:文本转 BERT 输入格式 → 冻结 BERT 提取
<CLS>特征 → 全连接层分类 → softmax 转换为概率。 - 核心价值:为中文文本二分类任务提供了基础模型结构,可直接用于后续的训练和推理,且易于扩展到多分类任务。
模块核心价值
该模块实现了 "预训练模型复用 + 下游任务轻量化" 的网络构建思路,既利用了 BERT 的强大语义提取能力,保证了分类任务的精度,又通过冻结预训练层降低了训练成本,让模型在普通硬件设备上也能高效训练,非常适合入门者学习与小规模 NLP 任务落地。
四、模块三:训练主程序(trainer.py)------ 端到端训练流程落地
如果说数据加载与模型构建模块是 "基石",那么训练主程序模块就是 "粘合剂"。该模块整合了前两个模块的成果,配置了各类训练组件,实现了 "前向计算→损失计算→反向传播→参数更新" 的训练闭环,同时包含训练监控与模型保存功能,落地完整的训练流程。
完整代码
# 导入PyTorch核心库,用于张量计算、模型训练、设备管理等
import torch
# 导入自定义的数据集类Mydataset(用于加载本地ChnSentiCorp数据集)
from MyData import Mydataset
# 导入PyTorch的数据加载器DataLoader,用于批量加载、打乱数据集
from torch.utils.data import DataLoader
# 导入自定义的下游分类模型Model(基于BERT的二分类模型)
from net import Model
# 导入BERT分词器,用于文本编码处理
from transformers import BertTokenizer
# 导入AdamW优化器(针对Transformer模型优化的Adam变体,常用在预训练模型微调)
from torch.optim import AdamW
# 定义模型训练/推理使用的设备(优先使用NVIDIA CUDA GPU,无GPU则使用CPU)
# 以下是注释掉的AMD显卡加速配置(torch_directml),新手可暂时忽略
# import torch_directml as dml
# DEVICE = dml.device() if dml.is_available() else torch.device("cpu")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 打印当前使用的计算设备,方便确认是否成功启用GPU
print(f"使用设备: {DEVICE}")
# 定义模型训练的总轮次(Epoch):完整遍历训练集的次数
EPOCH = 30
# 定义本地中文BERT预训练模型的存储路径
model_name = r'D:\本地模型\google-bert\bert-base-chinese'
# 从本地路径加载BERT分词器,与预训练模型配套使用,用于文本编码
token = BertTokenizer.from_pretrained(model_name)
# 自定义数据批处理函数collate_fn:用于DataLoader中处理批量样本,完成文本编码和格式转换
# 为什么需要collate_fn?
# 1. 神经网络要求输入尺寸固定,但原始文本长度各不相同,需要统一填充到相同长度
# 2. 批量处理可利用GPU并行计算,大幅提升训练效率,避免单样本处理的低效问题
# 3. 集中管理所有数据预处理逻辑(分词、填充、张量转换),让代码更整洁,减少重复
def collate_fn(data):
# 1. 从批量样本中分离文本内容和对应的标签(data是Mydataset返回的(text, label)元组列表)
sentes = [i[0] for i in data] # 提取所有样本的文本,组成列表:["文本1", "文本2", ..., "文本N"]
label = [i[1] for i in data] # 提取所有样本的标签,组成列表:[标签1, 标签2, ..., 标签N](0或1,二分类)
# 2. 使用BERT分词器对批量文本进行编码处理,转换为模型可识别的输入格式
data = token.batch_encode_plus(
batch_text_or_text_pairs=sentes, # 待编码的批量文本列表
truncation=True, # 开启截断:若文本长度超过max_length,截断到max_length
padding="max_length", # 开启填充:将所有文本填充到max_length指定的长度
max_length=350, # 文本的最大序列长度(统一输入尺寸,根据数据集调整)
return_tensors="pt", # 返回PyTorch的张量(Tensor)格式,方便后续模型计算
return_length=True, # 返回每个文本编码后的实际长度(可选,此处未使用,保留用于调试)
)
# 3. 从编码结果中提取BERT模型需要的三个核心输入张量
input_ids = data["input_ids"] # 文本分词后的数字编码张量,形状:[batch_size, max_length]
attention_mask = data["attention_mask"]# 注意力掩码张量:标记有效token(1)和填充token(0),形状:[batch_size, max_length]
token_type_ids = data["token_type_ids"]# 句子对分隔张量:单句分类任务中全为0,形状:[batch_size, max_length]
# 4. 将标签列表转换为PyTorch的长整型张量(LongTensor),适配CrossEntropyLoss损失函数
labels = torch.LongTensor(label)
# 5. 返回处理后的批量输入数据和标签,供模型训练使用
return input_ids, attention_mask, token_type_ids, labels
# 6. 创建训练集数据集实例:加载Mydataset中的"train"划分数据
train_dataset = Mydataset("train")
# 7. 创建训练集数据加载器(DataLoader):用于批量迭代训练数据
train_laoder = DataLoader(
dataset=train_dataset, # 传入已创建的训练集数据集
batch_size=32, # 批次大小:每次迭代处理32个样本(根据GPU显存调整,显存不足可减小)
shuffle=True, # 开启数据打乱:每个Epoch前打乱训练集顺序,避免模型过拟合到数据顺序
drop_last=True, # 删除最后一个不完整的批次:若数据集总数无法被batch_size整除,丢弃最后一个不足32个样本的批次
collate_fn=collate_fn # 指定自定义的批处理函数,处理批量样本的编码和格式转换
)
# 主程序入口:仅在直接运行该脚本时执行训练逻辑(被其他脚本导入时不执行)
if __name__ == '__main__':
# 打印设备信息,确认训练设备
print(DEVICE)
# 1. 实例化自定义的下游分类模型(基于BERT的二分类模型)
model = Model()
print(f"\n3. 移动模型到 {DEVICE}...")
# 2. 关键步骤:将模型移至指定计算设备(GPU/CPU),后续模型计算均在该设备上进行
# 避免「模型参数与输入张量不在同一设备」的报错
model = model.to(DEVICE)
print(" 模型移动成功")
# 3. 初始化AdamW优化器:用于更新模型参数(仅更新自定义分类头的参数,BERT主干已冻结)
# model.parameters():获取模型中可训练的参数(此处仅全连接层fc的参数)
# lr=5e-4:学习率(控制参数更新的步长,Transformer模型微调常用5e-5~1e-4)
optimizer = AdamW(model.parameters(), lr=5e-4)
print(" 优化器创建成功")
# 4. 定义损失函数:交叉熵损失(CrossEntropyLoss),适用于分类任务
# 自动整合了log_softmax和nll_loss,适合直接处理模型输出的概率分布和标签
loss_func = torch.nn.CrossEntropyLoss()
# 5. 开启模型训练模式:启用Dropout、BatchNorm等训练相关层(此处模型无此类层,仅为规范操作)
model.train()
# 6. 开始训练循环:遍历所有Epoch
for epoch in range(EPOCH):
# 7. 遍历DataLoader中的批量数据:enumerate返回批次索引i和批量数据
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(train_laoder):
# 8. 关键步骤:将批量输入数据和标签移至指定计算设备(与模型在同一设备)
input_ids, attention_mask, token_type_ids, labels = input_ids.to(DEVICE), attention_mask.to(DEVICE), token_type_ids.to(DEVICE), labels.to(DEVICE)
# 9. 模型前向传播:传入三个核心输入张量,得到分类概率输出
out = model(input_ids, attention_mask, token_type_ids)
print(f" 前向计算成功,输出形状: {out.shape}") # 输出形状应为 [32, 2](batch_size=32,二分类)
# 10. 重新实例化损失函数(注:此处可优化,无需在循环内重复实例化,外部定义一次即可)
loss_func = torch.nn.CrossEntropyLoss()
# 11. 计算批次损失:传入模型输出out和真实标签labels
loss = loss_func(out, labels)
print(f" 损失计算成功: {loss.item():.4f}") # loss.item()提取张量中的标量值,保留4位小数
# 12. 反向传播前:清零优化器的梯度(PyTorch梯度会累积,需手动清零)
optimizer.zero_grad()
# 13. 反向传播:计算损失对模型可训练参数的梯度
loss.backward()
print(" 反向传播成功")
# 14. 调试:检查是否存在有效梯度(确认模型参数是否在更新)
grad_found = False
for name, param in model.named_parameters():
if param.grad is not None:
grad_found = True
grad_norm = param.grad.norm().item() # 计算梯度的范数,用于观察梯度大小
print(f" {name} 梯度范数: {grad_norm:.6f}")
break # 仅打印第一个有梯度的参数,避免输出过多
if not grad_found:
print(" 警告: 没有找到梯度")
# 15. 优化器更新参数:根据计算出的梯度,更新模型可训练参数(此处为全连接层fc的参数)
optimizer.step()
print(" 参数更新成功")
# 16. 每5个批次打印一次训练信息(Epoch、批次索引、损失值、准确率),方便监控训练进度
if i % 5 == 0:
# 提取预测结果:取概率分布中最大值对应的索引(0或1),作为模型的分类预测
out = out.argmax(dim=1)
# 计算批次准确率:预测正确的样本数 / 总样本数
acc = (out == labels).sum().item() / len(labels)
# 打印训练监控信息
print(f"Epoch: {epoch}, Batch: {i}, Loss: {loss.item():.4f}, Acc: {acc:.4f}")
# 17. 每个Epoch训练完成后,保存模型参数(state_dict),方便后续加载继续训练或推理
# state_dict:保存模型中可训练参数的字典,是PyTorch保存模型的标准方式
torch.save(model.state_dict(), f"params/{epoch}bert.pt")
print(f"Epoch {epoch} 训练完成,参数保存成功!")
代码详细解释
这段代码是端到端的 BERT 文本二分类模型训练代码 ,承接之前搭建的Mydataset(自定义数据集)和Model(自定义分类模型),实现了从「批量数据预处理」到「模型参数保存」的完整训练流程,核心特点是:
- 采用「批量处理」提升训练效率,利用 GPU 并行计算优势。
- 仅训练自定义的全连接层分类头,冻结 BERT 主干模型参数。
- 包含完整的训练闭环:前向计算→损失计算→反向传播→参数更新→进度监控→模型保存。
- 最终输出训练过程中的损失和准确率,并保存每一轮的模型参数,方便后续推理和继续训练。
第一步:导入依赖库 / 模块
import torch
from MyData import Mydataset
from torch.utils.data import DataLoader
from net import Model
from transformers import BertTokenizer
from torch.optim import AdamW
每个导入项的核心作用(呼应之前的代码,形成闭环):
torch:PyTorch 核心库,负责张量计算、模型训练、设备管理、损失函数定义等。Mydataset:你自定义的数据集类,用于加载本地ChnSentiCorp训练集数据,返回(text, label)元组。DataLoader:PyTorch 的批量数据加载器,用于将Mydataset的样本打包成批次,支持打乱、并行加载等功能。Model:你自定义的下游分类模型,基于 BERT 提取特征 + 单层全连接层实现二分类。BertTokenizer:BERT 配套分词器,用于将原始文本转换为模型可识别的张量格式。AdamW:针对 Transformer 模型优化的 Adam 变体优化器(修正了权重衰减逻辑),是预训练模型微调的首选优化器。
第二步:配置计算设备 & 定义基础参数
# 定义训练/推理使用的设备(优先GPU,无GPU则CPU)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {DEVICE}")
# 定义训练总轮次(Epoch):完整遍历训练集的次数
EPOCH = 30
# 本地中文BERT预训练模型路径
model_name = r'D:\本地模型\google-bert\bert-base-chinese'
# 加载与BERT配套的分词器
token = BertTokenizer.from_pretrained(model_name)
-
设备配置
DEVICE:- 核心逻辑:优先使用 NVIDIA CUDA GPU(需安装 CUDA/cuDNN),无 GPU 则降级为 CPU。
- 关键作用:后续模型和所有输入数据都要移至该设备,避免「模型参数与张量不在同一设备」的报错(PyTorch 中,张量和模型必须在同一设备才能进行计算)。
- 打印设备信息:方便你确认是否成功启用 GPU(若显示
cuda:0则表示 GPU 可用,显示cpu则表示使用 CPU)。
-
超参数与分词器:
EPOCH=30:训练轮次,即完整遍历训练集 30 次。轮次过多可能过拟合,过少可能欠拟合,后续可根据验证集效果调整。token = BertTokenizer.from_pretrained(model_name):加载与 BERT 预训练模型配套的分词器,保证分词规则、词汇表与 BERT 一致,否则会出现输入不匹配错误。
第三步:核心自定义函数collate_fn(批量数据处理)
这是训练代码的关键辅助函数 ,用于给DataLoader提供批量数据的处理逻辑,解决「文本长度不一致」和「批量编码效率低」的问题。
def collate_fn(data):
# 1. 分离文本和标签(data是Mydataset返回的(text, label)元组列表)
sentes = [i[0] for i in data]# ["文本1", "文本2", ...]
label = [i[1] for i in data] # [标签1, 标签2, ...]
# 2. 批量编码文本:转换为BERT所需的张量格式
data = token.batch_encode_plus(
batch_text_or_text_pairs=sentes,
truncation=True,
padding="max_length",
max_length=350,
return_tensors="pt", # 返回 PyTorch tensor
return_length=True,
)
# 3. 提取BERT核心输入张量
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
token_type_ids = data["token_type_ids"]
# 4. 标签转换为长整型张量
labels = torch.LongTensor(label)
return input_ids,attention_mask,token_type_ids,labels
逐步骤解释:
-
分离文本和标签:
data是DataLoader传入的「批量样本列表」,每个元素是Mydataset.__getitem__()返回的(text, label)元组。- 用列表推导式分离出文本列表
sentes和标签列表label,方便后续分别处理。
-
批量编码
token.batch_encode_plus():这是BertTokenizer的批量编码方法,比单样本编码效率高 10 倍以上,核心参数解释:batch_text_or_text_pairs=sentes:传入待编码的批量文本列表。truncation=True:开启截断,若文本分词后长度超过max_length=350,则截断到 350 个 token。padding="max_length":开启「填充到最大长度」,将所有文本填充到350个 token,保证批量内所有文本输入尺寸一致(神经网络要求固定尺寸输入)。max_length=350:文本的最大序列长度,可根据你的数据集调整(显存不足可减小,如 128、256)。return_tensors="pt":返回 PyTorch 张量格式,无需手动转换,可直接用于模型计算。return_length=True:返回每个文本的实际编码长度(可选,此处未使用,仅用于调试)。
-
提取 BERT 核心输入张量 :编码后返回的
data是一个字典,包含 BERT 模型必需的 3 个张量,形状均为[batch_size, 350](batch_size=32):input_ids:文本分词后的数字编码(每个 token 对应一个唯一数字)。attention_mask:注意力掩码,1表示有效 token,0表示填充 token,让 BERT 忽略填充部分。token_type_ids:句子对分隔符,单句分类任务中全为0(仅在问答、句对匹配等任务中有用)。
-
标签转换为
torch.LongTensor:- 将标签列表转换为长整型张量,适配
CrossEntropyLoss损失函数的输入要求 (该损失函数要求标签是LongTensor类型,不能是普通列表或浮点型张量)。 - 最终返回 4 个张量,供模型训练使用。
- 将标签列表转换为长整型张量,适配
为什么必须用collate_fn?
- 解决「文本长度不一致」的问题,统一输入尺寸。
- 批量处理利用 GPU 并行计算,提升训练效率。
- 集中管理数据预处理逻辑,让代码更整洁,避免重复编码。
- 若不使用
collate_fn,DataLoader会直接返回原始元组列表,无法直接输入模型。
第四步:创建数据集和DataLoader(批量数据迭代器)
#创建数据集
train_dataset = Mydataset("train")
#创建DataLoader
train_laoder = DataLoader(
dataset=train_dataset,
batch_size=32,
shuffle=True,
drop_last=True,
collate_fn=collate_fn
)
train_dataset = Mydataset("train"):实例化自定义数据集,加载ChnSentiCorp的训练集数据。DataLoader:创建批量数据迭代器,核心参数解释:dataset=train_dataset:传入已加载的训练集。batch_size=32:批次大小,每次迭代处理 32 个样本(根据 GPU 显存调整,显存不足可减小为 16、8)。shuffle=True:每个 Epoch 前打乱训练集顺序,避免模型过拟合到数据顺序(若不打乱,模型可能记住样本顺序,而非学习文本语义)。drop_last=True:删除最后一个不完整的批次(若训练集总数无法被 32 整除,丢弃最后一个不足 32 个样本的批次),保证每个批次的样本数一致,避免训练报错。collate_fn=collate_fn:指定自定义的批量数据处理函数,DataLoader每次获取批量样本后,会调用该函数进行处理。
第五步:主程序训练逻辑(核心闭环)
if __name__ == '__main__':内的代码是完整的训练闭环,也是这段代码的核心,实现了从模型初始化到参数保存的全流程。
if __name__ == '__main__':
# 1. 实例化模型并移至指定设备
model = Model()
model = model.to(DEVICE)
print(f"模型已移动到 {DEVICE}")
# 2. 初始化优化器(AdamW)
optimizer = AdamW(model.parameters(),lr=5e-4)
print(" 优化器创建成功")
# 3. 定义损失函数(交叉熵损失)
loss_func = torch.nn.CrossEntropyLoss()
# 4. 开启模型训练模式
model.train()
# 5. 多轮训练循环(Epoch)
for epoch in range(EPOCH):
# 6. 批量迭代训练数据(Batch)
for i,(input_ids,attention_mask,token_type_ids,labels) in enumerate(train_laoder):
# 7. 数据移至指定设备(与模型同一设备)
input_ids, attention_mask, token_type_ids, labels = input_ids.to(DEVICE),attention_mask.to(DEVICE),token_type_ids.to(DEVICE),labels.to(DEVICE)
# 8. 前向传播:获取模型预测结果
out = model(input_ids, attention_mask, token_type_ids)
print(f" 前向计算成功,输出形状: {out.shape}")
# 9. 计算批次损失
loss = loss_func(out, labels)
print(f" 损失计算成功: {loss.item():.4f}")
# 10. 反向传播三部曲:清零梯度→计算梯度→更新参数
optimizer.zero_grad() # 清零梯度(PyTorch梯度会累积)
loss.backward() # 反向传播:计算损失对可训练参数的梯度
optimizer.step() # 优化器更新参数:根据梯度调整分类头权重
# 11. 调试:检查是否存在有效梯度
grad_found = False
for name, param in model.named_parameters():
if param.grad is not None:
grad_found = True
grad_norm = param.grad.norm().item()
print(f" {name} 梯度范数: {grad_norm:.6f}")
break
if not grad_found:
print(" 警告: 没有找到梯度")
# 12. 每5个批次打印训练进度(损失+准确率)
if i%5 ==0:
out = out.argmax(dim=1) # 提取概率最大值对应的分类索引(0/1)
acc = (out == labels).sum().item()/len(labels) # 计算批次准确率
print(epoch,i,loss.item(),acc)
# 13. 每个Epoch结束后保存模型参数
torch.save(model.state_dict(),f"params/{epoch}bert.pt")
print(epoch,"参数保存成功!")
逐步骤解释(训练闭环核心):
-
模型实例化与设备迁移:
model = Model():实例化你自定义的二分类模型(包含 BERT 主干 + 全连接层分类头)。model = model.to(DEVICE):将模型所有参数移至指定设备(GPU/CPU),这是必须步骤,否则后续模型计算会报错。
-
初始化优化器
AdamW:optimizer = AdamW(model.parameters(), lr=5e-4):model.parameters():获取模型中所有「可训练参数」(此处仅全连接层self.fc的参数,因为 BERT 主干已被torch.no_grad()冻结)。lr=5e-4:学习率,控制参数更新的步长(Transformer 模型微调常用5e-5~1e-4,步长过大会导致训练不稳定,过小会导致训练缓慢)。- 为什么用
AdamW?:相比普通Adam,AdamW修正了权重衰减的计算逻辑,更适合 Transformer 模型,能有效防止过拟合。
-
定义损失函数
CrossEntropyLoss:- 适用于分类任务的经典损失函数,自动整合了
log_softmax(概率归一化)和nll_loss(负对数似然损失),无需手动对模型输出做额外处理。 - 输入要求:模型输出
out(形状[32,2])和真实标签labels(形状[32],LongTensor类型)。
- 适用于分类任务的经典损失函数,自动整合了
-
开启训练模式
model.train():- 启用模型中的「训练相关层」(如
Dropout、BatchNorm),虽然你的模型中没有这些层,但这是 PyTorch 训练的规范操作,为后续扩展模型(添加 Dropout 防止过拟合)做准备。 - 对应评估模式
model.eval():验证 / 推理时使用,关闭Dropout等层,保证结果稳定。
- 启用模型中的「训练相关层」(如
-
外层
Epoch循环(多轮训练):- 遍历
30个 Epoch,每个 Epoch 完整遍历一次训练集,目的是让模型充分学习数据特征。 - 每个 Epoch 结束后保存模型参数,方便后续选择「损失最低、准确率最高」的模型进行推理。
- 遍历
-
内层
Batch循环(批量迭代):enumerate(train_laoder):遍历DataLoader,返回批次索引i和处理后的批量数据(4 个张量)。- 数据设备迁移:将
input_ids、attention_mask、token_type_ids、labels全部移至DEVICE,保证与模型在同一设备,这是避免报错的关键步骤。
-
前向传播(模型预测):
out = model(input_ids, attention_mask, token_type_ids):调用模型的forward()方法,得到分类概率输出。- 输出形状
out.shape为[32, 2](batch_size=32,二分类),每个元素是对应类别的概率(总和为 1)。
-
损失计算:
loss = loss_func(out, labels):计算该批次的平均损失,损失值越小表示模型预测结果越接近真实标签。loss.item():提取张量中的标量值(去掉张量封装),方便打印和后续分析(避免打印多余的张量信息)。
-
反向传播三部曲(核心:参数更新) :这是 PyTorch 模型训练的核心流程,实现模型参数的自动优化:
optimizer.zero_grad():清零优化器的梯度缓存(PyTorch 中梯度会累积,若不清零,会导致后续批次的梯度叠加,影响参数更新)。loss.backward():反向传播,从损失值出发,计算损失对所有可训练参数(此处为全连接层参数)的梯度(梯度表示参数更新的方向和幅度)。optimizer.step():优化器更新参数 ,根据计算出的梯度,按照AdamW的优化逻辑,调整可训练参数的取值,让损失值逐渐降低。
-
梯度调试(可选):
- 遍历模型参数,检查是否存在有效梯度,目的是确认模型参数是否在正常更新。
- 若显示「没有找到梯度」,可能的原因:BERT 主干冻结且分类头参数不可训练、输入输出不匹配、学习率为 0 等,需要排查问题。
-
训练进度监控(每 5 个批次打印):
out.argmax(dim=1):取概率分布中最大值对应的索引(0或1),作为模型的最终分类预测结果(如[0.05, 0.95]对应索引1,即正面情感)。acc = (out == labels).sum().item()/len(labels):计算批次准确率,即预测正确的样本数占总样本数的比例,准确率越高表示模型效果越好。- 每 5 个批次打印一次,避免输出过多信息,方便监控训练趋势(损失是否逐渐降低,准确率是否逐渐提升)。
-
模型参数保存:
torch.save(model.state_dict(), f"params/{epoch}bert.pt"):保存模型的「状态字典」(state_dict),即模型中所有可训练参数的键值对(此处仅全连接层参数)。state_dict是 PyTorch 保存模型的标准方式,后续可通过model.load_state_dict(torch.load("xxx.pt"))加载模型参数,进行推理或继续训练。- 注意:需要提前创建
params文件夹,否则会报「文件不存在」错误。
补充:关键细节与优化点
- 代码小优化 :
loss_func = torch.nn.CrossEntropyLoss()无需在Batch循环内重复实例化,在循环外定义一次即可,减少冗余。 - 显存不足解决 :减小
batch_size(如从 32 改为 16)、减小max_length(如从 350 改为 256)、使用 CPU 训练(速度较慢)。 - 过拟合预防:后续可添加验证集,监控验证集损失和准确率,若验证集损失上升,说明模型过拟合,可提前停止训练(早停法)。
- 梯度裁剪 :若梯度范数过大(如超过 10),可添加
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0),防止梯度爆炸。
总结
- 核心流程:配置设备→批量数据处理→初始化模型 / 优化器 / 损失函数→多轮训练(前向→损失→反向→更新)→监控进度→保存参数,形成完整的训练闭环。
- 关键细节:
collate_fn统一输入尺寸,to(DEVICE)保证设备一致,反向传播三部曲实现参数更新,state_dict保存模型参数。 - 核心价值:实现了基于 BERT 的文本二分类模型的可落地训练,输出可监控的训练指标和可复用的模型参数,为后续推理和优化提供基础。
模块核心价值
该模块实现了端到端的模型训练流程,涵盖了数据预处理、组件初始化、参数迭代、监控保存等所有核心功能,形成了完整的训练闭环。代码逻辑清晰,注释详尽,能够让入门者直观理解模型训练的底层逻辑,同时具备良好的可修改性,可根据实际任务需求调整超参数、优化器、损失函数等组件。
五、训练核心原理总结与优化方向
1. 核心原理回顾
本次模型训练的核心逻辑是 "复用预训练知识,轻量化迭代优化":
- 数据层:通过标准化数据集接口,为训练提供统一、高质量的样本输入。
- 模型层:复用 BERT 预训练模型的中文语义提取能力,仅训练下游全连接分类层,降低训练成本。
- 训练层:通过 "前向计算→损失计算→反向传播→参数更新" 的闭环,实现模型参数的迭代优化,让模型逐步学习文本情感特征。
2. 后续优化方向(基于现有代码,不额外新增代码)
- 解冻 BERT 部分层微调:当前代码冻结了全部 BERT 层,可尝试解冻最后 2~3 层 BERT 参数参与训练,让模型学习到更贴合情感分类任务的语义特征,提升分类精度。
- 增加早停策略:在训练中监控验证集损失,若验证集损失连续多轮上升,停止训练,避免模型过拟合。
- 超参数调优:调整
BATCH_SIZE、LEARNING_RATE、训练轮次等超参数,或尝试不同的优化器(如 Adam、SGD),寻找最优参数组合。 - 数据增强:对训练集文本进行同义词替换、语句重组等增强操作,增加数据多样性,提升模型泛化能力。
六、全文总结
本文基于一套模块化的完整代码,详细解析了基于 BERT 的中文文本情感分类模型训练全流程,从数据加载到模型构建,再到端到端训练闭环,每个模块都提供了完整代码与详细解释,帮助读者不仅能看懂代码,更能理解背后的技术逻辑。
这套训练框架兼具轻量性与实用性,既适合 NLP 入门者学习模型训练的核心流程,也适合小规模中文情感分类任务的落地。掌握这套框架的核心逻辑,将为后续更复杂的 NLP 任务(如文本摘要、命名实体识别)学习与落地奠定坚实的基础。