基于 PyTorch 的模型测试与全局平均池化实践

一、模型测试:整体与类别准确率统计

模型训练完成后,我们需要在测试集上评估其性能,不仅要知道整体的准确率,有时还需要了解模型在各个类别上的表现。

首先来看整体准确率的计算代码:

复制代码
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

这里使用 torch.no_grad() 上下文管理器,因为测试阶段不需要计算梯度,这样可以节省内存并加快计算速度。通过遍历测试数据加载器 testloader,将图像和标签放到指定设备(CPU 或 GPU)上,模型 net 对图像进行前向传播得到输出 outputstorch.max(outputs.data, 1) 会返回每个样本输出中最大值的索引,也就是模型预测的类别。最后通过统计预测正确的样本数与总样本数的比例,得到整体准确率。

接下来是各类别准确率的统计代码:

复制代码
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

我们初始化了两个列表 class_correctclass_total,分别用于记录每个类别预测正确的样本数和每个类别总的样本数。在遍历测试数据时,对于每个样本,判断预测是否正确,并根据标签更新对应类别的统计值。最后遍历每个类别,计算并打印出每个类别的准确率。这样可以帮助我们了解模型在哪些类别上表现较好,哪些类别上还有提升空间。

二、采用全局平均池化优化网络结构

全局平均池化(Global Average Pooling,GAP)是一种常用的网络结构优化手段,它可以替代全连接层的部分功能,减少模型参数数量,同时也有助于缓解过拟合问题。

下面是采用全局平均池化的网络结构定义代码:

复制代码
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 36, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        # 全局平均池化层
        self.aap = nn.AdaptiveAvgPool2d(1)
        self.fc3 = nn.Linear(36, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.aap(x)
        x = x.view(x.shape[0], -1)
        x = self.fc3(x)
        return x

net = Net()
net = net.to(device)

print("net_gvp have {} parameters in total".format(sum(x.numel() for x in net.parameters())))

在这个网络结构中,首先通过卷积层 conv1conv2 提取图像特征,然后经过最大池化层 pool1pool2 缩小特征图尺寸。接着使用全局平均池化层 aap,它会将每个特征图平均成一个值,这样就将特征图转换为了固定长度的向量。最后通过一个全连接层 fc3 将特征映射到类别空间。

与使用全连接层直接处理卷积层输出相比,全局平均池化减少了大量的参数。我们可以通过 sum(x.numel() for x in net.parameters()) 统计模型的总参数数量,参数数量的减少有助于模型的部署和推理,同时也降低了过拟合的风险。

三、总结

通过模型测试,我们可以全面了解模型在测试集上的性能表现,包括整体准确率和各类别准确率。而采用全局平均池化则是一种有效的网络结构优化方法,能够在保证模型性能的前提下,减少参数数量,提升模型的泛化能力和推理效率。在实际的深度学习项目中,这些技术都是非常实用的工具,有助于我们开发出更高效、更准确的模型。

相关推荐
TG:@yunlaoda360 云老大12 小时前
腾讯WAIC发布“1+3+N”AI全景图:混元3D世界模型开源,具身智能平台Tairos亮相
人工智能·3d·开源·腾讯云
这张生成的图像能检测吗12 小时前
(论文速读)Fast3R:在一个向前通道中实现1000+图像的3D重建
人工智能·深度学习·计算机视觉·3d重建
兴趣使然黄小黄15 小时前
【AI-agent】LangChain开发智能体工具流程
人工智能·microsoft·langchain
出门吃三碗饭15 小时前
Transformer前世今生——使用pytorch实现多头注意力(八)
人工智能·深度学习·transformer
l1t15 小时前
利用DeepSeek改写SQLite版本的二进制位数独求解SQL
数据库·人工智能·sql·sqlite
说私域15 小时前
开源AI智能名片链动2+1模式S2B2C商城小程序FAQ设计及其意义探究
人工智能·小程序
开利网络16 小时前
合规底线:健康产品营销的红线与避坑指南
大数据·前端·人工智能·云计算·1024程序员节
非著名架构师16 小时前
量化“天气风险”:金融与保险机构如何利用气候大数据实现精准定价与投资决策
大数据·人工智能·新能源风光提高精度·疾风气象大模型4.0
巫婆理发22217 小时前
评估指标+数据不匹配+贝叶斯最优误差(分析方差和偏差)+迁移学习+多任务学习+端到端深度学习
深度学习·学习·迁移学习
熙梦数字化17 小时前
2025汽车零部件行业数字化转型落地方案
大数据·人工智能·汽车