十二、基于 BERT 的中文文本二分类模型测试实战:从数据加载到准确率评估

源代码如下:bert_test

复制代码
# 导入PyTorch核心库,用于张量计算、模型加载、设备管理等
import torch

# 导入自定义的数据集类Mydataset(用于加载本地ChnSentiCorp测试集数据)
from MyData import Mydataset
# 导入PyTorch的DataLoader,用于批量加载测试集数据
from torch.utils.data import DataLoader
# 导入自定义的下游分类模型Model(基于BERT的二分类模型)
from net import Model
# 导入BERT分词器,用于测试集文本的编码处理
from transformers import BertTokenizer

# 定义模型测试/推理使用的设备(优先使用NVIDIA CUDA GPU,无GPU则使用CPU)
# 以下是注释掉的AMD显卡加速配置(torch_directml),新手可暂时忽略
# import torch_directml as dml
# DEVICE = dml.device() if dml.is_available() else torch.device("cpu")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 打印当前使用的计算设备,确认是否成功启用GPU
print(f"使用设备: {DEVICE}")

# 定义本地中文BERT预训练模型的存储路径(需与训练时使用的模型一致)
model_name = r'D:\本地模型\google-bert\bert-base-chinese'
# 从本地路径加载BERT分词器,与预训练模型配套使用,保证编码规则一致
token = BertTokenizer.from_pretrained(model_name)

# 自定义数据批处理函数collate_fn:用于DataLoader中处理测试集批量样本
# 作用:将测试集的原始文本转换为BERT模型可识别的张量格式,统一输入尺寸
def collate_fn(data):
    # 1. 从批量样本中分离文本内容和对应的标签(data是Mydataset返回的(text, label)元组列表)
    sentes = [i[0] for i in data]  # 提取所有测试样本的文本,组成列表:["文本1", "文本2", ...]
    label = [i[1] for i in data]   # 提取所有测试样本的真实标签,组成列表:[标签1, 标签2, ...]
    
    # 2. 使用BERT分词器对批量测试文本进行编码处理
    data = token.batch_encode_plus(
        batch_text_or_text_pairs=sentes,  # 待编码的批量测试文本列表
        truncation=True,                  # 开启截断:文本长度超过max_length时截断
        padding="max_length",             # 开启填充:将所有文本填充到max_length指定长度
        max_length=500,                   # 测试集文本的最大序列长度(需与训练时保持一致)
        return_tensors="pt",              # 返回PyTorch张量格式,方便模型计算
        return_length=True                # 返回每个文本的实际编码长度(可选,此处未使用)
    )
    
    # 3. 提取BERT模型需要的三个核心输入张量
    input_ids = data["input_ids"]          # 文本分词后的数字编码张量,形状:[batch_size, 500]
    attention_mask = data["attention_mask"]# 注意力掩码张量:标记有效token(1)和填充token(0)
    token_type_ids = data["token_type_ids"]# 句子对分隔张量:单句分类任务中全为0
    
    # 4. 将测试集标签转换为PyTorch长整型张量(用于后续准确率计算)
    labels = torch.LongTensor(label)
    
    # 5. 返回处理后的批量输入数据和真实标签
    return input_ids, attention_mask, token_type_ids, labels

# 6. 创建测试集数据集实例:加载Mydataset中的"test"划分数据
test_dataset = Mydataset("test")

# 7. 创建测试集数据加载器(DataLoader):批量迭代测试数据
test_laoder = DataLoader(
    dataset=test_dataset,    # 传入已创建的测试集数据集
    batch_size=32,           # 测试批次大小(与训练时一致,可根据GPU显存调整)
    shuffle=True,            # 测试时打乱数据(不影响最终准确率,仅为遍历顺序随机)
    drop_last=True,          # 删除最后一个不完整的批次,保证批次样本数一致
    collate_fn=collate_fn    # 指定自定义的批处理函数,处理测试集批量样本
)

