源代码如下:bert_test
# 导入PyTorch核心库,用于张量计算、模型加载、设备管理等
import torch
# 导入自定义的数据集类Mydataset(用于加载本地ChnSentiCorp测试集数据)
from MyData import Mydataset
# 导入PyTorch的DataLoader,用于批量加载测试集数据
from torch.utils.data import DataLoader
# 导入自定义的下游分类模型Model(基于BERT的二分类模型)
from net import Model
# 导入BERT分词器,用于测试集文本的编码处理
from transformers import BertTokenizer
# 定义模型测试/推理使用的设备(优先使用NVIDIA CUDA GPU,无GPU则使用CPU)
# 以下是注释掉的AMD显卡加速配置(torch_directml),新手可暂时忽略
# import torch_directml as dml
# DEVICE = dml.device() if dml.is_available() else torch.device("cpu")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 打印当前使用的计算设备,确认是否成功启用GPU
print(f"使用设备: {DEVICE}")
# 定义本地中文BERT预训练模型的存储路径(需与训练时使用的模型一致)
model_name = r'D:\本地模型\google-bert\bert-base-chinese'
# 从本地路径加载BERT分词器,与预训练模型配套使用,保证编码规则一致
token = BertTokenizer.from_pretrained(model_name)
# 自定义数据批处理函数collate_fn:用于DataLoader中处理测试集批量样本
# 作用:将测试集的原始文本转换为BERT模型可识别的张量格式,统一输入尺寸
def collate_fn(data):
# 1. 从批量样本中分离文本内容和对应的标签(data是Mydataset返回的(text, label)元组列表)
sentes = [i[0] for i in data] # 提取所有测试样本的文本,组成列表:["文本1", "文本2", ...]
label = [i[1] for i in data] # 提取所有测试样本的真实标签,组成列表:[标签1, 标签2, ...]
# 2. 使用BERT分词器对批量测试文本进行编码处理
data = token.batch_encode_plus(
batch_text_or_text_pairs=sentes, # 待编码的批量测试文本列表
truncation=True, # 开启截断:文本长度超过max_length时截断
padding="max_length", # 开启填充:将所有文本填充到max_length指定长度
max_length=500, # 测试集文本的最大序列长度(需与训练时保持一致)
return_tensors="pt", # 返回PyTorch张量格式,方便模型计算
return_length=True # 返回每个文本的实际编码长度(可选,此处未使用)
)
# 3. 提取BERT模型需要的三个核心输入张量
input_ids = data["input_ids"] # 文本分词后的数字编码张量,形状:[batch_size, 500]
attention_mask = data["attention_mask"]# 注意力掩码张量:标记有效token(1)和填充token(0)
token_type_ids = data["token_type_ids"]# 句子对分隔张量:单句分类任务中全为0
# 4. 将测试集标签转换为PyTorch长整型张量(用于后续准确率计算)
labels = torch.LongTensor(label)
# 5. 返回处理后的批量输入数据和真实标签
return input_ids, attention_mask, token_type_ids, labels
# 6. 创建测试集数据集实例:加载Mydataset中的"test"划分数据
test_dataset = Mydataset("test")
# 7. 创建测试集数据加载器(DataLoader):批量迭代测试数据
test_laoder = DataLoader(
dataset=test_dataset, # 传入已创建的测试集数据集
batch_size=32, # 测试批次大小(与训练时一致,可根据GPU显存调整)
shuffle=True, # 测试时打乱数据(不影响最终准确率,仅为遍历顺序随机)
drop_last=True, # 删除最后一个不完整的批次,保证批次样本数一致
collate_fn=collate_fn # 指定自定义的批处理函数,处理测试集批量样本
)
# 主程序入口:仅在直接运行该脚本时执行测试逻辑
if __name__ == '__main__':
# 初始化准确率计算的累计变量
acc = 0 # 累计预测正确的样本数
total = 0 # 累计参与测试的总样本数
# 打印设备信息,确认测试设备
print(DEVICE)
# 1. 实例化自定义分类模型,并移至指定计算设备(GPU/CPU)
model = Model().to(DEVICE)
# 2. 加载训练好的模型权重文件(关键:使用训练完成的参数进行测试)
# "params/1bert.pt":第1个Epoch训练后保存的模型参数文件(可替换为效果更好的Epoch文件)
model.load_state_dict(torch.load("params/1bert.pt"))
# 3. 关键:将模型切换到推理/评估模式
# 作用:关闭训练相关的层(如Dropout、BatchNorm),保证推理结果稳定;禁用梯度计算,节省显存
model.eval()
# 4. 批量遍历测试集数据,计算整体准确率
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(test_laoder):
# 5. 将测试批量数据移至指定设备(与模型在同一设备,避免计算报错)
input_ids, attention_mask, token_type_ids, labels = input_ids.to(DEVICE), attention_mask.to(
DEVICE), token_type_ids.to(DEVICE), labels.to(DEVICE)
# 6. 模型推理(前向传播):获取测试样本的分类概率输出
# 注:测试时无需计算梯度,可添加with torch.no_grad()进一步优化(见补充说明)
out = model(input_ids, attention_mask, token_type_ids)
# 7. 提取预测结果:取概率分布中最大值对应的索引(0或1),作为最终分类预测
out = out.argmax(dim=1)
# 8. 累计统计:更新正确样本数和总样本数
acc += (out == labels).sum().item() # 统计当前批次预测正确的样本数,累加到acc
total += len(labels) # 统计当前批次的总样本数,累加到total
# 打印当前遍历的批次索引,监控测试进度
print(f"已完成测试批次: {i}")
# 9. 计算并打印测试集的整体准确率
# 准确率 = 预测正确的总样本数 / 参与测试的总样本数
print(f"\n测试集整体准确率: {acc / total:.8f}")
模型训练完成后,测试阶段 是验证模型泛化能力的核心环节 ------ 只有在未见过的测试集上表现稳定,才能说明模型真正学到了文本的语义特征,而非单纯 "记住" 训练数据。本文聚焦基于bert-base-chinese的中文文本二分类模型测试全流程,从测试集加载、批量数据处理,到模型推理、准确率计算,拆解每一个核心步骤,并附上可直接运行的完整代码。
测试阶段的核心目标与准备工作
核心目标
- 验证训练后的模型在独立测试集上的分类准确率,评估模型泛化能力;
- 确保模型推理流程的正确性(输入格式、设备匹配、结果解析无错误);
- 为模型优化(如调整训练轮次、批次大小、文本长度)提供数据支撑。
前置准备
在开始测试前,需确保以下资源已就绪:
- 环境依赖 :与训练阶段一致(
torch、transformers、datasets),避免版本不一致导致的兼容性问题; - 预训练模型 :本地存放的
bert-base-chinese模型(路径与训练时一致); - 训练好的权重文件 :如
params/1bert.pt(训练阶段保存的模型参数); - 测试数据集 :ChnSentiCorp 的
test划分(与训练 / 验证集独立); - 自定义模块 :
MyData.py(数据集封装)、net.py(BERT 分类模型)。
整体功能概述
这段代码是训练好的 BERT 文本二分类模型的完整测试流程,核心目标是:
- 加载独立的测试集(ChnSentiCorp 的
test划分),验证模型的泛化能力(对未见过数据的预测能力); - 按照与训练一致的规则处理测试数据,保证输入格式匹配;
- 加载训练阶段保存的模型权重,通过批量推理计算测试集的整体分类准确率;
- 输出准确率指标,评估模型的最终性能。
整个流程可概括为:设备配置 → 数据预处理 → 测试集加载 → 模型加载 → 批量推理 → 准确率统计。
第一步:导入核心依赖库 / 模块
# 导入PyTorch核心库,用于张量计算、模型加载、设备管理等
import torch
# 导入自定义的数据集类Mydataset(用于加载本地ChnSentiCorp测试集数据)
from MyData import Mydataset
# 导入PyTorch的DataLoader,用于批量加载测试集数据
from torch.utils.data import DataLoader
# 导入自定义的下游分类模型Model(基于BERT的二分类模型)
from net import Model
# 导入BERT分词器,用于测试集文本的编码处理
from transformers import BertTokenizer
每个导入项的核心作用(与训练代码呼应):
torch:PyTorch 核心库,负责张量计算、模型权重加载、设备管理(如to(DEVICE))、准确率统计等核心操作;Mydataset:你自定义的数据集类,已封装了本地 ChnSentiCorp 数据集的加载逻辑,这里专门加载test划分的数据;DataLoader:PyTorch 的批量数据迭代器,解决 "单样本处理效率低" 的问题,同时配合collate_fn完成批量数据标准化;Model:你自定义的二分类模型(BERT 主干 + 全连接分类头),与训练阶段的模型结构完全一致,保证权重加载后可正常推理;BertTokenizer:BERT 配套的分词器,需与训练时使用的bert-base-chinese模型匹配,保证文本编码规则(分词、词汇表、填充 / 截断)一致。
第二步:配置计算设备
# 定义模型测试/推理使用的设备(优先使用NVIDIA CUDA GPU,无GPU则使用CPU)
# 以下是注释掉的AMD显卡加速配置(torch_directml),新手可暂时忽略
# import torch_directml as dml
# DEVICE = dml.device() if dml.is_available() else torch.device("cpu")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 打印当前使用的计算设备,确认是否成功启用GPU
print(f"使用设备: {DEVICE}")
核心逻辑与设计目的:
- 设备优先级:优先使用 NVIDIA GPU(CUDA),因为 GPU 的并行计算能力可大幅提升测试推理速度;若无 GPU,自动降级为 CPU(速度较慢,但能保证代码运行);
- 打印设备信息 :方便你确认设备是否配置正确(如显示
cuda:0表示 GPU 可用,cpu表示仅能用 CPU),避免后续 "模型与数据不在同一设备" 的报错; - 与训练一致 :测试设备需与训练设备逻辑一致,否则可能出现权重加载异常(如训练用 GPU、测试用 CPU 时,需注意权重加载的
map_location参数,本文代码中未体现,因为设备判断逻辑统一)。
第三步:加载 BERT 分词器
# 定义本地中文BERT预训练模型的存储路径(需与训练时使用的模型一致)
model_name = r'D:\本地模型\google-bert\bert-base-chinese'
# 从本地路径加载BERT分词器,与预训练模型配套使用,保证编码规则一致
token = BertTokenizer.from_pretrained(model_name)
关键注意事项:
model_name路径必须与训练阶段完全一致:若路径错误或加载了其他版本的分词器,会导致词汇表不匹配,编码后的input_ids无效,模型推理报错;- 分词器的核心作用:将测试集的原始中文文本转换为 BERT 模型能识别的数字编码(
input_ids),是 "文本→模型输入" 的桥梁。
第四步:自定义批量数据处理函数collate_fn
这是测试代码的核心辅助函数 ,解决 "文本长度不一致" 和 "批量输入格式标准化" 的问题,且必须与训练阶段的collate_fn逻辑完全一致。
# 自定义数据批处理函数collate_fn:用于DataLoader中处理测试集批量样本
# 作用:将测试集的原始文本转换为BERT模型可识别的张量格式,统一输入尺寸
def collate_fn(data):
# 1. 从批量样本中分离文本内容和对应的标签(data是Mydataset返回的(text, label)元组列表)
sentes = [i[0] for i in data] # 提取所有测试样本的文本,组成列表:["文本1", "文本2", ...]
label = [i[1] for i in data] # 提取所有测试样本的真实标签,组成列表:[标签1, 标签2, ...]
# 2. 使用BERT分词器对批量测试文本进行编码处理
data = token.batch_encode_plus(
batch_text_or_text_pairs=sentes, # 待编码的批量测试文本列表
truncation=True, # 开启截断:文本长度超过max_length时截断
padding="max_length", # 开启填充:将所有文本填充到max_length指定长度
max_length=500, # 测试集文本的最大序列长度(需与训练时保持一致)
return_tensors="pt", # 返回PyTorch张量格式,方便模型计算
return_length=True # 返回每个文本的实际编码长度(可选,此处未使用)
)
# 3. 提取BERT模型需要的三个核心输入张量
input_ids = data["input_ids"] # 文本分词后的数字编码张量,形状:[batch_size, 500]
attention_mask = data["attention_mask"]# 注意力掩码张量:标记有效token(1)和填充token(0)
token_type_ids = data["token_type_ids"]# 句子对分隔张量:单句分类任务中全为0
# 4. 将测试集标签转换为PyTorch长整型张量(用于后续准确率计算)
labels = torch.LongTensor(label)
# 5. 返回处理后的批量输入数据和真实标签
return input_ids, attention_mask, token_type_ids, labels
逐步骤解释:
-
分离文本和标签:
data是DataLoader传入的 "批量样本列表",每个元素是Mydataset.__getitem__()返回的(text, label)元组(测试集保留标签,用于计算准确率);- 用列表推导式分离出文本列表
sentes和标签列表label,方便后续分别处理。
-
批量编码
token.batch_encode_plus():这是分词器的核心批量处理方法,参数必须与训练阶段完全一致,否则模型输入维度不匹配:truncation=True:若测试文本分词后长度超过 500,截断到 500 个 token(避免输入过长导致显存溢出);padding="max_length":将所有文本填充到 500 个 token,保证批量内所有样本的输入尺寸一致(神经网络要求固定尺寸输入);max_length=500:必须与训练阶段的数值一致(若训练用 350、测试用 500,会导致输入维度不匹配,直接报错);return_tensors="pt":直接返回 PyTorch 张量,无需手动转换,提升效率。
-
提取 BERT 核心输入张量 :编码后返回的
data是字典,包含 3 个 BERT 必需的张量(形状均为[batch_size, 500]):input_ids:文本的数字编码(如 "好" 对应某个数字),是模型的核心输入;attention_mask:标记哪些 token 是有效文本(1)、哪些是填充(0),让 BERT 忽略填充部分;token_type_ids:单句分类任务中全为 0(仅在句对任务如问答中有用)。
-
标签转换为
LongTensor:- 测试集标签是整数(0/1),转换为 PyTorch 长整型张量,方便后续与模型预测结果(张量)对比,计算准确率。
第五步:加载测试集并创建DataLoader
# 6. 创建测试集数据集实例:加载Mydataset中的"test"划分数据
test_dataset = Mydataset("test")
# 7. 创建测试集数据加载器(DataLoader):批量迭代测试数据
test_laoder = DataLoader(
dataset=test_dataset, # 传入已创建的测试集数据集
batch_size=32, # 测试批次大小(与训练时一致,可根据GPU显存调整)
shuffle=True, # 测试时打乱数据(不影响最终准确率,仅为遍历顺序随机)
drop_last=True, # 删除最后一个不完整的批次,保证批次样本数一致
collate_fn=collate_fn # 指定自定义的批处理函数,处理测试集批量样本
)
核心参数解释:
test_dataset = Mydataset("test"):加载 ChnSentiCorp 的测试集(与训练集、验证集独立),保证测试结果能反映模型的泛化能力;batch_size=32:与训练阶段一致,兼顾推理效率和显存占用(显存不足可改为 16/8);shuffle=True:测试时打乱数据顺序,不影响最终准确率 (仅改变遍历顺序),若想按原顺序测试可改为False;drop_last=True:删除最后一个不足 32 个样本的批次,避免 "批次样本数不一致" 导致的张量维度报错;collate_fn=collate_fn:指定自定义的批量数据处理函数,DataLoader每次获取批量样本后,会自动调用该函数完成编码和格式转换。
第六步:主程序测试逻辑(核心闭环)
if __name__ == '__main__':内的代码是测试流程的核心,实现从 "模型加载" 到 "准确率计算" 的完整闭环。
if __name__ == '__main__':
# 初始化准确率计算的累计变量
acc = 0 # 累计预测正确的样本数
total = 0 # 累计参与测试的总样本数
# 打印设备信息,确认测试设备
print(DEVICE)
# 1. 实例化自定义分类模型,并移至指定计算设备(GPU/CPU)
model = Model().to(DEVICE)
# 2. 加载训练好的模型权重文件(关键:使用训练完成的参数进行测试)
# "params/1bert.pt":第1个Epoch训练后保存的模型参数文件(可替换为效果更好的Epoch文件)
model.load_state_dict(torch.load("params/1bert.pt"))
# 3. 关键:将模型切换到推理/评估模式
# 作用:关闭训练相关的层(如Dropout、BatchNorm),保证推理结果稳定;禁用梯度计算,节省显存
model.eval()
# 4. 批量遍历测试集数据,计算整体准确率
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(test_laoder):
# 5. 将测试批量数据移至指定设备(与模型在同一设备,避免计算报错)
input_ids, attention_mask, token_type_ids, labels = input_ids.to(DEVICE), attention_mask.to(
DEVICE), token_type_ids.to(DEVICE), labels.to(DEVICE)
# 6. 模型推理(前向传播):获取测试样本的分类概率输出
# 注:测试时无需计算梯度,可添加with torch.no_grad()进一步优化(见补充说明)
out = model(input_ids, attention_mask, token_type_ids)
# 7. 提取预测结果:取概率分布中最大值对应的索引(0或1),作为最终分类预测
out = out.argmax(dim=1)
# 8. 累计统计:更新正确样本数和总样本数
acc += (out == labels).sum().item() # 统计当前批次预测正确的样本数,累加到acc
total += len(labels) # 统计当前批次的总样本数,累加到total
# 打印当前遍历的批次索引,监控测试进度
print(f"已完成测试批次: {i}")
# 9. 计算并打印测试集的整体准确率
# 准确率 = 预测正确的总样本数 / 参与测试的总样本数
print(f"\n测试集整体准确率: {acc / total:.8f}")
逐步骤解释(核心闭环):
-
初始化统计变量:
acc:累计所有批次中预测正确的样本数,初始为 0;total:累计所有参与测试的样本数,初始为 0;- 最终准确率 =
acc / total,这是评估模型性能的核心指标。
-
模型实例化与设备迁移:
model = Model().to(DEVICE):实例化自定义的二分类模型,并将模型所有参数移至指定设备(GPU/CPU);- 必须保证模型与后续输入数据在同一设备,否则会报 "张量与模型不在同一设备" 的错误。
-
加载训练好的权重:
model.load_state_dict(torch.load("params/1bert.pt")):这是测试的核心关键步骤 :torch.load("params/1bert.pt"):加载训练阶段保存的模型参数文件(state_dict,即模型可训练参数的键值对);model.load_state_dict(...):将加载的参数赋值给模型,替换模型的随机初始化参数;
- 若不加载权重,模型会使用随机参数推理,准确率约 50%(二分类随机猜测),失去测试意义。
-
切换到推理模式
model.eval():- 这是极易忽略但必须执行 的步骤,作用有二:
- 关闭训练相关的层(如 Dropout、BatchNorm):训练时 Dropout 会随机丢弃神经元防止过拟合,测试时需关闭,保证推理结果稳定;
- 隐含禁用梯度计算(部分优化):减少显存占用,提升推理速度。
- 这是极易忽略但必须执行 的步骤,作用有二:
-
批量遍历测试集:
enumerate(test_laoder):遍历DataLoader,返回批次索引i和处理后的批量数据(4 个张量);- 数据设备迁移:将
input_ids、attention_mask、token_type_ids、labels全部移至DEVICE,与模型保持一致。
-
模型推理(前向传播):
-
out = model(input_ids, attention_mask, token_type_ids):调用模型的forward方法,得到分类概率输出; -
out的形状为[32, 2](batch_size=32,二分类),每个元素是对应类别的概率(如[0.05, 0.95]表示 "负面" 概率 0.05,"正面" 概率 0.95); -
优化建议:添加
with torch.no_grad():上下文管理器(包裹整个遍历循环),禁用梯度计算,进一步节省显存、提升推理速度:model.eval() with torch.no_grad(): # 禁用梯度计算 for i, (input_ids, ...) in enumerate(test_laoder): # 后续推理逻辑不变
-
-
解析预测结果
out.argmax(dim=1):argmax(dim=1):在分类维度(第 1 维,对应 2 个类别)取概率最大值的索引,将概率分布转换为具体标签:- 如
out = [[0.05, 0.95], [0.9, 0.1]],out.argmax(dim=1)返回[1, 0],即第一个样本预测为 1(正面),第二个为 0(负面);
- 如
- 这是二分类结果解析的标准方式。
-
累计统计准确率:
(out == labels):逐元素对比预测标签和真实标签,返回布尔张量(True表示预测正确,False错误);.sum().item():统计布尔张量中True的数量(即当前批次正确样本数),并转换为 Python 标量;acc += ...:累加所有批次的正确样本数,total += len(labels)累加总样本数。
-
打印测试结果:
- 每批次打印索引,监控测试进度;
- 最终计算并打印整体准确率,保留 8 位小数,方便精确评估(如
0.89104729表示准确率 89.104729%)。
补充:常见问题与优化
- 权重加载报错 :若报 "key mismatch"(键不匹配),说明测试用的
Model类与训练时的结构不一致(如分类头维度修改),需保证模型结构完全一致; - 显存不足 :减小
batch_size(如 32→16)、减小max_length(如 500→256); - 准确率过低 :
- 检查权重文件是否正确(如加载了错误 Epoch 的权重);
- 检查
collate_fn的max_length是否与训练一致; - 检查测试集是否与训练集独立(避免数据泄露)。
总结
- 核心流程:设备配置 → 批量数据标准化 → 测试集加载 → 模型加载(权重)→ 推理模式切换 → 批量推理 → 准确率统计;
- 关键细节 :
- 测试数据的编码规则(
collate_fn)必须与训练一致; - 必须加载训练好的权重,否则测试无意义;
- 必须执行
model.eval(),保证推理结果稳定; - 模型与数据必须在同一设备,避免计算报错;
- 测试数据的编码规则(
- 核心价值:通过测试集准确率评估模型的泛化能力,验证模型是否真正学到了文本的语义特征,而非 "记住" 训练数据。
代码优化
结合性能、鲁棒性、代码规范、兼容性四个维度,对原测试代码进行全面优化,同时保留核心逻辑,修复潜在 Bug,提升可维护性与运行效率。
一、核心优化点说明
| 优化维度 | 具体优化内容 |
|---|---|
| 性能优化 | 新增torch.no_grad()禁用推理梯度,节省显存、提升推理速度 |
| 跨设备兼容 | 权重加载添加map_location,支持GPU 训练权重→CPU 测试的跨设备场景 |
| 代码规范 | 修复拼写错误、抽离硬编码常量、添加类型注解、模块化封装逻辑 |
| 鲁棒性 | 增加文件路径校验、异常捕获,避免因文件缺失 / 路径错误导致程序崩溃 |
| 可读性 | 优化打印日志、拆分核心逻辑为函数,结构更清晰 |
| 严谨性 | 优化模型加载参数、统一设备操作逻辑,规避设备不匹配报错 |
二、优化后完整代码
# 导入依赖库
import os
from typing import Tuple, List
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer
# 导入自定义模块(需保证与测试脚本同目录)
from MyData import Mydataset
from net import Model
# ===================== 全局配置(抽离硬编码,统一维护) =====================
# 预训练模型本地路径(与训练阶段保持一致)
PRETRAINED_MODEL_PATH = r'D:\本地模型\google-bert\bert-base-chinese'
# 训练好的模型权重路径
MODEL_WEIGHT_PATH = "params/1bert.pt"
# 文本最大长度(必须与训练阶段的max_length完全一致)
MAX_SEQ_LENGTH = 500
# 测试批次大小(建议与训练阶段一致)
BATCH_SIZE = 32
# 设备配置:优先CUDA,其次CPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ===================== 工具函数封装 =====================
def collate_fn(data: List[Tuple[str, int]]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
批量数据预处理函数,与训练阶段逻辑完全一致
Args:
data: 批量样本列表,单元素为(文本字符串, 标签整数)
Returns:
input_ids/attention_mask/token_type_ids: BERT模型标准输入
labels: 样本真实标签张量
"""
# 分离文本与标签
sents = [item[0] for item in data]
labels = [item[1] for item in data]
# 加载分词器(全局单例,避免重复加载)
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_PATH)
# 批量编码:截断+填充+张量格式转换
encoded_output = tokenizer.batch_encode_plus(
batch_text_or_text_pairs=sents,
truncation=True,
padding="max_length",
max_length=MAX_SEQ_LENGTH,
return_tensors="pt",
return_length=True
)
# 解构模型输入张量
input_ids = encoded_output["input_ids"]
attention_mask = encoded_output["attention_mask"]
token_type_ids = encoded_output["token_type_ids"]
label_tensor = torch.LongTensor(labels)
return input_ids, attention_mask, token_type_ids, label_tensor
def load_test_model() -> Model:
"""
加载模型结构与训练权重,切换至推理模式,封装复用
Returns:
完成权重加载、设备迁移、推理模式切换的模型实例
Raises:
FileNotFoundError: 权重文件不存在时抛出异常
"""
# 校验权重文件是否存在
if not os.path.exists(MODEL_WEIGHT_PATH):
raise FileNotFoundError(f"模型权重文件不存在,请检查路径:{MODEL_WEIGHT_PATH}")
# 初始化模型并迁移设备
model = Model().to(DEVICE)
# 加载权重:map_location兼容跨设备场景,strict=False适配轻微结构差异
checkpoint = torch.load(
MODEL_WEIGHT_PATH,
map_location=DEVICE,
weights_only=True # 安全加载,仅加载权重张量
)
model.load_state_dict(checkpoint, strict=True)
# 切换为推理模式:关闭Dropout/BatchNorm,保证结果稳定
model.eval()
print(f"✅ 模型权重加载完成,已切换至推理模式,运行设备:{DEVICE}")
return model
def build_test_dataloader() -> DataLoader:
"""构建测试集DataLoader,封装数据集与加载器逻辑"""
# 初始化测试集
test_dataset = Mydataset(split="test")
# 构建测试加载器
test_loader = DataLoader(
dataset=test_dataset,
batch_size=BATCH_SIZE,
shuffle=False, # 测试集无需打乱,修正原代码冗余shuffle
drop_last=True,
collate_fn=collate_fn
)
print(f"✅ 测试集加载完成,总批次:{len(test_loader)}")
return test_loader
# ===================== 主测试逻辑 =====================
def main():
print("=" * 50)
print(f"开始执行模型测试,运行设备:{DEVICE}")
print("=" * 50)
try:
# 1. 构建测试数据加载器
test_loader = build_test_dataloader()
# 2. 加载训练完成的模型
model = load_test_model()
# 初始化统计变量
total_correct = 0
total_samples = 0
# 核心优化:禁用梯度计算,大幅节省显存、提升推理速度
with torch.no_grad():
for batch_idx, batch_data in enumerate(test_loader):
# 解构批次数据
input_ids, attn_mask, token_type_ids, labels = batch_data
# 张量设备迁移:与模型保持同一设备
input_ids = input_ids.to(DEVICE)
attn_mask = attn_mask.to(DEVICE)
token_type_ids = token_type_ids.to(DEVICE)
labels = labels.to(DEVICE)
# 模型前向推理
logits = model(input_ids, attn_mask, token_type_ids)
# 获取预测类别:取概率最大的类别索引
pred_labels = logits.argmax(dim=1)
# 统计正确样本数与总样本数
batch_correct = (pred_labels == labels).sum().item()
total_correct += batch_correct
total_samples += len(labels)
# 每5个批次打印进度,避免日志刷屏
if batch_idx % 5 == 0:
batch_acc = batch_correct / len(labels)
print(f"批次:{batch_idx:>3d} | 当前批次准确率:{batch_acc:.4f}")
# 计算整体测试准确率
overall_acc = total_correct / total_samples if total_samples > 0 else 0.0
# 打印最终测试结果
print("\n" + "=" * 50)
print(f"测试完成 | 总测试样本数:{total_samples} | 正确预测数:{total_correct}")
print(f"测试集整体准确率:{overall_acc:.8f}")
print("=" * 50)
except Exception as e:
print(f"\n❌ 测试执行异常:{str(e)}")
raise
if __name__ == '__main__':
main()
三、关键优化细节解析
1. 性能核心优化:禁用梯度计算
测试阶段无需更新参数 ,添加with torch.no_grad()上下文管理器,禁用 Autograd 梯度计算:
- 减少显存占用,避免大批次 / 长文本导致显存溢出
- 提升推理速度,降低 CPU/GPU 计算负载
2. 跨设备兼容优化
权重加载添加map_location=DEVICE与weights_only=True:
- 解决GPU 训练、CPU 测试的设备不匹配问题
weights_only=True开启安全加载,规避权重文件安全风险
3. 代码规范与可维护性
- 抽离硬编码常量:路径、批次大小、序列长度等统一维护,修改更便捷
- 模块化封装:将数据加载、模型加载拆分为独立函数,代码结构清晰,便于复用与单元测试
- 修复拼写错误:原代码
test_laoder→test_loader,符合编程命名规范 - 新增类型注解,提升代码可读性与 IDE 提示效果
4. 鲁棒性增强
- 增加文件存在性校验:提前判断权重文件路径,避免运行到后半段崩溃
- 新增异常捕获
try-except,捕获执行异常并友好提示 - 增加边界判断:
total_samples > 0避免除零错误
5. 逻辑合理性优化
- 测试集
shuffle=False:测试阶段无需打乱样本顺序,原代码shuffle=True属于冗余操作 - 优化日志打印:按批次间隔输出进度,避免刷屏,结果展示格式化更清晰
- 拆分模型输入变量名:
attn_mask替代简写,语义更直观