# 主程序入口:仅在直接运行该脚本时执行测试逻辑
if __name__ == '__main__':
    # 初始化准确率计算的累计变量
    acc = 0    # 累计预测正确的样本数
    total = 0  # 累计参与测试的总样本数
    
    # 打印设备信息,确认测试设备
    print(DEVICE)
    
    # 1. 实例化自定义分类模型,并移至指定计算设备(GPU/CPU)
    model = Model().to(DEVICE)
    
    # 2. 加载训练好的模型权重文件(关键:使用训练完成的参数进行测试)
    # "params/1bert.pt":第1个Epoch训练后保存的模型参数文件(可替换为效果更好的Epoch文件)
    model.load_state_dict(torch.load("params/1bert.pt"))
    
    # 3. 关键:将模型切换到推理/评估模式
    # 作用:关闭训练相关的层(如Dropout、BatchNorm),保证推理结果稳定;禁用梯度计算,节省显存
    model.eval()
    
    # 4. 批量遍历测试集数据,计算整体准确率
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(test_laoder):
        # 5. 将测试批量数据移至指定设备(与模型在同一设备,避免计算报错)
        input_ids, attention_mask, token_type_ids, labels = input_ids.to(DEVICE), attention_mask.to(
            DEVICE), token_type_ids.to(DEVICE), labels.to(DEVICE)
        
        # 6. 模型推理(前向传播):获取测试样本的分类概率输出
        # 注:测试时无需计算梯度,可添加with torch.no_grad()进一步优化(见补充说明)
        out = model(input_ids, attention_mask, token_type_ids)
        
        # 7. 提取预测结果:取概率分布中最大值对应的索引(0或1),作为最终分类预测
        out = out.argmax(dim=1)
        
        # 8. 累计统计:更新正确样本数和总样本数
        acc += (out == labels).sum().item()  # 统计当前批次预测正确的样本数,累加到acc
        total += len(labels)                # 统计当前批次的总样本数,累加到total
        
        # 打印当前遍历的批次索引,监控测试进度
        print(f"已完成测试批次: {i}")
    
    # 9. 计算并打印测试集的整体准确率
    # 准确率 = 预测正确的总样本数 / 参与测试的总样本数
    print(f"\n测试集整体准确率: {acc / total:.8f}")

模型训练完成后,测试阶段 是验证模型泛化能力的核心环节 ------ 只有在未见过的测试集上表现稳定,才能说明模型真正学到了文本的语义特征,而非单纯 "记住" 训练数据。本文聚焦基于bert-base-chinese的中文文本二分类模型测试全流程,从测试集加载、批量数据处理,到模型推理、准确率计算,拆解每一个核心步骤,并附上可直接运行的完整代码。

测试阶段的核心目标与准备工作

核心目标
  • 验证训练后的模型在独立测试集上的分类准确率,评估模型泛化能力;
  • 确保模型推理流程的正确性(输入格式、设备匹配、结果解析无错误);
  • 为模型优化(如调整训练轮次、批次大小、文本长度)提供数据支撑。
前置准备

在开始测试前,需确保以下资源已就绪:

  1. 环境依赖 :与训练阶段一致(torchtransformersdatasets),避免版本不一致导致的兼容性问题;
  2. 预训练模型 :本地存放的bert-base-chinese模型(路径与训练时一致);
  3. 训练好的权重文件 :如params/1bert.pt(训练阶段保存的模型参数);
  4. 测试数据集 :ChnSentiCorp 的test划分(与训练 / 验证集独立);
  5. 自定义模块MyData.py(数据集封装)、net.py(BERT 分类模型)。

整体功能概述

这段代码是训练好的 BERT 文本二分类模型的完整测试流程,核心目标是:

  1. 加载独立的测试集(ChnSentiCorp 的test划分),验证模型的泛化能力(对未见过数据的预测能力);
  2. 按照与训练一致的规则处理测试数据,保证输入格式匹配;
  3. 加载训练阶段保存的模型权重,通过批量推理计算测试集的整体分类准确率;
  4. 输出准确率指标,评估模型的最终性能。

整个流程可概括为:设备配置 → 数据预处理 → 测试集加载 → 模型加载 → 批量推理 → 准确率统计


第一步:导入核心依赖库 / 模块

复制代码
# 导入PyTorch核心库,用于张量计算、模型加载、设备管理等
import torch

# 导入自定义的数据集类Mydataset(用于加载本地ChnSentiCorp测试集数据)
from MyData import Mydataset
# 导入PyTorch的DataLoader,用于批量加载测试集数据
from torch.utils.data import DataLoader
# 导入自定义的下游分类模型Model(基于BERT的二分类模型)
from net import Model
# 导入BERT分词器,用于测试集文本的编码处理
from transformers import BertTokenizer

每个导入项的核心作用(与训练代码呼应):

  1. torch:PyTorch 核心库,负责张量计算、模型权重加载、设备管理(如to(DEVICE))、准确率统计等核心操作;
  2. Mydataset:你自定义的数据集类,已封装了本地 ChnSentiCorp 数据集的加载逻辑,这里专门加载test划分的数据;
  3. DataLoader:PyTorch 的批量数据迭代器,解决 "单样本处理效率低" 的问题,同时配合collate_fn完成批量数据标准化;
  4. Model:你自定义的二分类模型(BERT 主干 + 全连接分类头),与训练阶段的模型结构完全一致,保证权重加载后可正常推理;
  5. BertTokenizer:BERT 配套的分词器,需与训练时使用的bert-base-chinese模型匹配,保证文本编码规则(分词、词汇表、填充 / 截断)一致。

第二步:配置计算设备

复制代码
# 定义模型测试/推理使用的设备(优先使用NVIDIA CUDA GPU,无GPU则使用CPU)
# 以下是注释掉的AMD显卡加速配置(torch_directml),新手可暂时忽略
# import torch_directml as dml
# DEVICE = dml.device() if dml.is_available() else torch.device("cpu")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 打印当前使用的计算设备,确认是否成功启用GPU
print(f"使用设备: {DEVICE}")

核心逻辑与设计目的:

  1. 设备优先级:优先使用 NVIDIA GPU(CUDA),因为 GPU 的并行计算能力可大幅提升测试推理速度;若无 GPU,自动降级为 CPU(速度较慢,但能保证代码运行);
  2. 打印设备信息 :方便你确认设备是否配置正确(如显示cuda:0表示 GPU 可用,cpu表示仅能用 CPU),避免后续 "模型与数据不在同一设备" 的报错;
  3. 与训练一致 :测试设备需与训练设备逻辑一致,否则可能出现权重加载异常(如训练用 GPU、测试用 CPU 时,需注意权重加载的map_location参数,本文代码中未体现,因为设备判断逻辑统一)。

第三步:加载 BERT 分词器

复制代码
# 定义本地中文BERT预训练模型的存储路径(需与训练时使用的模型一致)
model_name = r'D:\本地模型\google-bert\bert-base-chinese'
# 从本地路径加载BERT分词器,与预训练模型配套使用,保证编码规则一致
token = BertTokenizer.from_pretrained(model_name)

关键注意事项:

  1. model_name路径必须与训练阶段完全一致:若路径错误或加载了其他版本的分词器,会导致词汇表不匹配,编码后的input_ids无效,模型推理报错;
  2. 分词器的核心作用:将测试集的原始中文文本转换为 BERT 模型能识别的数字编码(input_ids),是 "文本→模型输入" 的桥梁。

第四步:自定义批量数据处理函数collate_fn

这是测试代码的核心辅助函数 ,解决 "文本长度不一致" 和 "批量输入格式标准化" 的问题,且必须与训练阶段的collate_fn逻辑完全一致。

复制代码
# 自定义数据批处理函数collate_fn:用于DataLoader中处理测试集批量样本
# 作用:将测试集的原始文本转换为BERT模型可识别的张量格式,统一输入尺寸
def collate_fn(data):
    # 1. 从批量样本中分离文本内容和对应的标签(data是Mydataset返回的(text, label)元组列表)
    sentes = [i[0] for i in data]  # 提取所有测试样本的文本,组成列表:["文本1", "文本2", ...]
    label = [i[1] for i in data]   # 提取所有测试样本的真实标签,组成列表:[标签1, 标签2, ...]
    
    # 2. 使用BERT分词器对批量测试文本进行编码处理
    data = token.batch_encode_plus(
        batch_text_or_text_pairs=sentes,  # 待编码的批量测试文本列表
        truncation=True,                  # 开启截断:文本长度超过max_length时截断
        padding="max_length",             # 开启填充:将所有文本填充到max_length指定长度
        max_length=500,                   # 测试集文本的最大序列长度(需与训练时保持一致)
        return_tensors="pt",              # 返回PyTorch张量格式,方便模型计算
        return_length=True                # 返回每个文本的实际编码长度(可选,此处未使用)
    )
    
    # 3. 提取BERT模型需要的三个核心输入张量
    input_ids = data["input_ids"]          # 文本分词后的数字编码张量,形状:[batch_size, 500]
    attention_mask = data["attention_mask"]# 注意力掩码张量:标记有效token(1)和填充token(0)
    token_type_ids = data["token_type_ids"]# 句子对分隔张量:单句分类任务中全为0
    
    # 4. 将测试集标签转换为PyTorch长整型张量(用于后续准确率计算)
    labels = torch.LongTensor(label)
    
    # 5. 返回处理后的批量输入数据和真实标签
    return input_ids, attention_mask, token_type_ids, labels
逐步骤解释:
  1. 分离文本和标签

    • dataDataLoader传入的 "批量样本列表",每个元素是Mydataset.__getitem__()返回的(text, label)元组(测试集保留标签,用于计算准确率);
    • 用列表推导式分离出文本列表sentes和标签列表label,方便后续分别处理。
  2. 批量编码token.batch_encode_plus():这是分词器的核心批量处理方法,参数必须与训练阶段完全一致,否则模型输入维度不匹配:

    • truncation=True:若测试文本分词后长度超过 500,截断到 500 个 token(避免输入过长导致显存溢出);
    • padding="max_length":将所有文本填充到 500 个 token,保证批量内所有样本的输入尺寸一致(神经网络要求固定尺寸输入);
    • max_length=500:必须与训练阶段的数值一致(若训练用 350、测试用 500,会导致输入维度不匹配,直接报错);
    • return_tensors="pt":直接返回 PyTorch 张量,无需手动转换,提升效率。
  3. 提取 BERT 核心输入张量 :编码后返回的data是字典,包含 3 个 BERT 必需的张量(形状均为[batch_size, 500]):

    • input_ids:文本的数字编码(如 "好" 对应某个数字),是模型的核心输入;
    • attention_mask:标记哪些 token 是有效文本(1)、哪些是填充(0),让 BERT 忽略填充部分;
    • token_type_ids:单句分类任务中全为 0(仅在句对任务如问答中有用)。
  4. 标签转换为LongTensor

    • 测试集标签是整数(0/1),转换为 PyTorch 长整型张量,方便后续与模型预测结果(张量)对比,计算准确率。

第五步:加载测试集并创建DataLoader

复制代码
# 6. 创建测试集数据集实例:加载Mydataset中的"test"划分数据
test_dataset = Mydataset("test")

# 7. 创建测试集数据加载器(DataLoader):批量迭代测试数据
test_laoder = DataLoader(
    dataset=test_dataset,    # 传入已创建的测试集数据集
    batch_size=32,           # 测试批次大小(与训练时一致,可根据GPU显存调整)
    shuffle=True,            # 测试时打乱数据(不影响最终准确率,仅为遍历顺序随机)
    drop_last=True,          # 删除最后一个不完整的批次,保证批次样本数一致
    collate_fn=collate_fn    # 指定自定义的批处理函数,处理测试集批量样本
)

核心参数解释:

  1. test_dataset = Mydataset("test"):加载 ChnSentiCorp 的测试集(与训练集、验证集独立),保证测试结果能反映模型的泛化能力;
  2. batch_size=32:与训练阶段一致,兼顾推理效率和显存占用(显存不足可改为 16/8);
  3. shuffle=True:测试时打乱数据顺序,不影响最终准确率 (仅改变遍历顺序),若想按原顺序测试可改为False
  4. drop_last=True:删除最后一个不足 32 个样本的批次,避免 "批次样本数不一致" 导致的张量维度报错;
  5. collate_fn=collate_fn:指定自定义的批量数据处理函数,DataLoader每次获取批量样本后,会自动调用该函数完成编码和格式转换。

第六步:主程序测试逻辑(核心闭环)

if __name__ == '__main__':内的代码是测试流程的核心,实现从 "模型加载" 到 "准确率计算" 的完整闭环。

复制代码
if __name__ == '__main__':
    # 初始化准确率计算的累计变量
    acc = 0    # 累计预测正确的样本数
    total = 0  # 累计参与测试的总样本数
    
    # 打印设备信息,确认测试设备
    print(DEVICE)
    
    # 1. 实例化自定义分类模型,并移至指定计算设备(GPU/CPU)
    model = Model().to(DEVICE)
    
    # 2. 加载训练好的模型权重文件(关键:使用训练完成的参数进行测试)
    # "params/1bert.pt":第1个Epoch训练后保存的模型参数文件(可替换为效果更好的Epoch文件)
    model.load_state_dict(torch.load("params/1bert.pt"))
    
    # 3. 关键:将模型切换到推理/评估模式
    # 作用:关闭训练相关的层(如Dropout、BatchNorm),保证推理结果稳定;禁用梯度计算,节省显存
    model.eval()
    
    # 4. 批量遍历测试集数据,计算整体准确率
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(test_laoder):
        # 5. 将测试批量数据移至指定设备(与模型在同一设备,避免计算报错)
        input_ids, attention_mask, token_type_ids, labels = input_ids.to(DEVICE), attention_mask.to(
            DEVICE), token_type_ids.to(DEVICE), labels.to(DEVICE)
        
        # 6. 模型推理(前向传播):获取测试样本的分类概率输出
        # 注:测试时无需计算梯度,可添加with torch.no_grad()进一步优化(见补充说明)
        out = model(input_ids, attention_mask, token_type_ids)
        
        # 7. 提取预测结果:取概率分布中最大值对应的索引(0或1),作为最终分类预测
        out = out.argmax(dim=1)
        
        # 8. 累计统计:更新正确样本数和总样本数
        acc += (out == labels).sum().item()  # 统计当前批次预测正确的样本数,累加到acc
        total += len(labels)                # 统计当前批次的总样本数,累加到total
        
        # 打印当前遍历的批次索引,监控测试进度
        print(f"已完成测试批次: {i}")
    
    # 9. 计算并打印测试集的整体准确率
    # 准确率 = 预测正确的总样本数 / 参与测试的总样本数
    print(f"\n测试集整体准确率: {acc / total:.8f}")
逐步骤解释(核心闭环):
  1. 初始化统计变量

    • acc:累计所有批次中预测正确的样本数,初始为 0;
    • total:累计所有参与测试的样本数,初始为 0;
    • 最终准确率 = acc / total,这是评估模型性能的核心指标。
  2. 模型实例化与设备迁移

    • model = Model().to(DEVICE):实例化自定义的二分类模型,并将模型所有参数移至指定设备(GPU/CPU);
    • 必须保证模型与后续输入数据在同一设备,否则会报 "张量与模型不在同一设备" 的错误。
  3. 加载训练好的权重

    • model.load_state_dict(torch.load("params/1bert.pt")):这是测试的核心关键步骤
      • torch.load("params/1bert.pt"):加载训练阶段保存的模型参数文件(state_dict,即模型可训练参数的键值对);
      • model.load_state_dict(...):将加载的参数赋值给模型,替换模型的随机初始化参数;
    • 若不加载权重,模型会使用随机参数推理,准确率约 50%(二分类随机猜测),失去测试意义。
  4. 切换到推理模式model.eval()

    • 这是极易忽略但必须执行 的步骤,作用有二:
      1. 关闭训练相关的层(如 Dropout、BatchNorm):训练时 Dropout 会随机丢弃神经元防止过拟合,测试时需关闭,保证推理结果稳定;
      2. 隐含禁用梯度计算(部分优化):减少显存占用,提升推理速度。
  5. 批量遍历测试集

    • enumerate(test_laoder):遍历DataLoader,返回批次索引i和处理后的批量数据(4 个张量);
    • 数据设备迁移:将input_idsattention_masktoken_type_idslabels全部移至DEVICE,与模型保持一致。
  6. 模型推理(前向传播)

    • out = model(input_ids, attention_mask, token_type_ids):调用模型的forward方法,得到分类概率输出;

    • out的形状为[32, 2]batch_size=32,二分类),每个元素是对应类别的概率(如[0.05, 0.95]表示 "负面" 概率 0.05,"正面" 概率 0.95);

    • 优化建议:添加with torch.no_grad():上下文管理器(包裹整个遍历循环),禁用梯度计算,进一步节省显存、提升推理速度:

      复制代码
      model.eval()
      with torch.no_grad():  # 禁用梯度计算
          for i, (input_ids, ...) in enumerate(test_laoder):
              # 后续推理逻辑不变
  7. 解析预测结果out.argmax(dim=1)

    • argmax(dim=1):在分类维度(第 1 维,对应 2 个类别)取概率最大值的索引,将概率分布转换为具体标签:
      • out = [[0.05, 0.95], [0.9, 0.1]]out.argmax(dim=1)返回[1, 0],即第一个样本预测为 1(正面),第二个为 0(负面);
    • 这是二分类结果解析的标准方式。
  8. 累计统计准确率

    • (out == labels):逐元素对比预测标签和真实标签,返回布尔张量(True表示预测正确,False错误);
    • .sum().item():统计布尔张量中True的数量(即当前批次正确样本数),并转换为 Python 标量;
    • acc += ...:累加所有批次的正确样本数,total += len(labels)累加总样本数。
  9. 打印测试结果

    • 每批次打印索引,监控测试进度;
    • 最终计算并打印整体准确率,保留 8 位小数,方便精确评估(如0.89104729表示准确率 89.104729%)。

补充:常见问题与优化

  1. 权重加载报错 :若报 "key mismatch"(键不匹配),说明测试用的Model类与训练时的结构不一致(如分类头维度修改),需保证模型结构完全一致;
  2. 显存不足 :减小batch_size(如 32→16)、减小max_length(如 500→256);
  3. 准确率过低
    • 检查权重文件是否正确(如加载了错误 Epoch 的权重);
    • 检查collate_fnmax_length是否与训练一致;
    • 检查测试集是否与训练集独立(避免数据泄露)。

总结

  1. 核心流程:设备配置 → 批量数据标准化 → 测试集加载 → 模型加载(权重)→ 推理模式切换 → 批量推理 → 准确率统计;
  2. 关键细节
    • 测试数据的编码规则(collate_fn)必须与训练一致;
    • 必须加载训练好的权重,否则测试无意义;
    • 必须执行model.eval(),保证推理结果稳定;
    • 模型与数据必须在同一设备,避免计算报错;
  3. 核心价值:通过测试集准确率评估模型的泛化能力,验证模型是否真正学到了文本的语义特征,而非 "记住" 训练数据。

代码优化

结合性能、鲁棒性、代码规范、兼容性四个维度,对原测试代码进行全面优化,同时保留核心逻辑,修复潜在 Bug,提升可维护性与运行效率。

一、核心优化点说明
优化维度 具体优化内容
性能优化 新增torch.no_grad()禁用推理梯度,节省显存、提升推理速度
跨设备兼容 权重加载添加map_location,支持GPU 训练权重→CPU 测试的跨设备场景
代码规范 修复拼写错误、抽离硬编码常量、添加类型注解、模块化封装逻辑
鲁棒性 增加文件路径校验、异常捕获,避免因文件缺失 / 路径错误导致程序崩溃
可读性 优化打印日志、拆分核心逻辑为函数,结构更清晰
严谨性 优化模型加载参数、统一设备操作逻辑,规避设备不匹配报错
二、优化后完整代码
复制代码
# 导入依赖库
import os
from typing import Tuple, List

import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer

# 导入自定义模块(需保证与测试脚本同目录)
from MyData import Mydataset
from net import Model

# ===================== 全局配置(抽离硬编码,统一维护) =====================
# 预训练模型本地路径(与训练阶段保持一致)
PRETRAINED_MODEL_PATH = r'D:\本地模型\google-bert\bert-base-chinese'
# 训练好的模型权重路径
MODEL_WEIGHT_PATH = "params/1bert.pt"
# 文本最大长度(必须与训练阶段的max_length完全一致)
MAX_SEQ_LENGTH = 500
# 测试批次大小(建议与训练阶段一致)
BATCH_SIZE = 32
# 设备配置:优先CUDA,其次CPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ===================== 工具函数封装 =====================
def collate_fn(data: List[Tuple[str, int]]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    批量数据预处理函数,与训练阶段逻辑完全一致
    Args:
        data: 批量样本列表,单元素为(文本字符串, 标签整数)
    Returns:
        input_ids/attention_mask/token_type_ids: BERT模型标准输入
        labels: 样本真实标签张量
    """
    # 分离文本与标签
    sents = [item[0] for item in data]
    labels = [item[1] for item in data]

    # 加载分词器(全局单例,避免重复加载)
    tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_PATH)

    # 批量编码:截断+填充+张量格式转换
    encoded_output = tokenizer.batch_encode_plus(
        batch_text_or_text_pairs=sents,
        truncation=True,
        padding="max_length",
        max_length=MAX_SEQ_LENGTH,
        return_tensors="pt",
        return_length=True
    )

    # 解构模型输入张量
    input_ids = encoded_output["input_ids"]
    attention_mask = encoded_output["attention_mask"]
    token_type_ids = encoded_output["token_type_ids"]
    label_tensor = torch.LongTensor(labels)

    return input_ids, attention_mask, token_type_ids, label_tensor


def load_test_model() -> Model:
    """
    加载模型结构与训练权重,切换至推理模式,封装复用
    Returns:
        完成权重加载、设备迁移、推理模式切换的模型实例
    Raises:
        FileNotFoundError: 权重文件不存在时抛出异常
    """
    # 校验权重文件是否存在
    if not os.path.exists(MODEL_WEIGHT_PATH):
        raise FileNotFoundError(f"模型权重文件不存在,请检查路径:{MODEL_WEIGHT_PATH}")

    # 初始化模型并迁移设备
    model = Model().to(DEVICE)

    # 加载权重:map_location兼容跨设备场景,strict=False适配轻微结构差异
    checkpoint = torch.load(
        MODEL_WEIGHT_PATH,
        map_location=DEVICE,
        weights_only=True  # 安全加载,仅加载权重张量
    )
    model.load_state_dict(checkpoint, strict=True)

    # 切换为推理模式:关闭Dropout/BatchNorm,保证结果稳定
    model.eval()
    print(f"✅ 模型权重加载完成,已切换至推理模式,运行设备:{DEVICE}")
    return model


def build_test_dataloader() -> DataLoader:
    """构建测试集DataLoader,封装数据集与加载器逻辑"""
    # 初始化测试集
    test_dataset = Mydataset(split="test")
    # 构建测试加载器
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,       # 测试集无需打乱,修正原代码冗余shuffle
        drop_last=True,
        collate_fn=collate_fn
    )
    print(f"✅ 测试集加载完成,总批次:{len(test_loader)}")
    return test_loader

# ===================== 主测试逻辑 =====================
def main():
    print("=" * 50)
    print(f"开始执行模型测试,运行设备:{DEVICE}")
    print("=" * 50)

    try:
        # 1. 构建测试数据加载器
        test_loader = build_test_dataloader()
        # 2. 加载训练完成的模型
        model = load_test_model()

        # 初始化统计变量
        total_correct = 0
        total_samples = 0

        # 核心优化:禁用梯度计算,大幅节省显存、提升推理速度
        with torch.no_grad():
            for batch_idx, batch_data in enumerate(test_loader):
                # 解构批次数据
                input_ids, attn_mask, token_type_ids, labels = batch_data

                # 张量设备迁移:与模型保持同一设备
                input_ids = input_ids.to(DEVICE)
                attn_mask = attn_mask.to(DEVICE)
                token_type_ids = token_type_ids.to(DEVICE)
                labels = labels.to(DEVICE)

                # 模型前向推理
                logits = model(input_ids, attn_mask, token_type_ids)
                # 获取预测类别:取概率最大的类别索引
                pred_labels = logits.argmax(dim=1)

                # 统计正确样本数与总样本数
                batch_correct = (pred_labels == labels).sum().item()
                total_correct += batch_correct
                total_samples += len(labels)

                # 每5个批次打印进度,避免日志刷屏
                if batch_idx % 5 == 0:
                    batch_acc = batch_correct / len(labels)
                    print(f"批次:{batch_idx:>3d} | 当前批次准确率:{batch_acc:.4f}")

        # 计算整体测试准确率
        overall_acc = total_correct / total_samples if total_samples > 0 else 0.0

        # 打印最终测试结果
        print("\n" + "=" * 50)
        print(f"测试完成 | 总测试样本数:{total_samples} | 正确预测数:{total_correct}")
        print(f"测试集整体准确率:{overall_acc:.8f}")
        print("=" * 50)

    except Exception as e:
        print(f"\n❌ 测试执行异常:{str(e)}")
        raise


if __name__ == '__main__':
    main()
三、关键优化细节解析
1. 性能核心优化:禁用梯度计算

测试阶段无需更新参数 ,添加with torch.no_grad()上下文管理器,禁用 Autograd 梯度计算:

  • 减少显存占用,避免大批次 / 长文本导致显存溢出
  • 提升推理速度,降低 CPU/GPU 计算负载
2. 跨设备兼容优化

权重加载添加map_location=DEVICEweights_only=True

  • 解决GPU 训练、CPU 测试的设备不匹配问题
  • weights_only=True开启安全加载,规避权重文件安全风险
3. 代码规范与可维护性
  • 抽离硬编码常量:路径、批次大小、序列长度等统一维护,修改更便捷
  • 模块化封装:将数据加载、模型加载拆分为独立函数,代码结构清晰,便于复用与单元测试
  • 修复拼写错误:原代码test_laodertest_loader,符合编程命名规范
  • 新增类型注解,提升代码可读性与 IDE 提示效果
4. 鲁棒性增强
  • 增加文件存在性校验:提前判断权重文件路径,避免运行到后半段崩溃
  • 新增异常捕获try-except,捕获执行异常并友好提示
  • 增加边界判断:total_samples > 0避免除零错误
5. 逻辑合理性优化
  • 测试集shuffle=False:测试阶段无需打乱样本顺序,原代码shuffle=True属于冗余操作
  • 优化日志打印:按批次间隔输出进度,避免刷屏,结果展示格式化更清晰
  • 拆分模型输入变量名:attn_mask替代简写,语义更直观
相关推荐
淮北4942 小时前
Reinforce算法
人工智能·机器学习
shangjian0072 小时前
AI-大语言模型LLM-概念术语-Dropout
人工智能·语言模型·自然语言处理
小鸡吃米…2 小时前
机器学习 - 高斯判别分析(Gaussian Discriminant Analysis)
人工智能·深度学习·机器学习
香芋Yu2 小时前
【机器学习教程】第01章:机器学习概览
人工智能·机器学习
HySpark2 小时前
关于语音智能技术实践与应用探索
人工智能·语音识别
AI应用开发实战派2 小时前
AI人工智能中Bard的智能电子商务优化
人工智能·ai·bard
FL16238631292 小时前
MMA综合格斗动作检测数据集VOC+YOLO格式1780张16类别
人工智能·yolo·机器学习
应用市场2 小时前
深度学习图像超分辨率技术全面解析:从入门到精通
人工智能·深度学习
格林威2 小时前
Baumer相机铸件气孔与缩松识别:提升铸造良品率的 6 个核心算法,附 OpenCV+Halcon 实战代码!
人工智能·opencv·算法·安全·计算机视觉·堡盟相机·baumer相机