>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客**
>- **🍖 原作者:[K同学啊]**
📌本周任务:📌
● 任务类型:自主探索⭐⭐
● 任务难度:偏难
●任务描述:
1、请根据J1~J3周的内容自有探索ResNet与DenseNet结合的可能性
2、是否可以根据两种特性构建一个新的模型框架?
3、请用之前的任一图像识别任务验证改进后模型的效果
🏡 我的环境:
- 语言环境:Python3.8
- 编译器:Jupyter Notebook
- 深度学习环境:Pytorch
-
- torch==2.3.1+cu118
-
- torchvision==0.18.1+cu118
一、论文导读
论文:Dual Path Networks
论文链接:https://arxiv.org/abs/1707.01629
代码:https://github.com/cypw/DPNs
MXNet框架下可训练模型的DPN代码:https://github.com/miraclewkf/DPN
残差网络[2]和DenseNet[3]是short-cut系列网络的最为经典的两个基础网络,其中残差网络通过单位加的方式直接将输入加到输出的卷积上,DenseNet则是通过拼接的方式将输出与之后的每一层的输入进行拼接。
DPN(Dual Path Networks)是一种网络结构,它结合了DensNet和ResNetXt两种思想的优点。这种结构的目的是通过不同的路径来利用神经网络的不同特性,从而提高模型的效率和性能。
DenseNet 的特点是其稠密连接路径,使得网络能够在不同层级之间持续地探索新的特征。这种连接方式允许网络在不增加参数的情况下学习到更丰富的特征表示。
ResNeXt(残差分组卷积)则是通过残差路径实现特征的复用,这有助于减少模型的大小和复杂度。
DPN的设计思想在于融合这两种思想,通过两个并行的路径来进行信息传递:
一条路径是通过DenseNet的方式,即通过稠密连接路径,这样可以持续地探索新的特征。
另一条路径是通过ResNeXt的方式,即通过残差路径,可以实现特征的复用。
此外,DPN使用了分组卷积来降低计算量,并且可以在不改变原有网络结构的前提下,提升性能,使其适合用于检测和分割任务作为新的Backbone网络。
总结:DPN可以说是融合了ResNeXt和DenseNet的核心思想:Dual Path Network(DPN)以ResNet为主要框架,保证了特征的低冗余度,并在其基础上添加了一个非常小的DenseNet分支,用于生成新的特征。
那么DPN到底有哪些优点呢?可以看以下两点:
1、关于模型复杂度 ,作者的原文是这么说的:The DPN-92 costs about 15% fewer parameters than ResNeXt-101 (32 4d), while the DPN-98 costs about 26% fewer parameters than ResNeXt-101 (64 4d).
2、关于计算复杂度 ,作者的原文是这么说的:DPN-92 consumes about 19% less FLOPs than ResNeXt-101(32 4d), and the DPN-98 consumes about 25% less FLOPs than ResNeXt-101(64 4d).
由上图可知,其实DPN和ResNeXt(ResNet)的结构很相似。最开始一个7*7的卷积层和max pooling层,然后是4个stage,每个stage包含几个sub-stage(后面会介绍),再接着是一个global average pooling和全连接层,最后是softmax层。重点在于stage里面的内容,也是DPN算法的核心。
因为DPN算法简单讲就是将ResNeXt和DenseNet融合成一个网络,因此在介绍DPN的每个stage里面的结构之前,先简单过一下ResNet(ResNeXt和ResNet的子结构在宏观上是一样的)和DenseNet的核心内容。
下图中的(a)是ResNet的某个stage中的一部分。(a)的左边竖着的大矩形框表示输入输出内容,对一个输入x,分两条线走,一条线还是x本身,另一条线是x经过1×1卷积,3×3卷积,1×1卷积(这三个卷积层的组合又称作bottleneck),然后把这两条线的输出做一个element-wise addition,也就是对应值相加,就是(a)中的加号,得到的结果又变成下一个同样模块的输入,几个这样的模块组合在一起就成了一个stage(比如Table1中的conv3)。
(b)表示DenseNet的核心内容。(b)的左边竖着的多边形框表示输入输出内容,对输入x,只走一条线,那就是经过几层卷积后和x做一个通道的合并(cancat),得到的结果又成了下一个小模块的输入,这样每一个小模块的输入都在不断累加,举个例子:第二个小模块的输入包含第一个小模块的输出和第一个小模块的输入,以此类推。
DPN是怎么做呢?简单讲就是将Residual Network 和 Densely Connected Network融合在一起。下图中的(d)和(e)是一个意思,所以就按(e)来讲吧。(e)中竖着的矩形框和多边形框的含义和前面一样。具体在代码中,对于一个输入x(分两种情况:一种是如果x是整个网络第一个卷积层的输出或者某个stage的输出,会对x做一个卷积,然后做slice,也就是将输出按照channel分成两部分:data_o1和data_o2,可以理解为(e)中竖着的矩形框和多边形框;另一种是在stage内部的某个sub-stage的输出,输出本身就包含两部分:data_o1和data_o2),走两条线,一条线是保持data_o1和data_o2本身,和ResNet类似;另一条线是对x做1×1卷积,3×3卷积,1×1卷积,然后再做slice得到两部分c1和c2,最后c1和data_o1做相加(element-wise addition)得到sum,类似ResNet中的操作;c2和data_o2做通道合并(concat)得到dense(这样下一层就可以得到这一层的输出和这一层的输入),也就是最后返回两个值:sum和dense。
以上这个过程就是DPN中 一个stage中的一个sub-stage。有两个细节,一个是3×3的卷积采用的是group操作,类似ResNeXt,另一个是在每个sub-stage的首尾都会对dense部分做一个通道的加宽操作。
作者在MXNet框架下实现了DPN算法,具体的symbol可以看:https://github.com/cypw/DPNs/tree/master/settings,介绍得非常详细也很容易读懂。
实验结果:
Table2是在ImageNet-1k数据集上和目前最好的几个算法的对比:ResNet,ResNeXt,DenseNet。可以看出在模型大小,GFLOP和准确率方面DPN网络都更胜一筹。不过在这个对比中好像DenseNet的表现不如DenseNet那篇论文介绍的那么喜人,可能是因为DenseNet的需要更多的训练技巧。
Figure3是关于训练速度和存储空间的对比。现在对于模型的改进,可能准确率方面的提升已经很难作为明显的创新点,因为幅度都不大,因此大部分还是在模型大小和计算复杂度上优化,同时只要准确率还能提高一点就算进步了。
总结:
作者提出的DPN网络可以理解为在ResNeXt的基础上引入了DenseNet的核心内容,使得模型对特征的利用更加充分。原理方面并不难理解,而且在跑代码过程中也比较容易训练,同时文章中的实验也表明模型在分类和检测的数据集上都有不错的效果。
参考文章:
DPN(Dual Path Network)算法详解_dpn(dual path network)算法详解-CSDN博客
DPN详解(Dual Path Networks) - 知乎 (zhihu.com)
解读Dual Path Networks(DPN,原创) - 知乎 (zhihu.com)
二、 前期准备
1. 设置GPU
如果设备上支持GPU就使用GPU,否则使用CPU
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
import torch
device=torch.device("cuda" if torch.cuda.is_available() else "CPU")
device
运行结果:
device(type='cuda')
2. 导入数据
import pathlib
data_dir=r'D:\THE MNIST DATABASE\J-series\J1\bird_photos'
data_dir=pathlib.Path(data_dir)
img_count=len(list(data_dir.glob('*/*')))
print("图片总数为:",img_count)
运行结果:
图片总数为: 565
3. 查看数据集分类
data_paths=list(data_dir.glob('*'))
classNames=[str(path).split('\\')[5] for path in data_paths]
classNames
运行结果:
['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
import PIL,random
4. 随机查看图片
随机抽取数据集中的10张图片进行查看
import PIL,random
import matplotlib.pyplot as plt
from PIL import Image
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号
data_paths2=list(data_dir.glob('*/*'))
plt.figure(figsize=(20,4))
plt.suptitle("OreoCC的案例",fontsize=15)
for i in range(10):
plt.subplot(2,5,i+1)
plt.axis("off")
image=random.choice(data_paths2) #随机选择一个图片
plt.title(image.parts[-2]) #通过glob对象取出他的文件夹名称,即分类名
plt.imshow(Image.open(str(image))) #显示图片
运行结果:
5. 图片预处理
import torchvision.transforms as transforms
from torchvision import transforms,datasets
train_transforms=transforms.Compose([
transforms.Resize([224,224]), #将图片统一尺寸
transforms.RandomHorizontalFlip(), #将图片随机水平翻转
transforms.RandomRotation(0.2), #将图片按照0.2的弧度值随机旋转
transforms.ToTensor(), #将图片转换为tensor
transforms.Normalize( #标准化处理->转换为正态分布,使模型更容易收敛
mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225]
)
])
total_data=datasets.ImageFolder(
r"D:\THE MNIST DATABASE\J-series\J1\bird_photos",
transform=train_transforms
)
total_data
运行结果:
Dataset ImageFolder
Number of datapoints: 565
Root location: D:\THE MNIST DATABASE\J-series\J1\bird_photos
StandardTransform
Transform: Compose(
Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
RandomHorizontalFlip(p=0.5)
RandomRotation(degrees=[-0.2, 0.2], interpolation=nearest, expand=False, fill=0)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
将数据集分类情况进行映射输出:
total_data.class_to_idx
运行结果:
{'Bananaquit': 0,
'Black Skimmer': 1,
'Black Throated Bushtiti': 2,
'Cockatoo': 3}
6. 划分数据集
train_size=int(0.8*len(total_data))
test_size=len(total_data)-train_size
train_dataset,test_dataset=torch.utils.data.random_split(
total_data,[train_size,test_size]
)
train_dataset,test_dataset
运行结果:
(<torch.utils.data.dataset.Subset at 0x270de0de310>,
<torch.utils.data.dataset.Subset at 0x270de0de950>)
查看训练集和测试集的数据数量:
train_size,test_size
运行结果:
(452, 113)
7. 加载数据集
batch_size=16
train_dl=torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=1
)
test_dl=torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=1
)
查看测试集的情况:
for x,y in train_dl:
print("Shape of x [N,C,H,W]:",x.shape)
print("Shape of y:",y.shape,y.dtype)
break
运行结果:
Shape of x [N,C,H,W]: torch.Size([16, 3, 224, 224])
Shape of y: torch.Size([16]) torch.int64
二、手动搭建DPN模型
1、搭建DPN模型
import torch
import torch.nn as nn
class Block(nn.Module):
"""
param:in_channel--输入通道数
mid_channel--中间经历的通道数
out_channel--ResNet部分使用的通道数(sum操作,这部分输出仍然是out_channel 1个通道)
dense_channel--DenseNet部分使用的通道数(concat操作,这部分输出是2*dense_channel 1个通道)
groups--conv2中的分组卷积参数
is_shortcut--ResNet前是否进行shortcut操作
"""
def __init__(self,in_channel,mid_channel,out_channel,dense_channel,stride,groups,is_shortcut=False):
super(Block,self).__init__()
self.is_shortcut=is_shortcut
self.out_channel=out_channel
self.conv1=nn.Sequential(
nn.Conv2d(in_channel,mid_channel,kernel_size=1,bias=False),
nn.BatchNorm2d(mid_channel),
nn.ReLU()
)
self.conv2=nn.Sequential(
nn.Conv2d(mid_channel,mid_channel,kernel_size=3,stride=stride,padding=1,groups=groups,bias=False),
nn.BatchNorm2d(mid_channel),
nn.ReLU()
)
self.conv3=nn.Sequential(
nn.Conv2d(mid_channel,out_channel+dense_channel,kernel_size=1,bias=False),
nn.BatchNorm2d(out_channel+dense_channel)
)
if self.is_shortcut:
self.shortcut=nn.Sequential(
nn.Conv2d(in_channel,out_channel+dense_channel,kernel_size=3,padding=1,stride=stride,bias=False),
nn.BatchNorm2d(out_channel+dense_channel)
)
self.relu=nn.ReLU(inplace=True)
def forward(self,x):
a=x
x=self.conv1(x)
x=self.conv2(x)
x=self.conv3(x)
if self.is_shortcut:
a=self.shortcut(a)
#a[:,:self.out_channel,:,:]+[:,:self.out_channel,:,:]是使用ResNet的方法,
#即采用sum的方式将特征图进行求和,通道数不变,都是out-channel个通道
#[a[:,self.out_channel,:,:],x[:,self.out_channel:,:,:]]是使用DenseNet的方法,
#即采用concat的方式将特征图在channel维度上直接进行叠加,通道数加倍,即2*dense_channel
x=torch.cat([a[:,:self.out_channel,:,:]+x[:,:self.out_channel,:,:],
a[:,self.out_channel:,:,:],x[:,self.out_channel:,:,:]],dim=1)
x=self.relu(x)
return x
class DPN(nn.Module):
def __init__(self,cfg):
super(DPN,self).__init__()
self.group=cfg['group']
self.in_channel=cfg['in_channel']
mid_channels=cfg['mid_channels']
out_channels=cfg['out_channels']
dense_channels=cfg['dense_channels']
num=cfg['num']
self.conv1=nn.Sequential(
nn.Conv2d(3,self.in_channel,7,stride=2,padding=3,bias=False,padding_mode='zeros'),
nn.BatchNorm2d(self.in_channel),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3,stride=2,padding=0)
)
self.conv2=self._make_layers(mid_channels[0],out_channels[0],dense_channels[0],num[0],stride=1)
self.conv3=self._make_layers(mid_channels[1],out_channels[1],dense_channels[1],num[1],stride=2)
self.conv4=self._make_layers(mid_channels[2],out_channels[2],dense_channels[2],num[2],stride=2)
self.conv5=self._make_layers(mid_channels[3],out_channels[3],dense_channels[3],num[3],stride=2)
self.pool=nn.AdaptiveAvgPool2d((1,1))
self.fc=nn.Linear(cfg['out_channels'][3]+(num[3]+1)*cfg['dense_channels'][3],cfg['classes']) #fc层需要计算
def _make_layers(self,mid_channel,out_channel,dense_channel,num,stride):
layers=[]
"""is_shortcut=True表示进行shortcut操作,则将浅层的特征进行一次卷积后与进行第三次卷积的特征图相加
(ResNet方式)和concat(DenseNet方式)操作"""
"""第一次使用Block可以满足浅层特征的利用,后续重复的Block则不需要浅层特征,因此后续的Block的
is_shortcut=False(默认值)"""
layers.append(Block(self.in_channel,mid_channel,out_channel,dense_channel,
stride=stride,groups=self.group,is_shortcut=True))
self.in_channel=out_channel+dense_channel*2
for i in range(1,num):
layers.append(Block(self.in_channel,mid_channel,out_channel,dense_channel,
stride=1,groups=self.group))
"""由于Block包含DenseNet在叠加特征图,所以第一次是2倍dense_channel,
后面每次都会多出1倍dense_channel"""
self.in_channel+=dense_channel
return nn.Sequential(*layers)
def forward(self,x):
x=self.conv1(x)
x=self.conv2(x)
x=self.conv3(x)
x=self.conv4(x)
x=self.conv5(x)
x=self.pool(x)
x=torch.flatten(x,start_dim=1)
x=self.fc(x)
return x
2、建立DPN92并显示模型结构
def DPN92(n_class=4):
cfg={
"group":32,
"in_channel":64,
"mid_channels":(96,192,384,768),
"out_channels":(256,512,1024,2048),
"dense_channels":(16,32,24,128),
"num":(3,4,20,3),
"classes":(n_class)
}
return DPN(cfg)
def DPN98(n_class4):
cfg={
"group":40,
"in_channel":96,
"mid_channels":(160,320,640,1280),
"out_channels":(256,512,1024,2048),
"dense_channels":(16,32,32,128),
"num":(3,6,20,3),
"classes":(n_class)
}
return DPN(cfg)
model=DPN92().to(device)
model
运行结果:
DPN(
(conv1): Sequential(
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(conv2): Sequential(
(0): Block(
(conv1): Sequential(
(0): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(96, 272, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(272, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential(
(0): Conv2d(64, 272, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(272, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(1): Block(
(conv1): Sequential(
(0): Conv2d(288, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(96, 272, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(272, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(2): Block(
(conv1): Sequential(
(0): Conv2d(304, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(96, 272, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(272, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
)
(conv3): Sequential(
(0): Block(
(conv1): Sequential(
(0): Conv2d(320, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(192, 544, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential(
(0): Conv2d(320, 544, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(1): Block(
(conv1): Sequential(
(0): Conv2d(576, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(192, 544, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(2): Block(
(conv1): Sequential(
(0): Conv2d(608, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(192, 544, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(3): Block(
(conv1): Sequential(
(0): Conv2d(640, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(192, 544, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
)
(conv4): Sequential(
(0): Block(
(conv1): Sequential(
(0): Conv2d(672, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential(
(0): Conv2d(672, 1048, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(1): Block(
(conv1): Sequential(
(0): Conv2d(1072, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(2): Block(
(conv1): Sequential(
(0): Conv2d(1096, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(3): Block(
(conv1): Sequential(
(0): Conv2d(1120, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(4): Block(
(conv1): Sequential(
(0): Conv2d(1144, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(5): Block(
(conv1): Sequential(
(0): Conv2d(1168, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(6): Block(
(conv1): Sequential(
(0): Conv2d(1192, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(7): Block(
(conv1): Sequential(
(0): Conv2d(1216, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(8): Block(
(conv1): Sequential(
(0): Conv2d(1240, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(9): Block(
(conv1): Sequential(
(0): Conv2d(1264, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(10): Block(
(conv1): Sequential(
(0): Conv2d(1288, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(11): Block(
(conv1): Sequential(
(0): Conv2d(1312, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(12): Block(
(conv1): Sequential(
(0): Conv2d(1336, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(13): Block(
(conv1): Sequential(
(0): Conv2d(1360, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(14): Block(
(conv1): Sequential(
(0): Conv2d(1384, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(15): Block(
(conv1): Sequential(
(0): Conv2d(1408, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(16): Block(
(conv1): Sequential(
(0): Conv2d(1432, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(17): Block(
(conv1): Sequential(
(0): Conv2d(1456, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(18): Block(
(conv1): Sequential(
(0): Conv2d(1480, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(19): Block(
(conv1): Sequential(
(0): Conv2d(1504, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
)
(conv5): Sequential(
(0): Block(
(conv1): Sequential(
(0): Conv2d(1528, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(768, 768, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(768, 2176, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(2176, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential(
(0): Conv2d(1528, 2176, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(2176, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(1): Block(
(conv1): Sequential(
(0): Conv2d(2304, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(768, 2176, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(2176, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(2): Block(
(conv1): Sequential(
(0): Conv2d(2432, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv2): Sequential(
(0): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv3): Sequential(
(0): Conv2d(768, 2176, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(2176, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
)
(pool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=2560, out_features=4, bias=True)
)
3、查看模型详情
#统计模型参数量以及其他指标
import torchsummary as summary
summary.summary(model,(3,224,224))
运行结果:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 55, 55] 0
Conv2d-5 [-1, 96, 55, 55] 6,144
BatchNorm2d-6 [-1, 96, 55, 55] 192
ReLU-7 [-1, 96, 55, 55] 0
Conv2d-8 [-1, 96, 55, 55] 2,592
BatchNorm2d-9 [-1, 96, 55, 55] 192
ReLU-10 [-1, 96, 55, 55] 0
Conv2d-11 [-1, 272, 55, 55] 26,112
BatchNorm2d-12 [-1, 272, 55, 55] 544
Conv2d-13 [-1, 272, 55, 55] 156,672
BatchNorm2d-14 [-1, 272, 55, 55] 544
ReLU-15 [-1, 288, 55, 55] 0
Block-16 [-1, 288, 55, 55] 0
Conv2d-17 [-1, 96, 55, 55] 27,648
BatchNorm2d-18 [-1, 96, 55, 55] 192
ReLU-19 [-1, 96, 55, 55] 0
Conv2d-20 [-1, 96, 55, 55] 2,592
BatchNorm2d-21 [-1, 96, 55, 55] 192
ReLU-22 [-1, 96, 55, 55] 0
Conv2d-23 [-1, 272, 55, 55] 26,112
BatchNorm2d-24 [-1, 272, 55, 55] 544
ReLU-25 [-1, 304, 55, 55] 0
Block-26 [-1, 304, 55, 55] 0
Conv2d-27 [-1, 96, 55, 55] 29,184
BatchNorm2d-28 [-1, 96, 55, 55] 192
ReLU-29 [-1, 96, 55, 55] 0
Conv2d-30 [-1, 96, 55, 55] 2,592
BatchNorm2d-31 [-1, 96, 55, 55] 192
ReLU-32 [-1, 96, 55, 55] 0
Conv2d-33 [-1, 272, 55, 55] 26,112
BatchNorm2d-34 [-1, 272, 55, 55] 544
ReLU-35 [-1, 320, 55, 55] 0
Block-36 [-1, 320, 55, 55] 0
Conv2d-37 [-1, 192, 55, 55] 61,440
BatchNorm2d-38 [-1, 192, 55, 55] 384
ReLU-39 [-1, 192, 55, 55] 0
Conv2d-40 [-1, 192, 28, 28] 10,368
BatchNorm2d-41 [-1, 192, 28, 28] 384
ReLU-42 [-1, 192, 28, 28] 0
Conv2d-43 [-1, 544, 28, 28] 104,448
BatchNorm2d-44 [-1, 544, 28, 28] 1,088
Conv2d-45 [-1, 544, 28, 28] 1,566,720
BatchNorm2d-46 [-1, 544, 28, 28] 1,088
ReLU-47 [-1, 576, 28, 28] 0
Block-48 [-1, 576, 28, 28] 0
Conv2d-49 [-1, 192, 28, 28] 110,592
BatchNorm2d-50 [-1, 192, 28, 28] 384
ReLU-51 [-1, 192, 28, 28] 0
Conv2d-52 [-1, 192, 28, 28] 10,368
BatchNorm2d-53 [-1, 192, 28, 28] 384
ReLU-54 [-1, 192, 28, 28] 0
Conv2d-55 [-1, 544, 28, 28] 104,448
BatchNorm2d-56 [-1, 544, 28, 28] 1,088
ReLU-57 [-1, 608, 28, 28] 0
Block-58 [-1, 608, 28, 28] 0
Conv2d-59 [-1, 192, 28, 28] 116,736
BatchNorm2d-60 [-1, 192, 28, 28] 384
ReLU-61 [-1, 192, 28, 28] 0
Conv2d-62 [-1, 192, 28, 28] 10,368
BatchNorm2d-63 [-1, 192, 28, 28] 384
ReLU-64 [-1, 192, 28, 28] 0
Conv2d-65 [-1, 544, 28, 28] 104,448
BatchNorm2d-66 [-1, 544, 28, 28] 1,088
ReLU-67 [-1, 640, 28, 28] 0
Block-68 [-1, 640, 28, 28] 0
Conv2d-69 [-1, 192, 28, 28] 122,880
BatchNorm2d-70 [-1, 192, 28, 28] 384
ReLU-71 [-1, 192, 28, 28] 0
Conv2d-72 [-1, 192, 28, 28] 10,368
BatchNorm2d-73 [-1, 192, 28, 28] 384
ReLU-74 [-1, 192, 28, 28] 0
Conv2d-75 [-1, 544, 28, 28] 104,448
BatchNorm2d-76 [-1, 544, 28, 28] 1,088
ReLU-77 [-1, 672, 28, 28] 0
Block-78 [-1, 672, 28, 28] 0
Conv2d-79 [-1, 384, 28, 28] 258,048
BatchNorm2d-80 [-1, 384, 28, 28] 768
ReLU-81 [-1, 384, 28, 28] 0
Conv2d-82 [-1, 384, 14, 14] 41,472
BatchNorm2d-83 [-1, 384, 14, 14] 768
ReLU-84 [-1, 384, 14, 14] 0
Conv2d-85 [-1, 1048, 14, 14] 402,432
BatchNorm2d-86 [-1, 1048, 14, 14] 2,096
Conv2d-87 [-1, 1048, 14, 14] 6,338,304
BatchNorm2d-88 [-1, 1048, 14, 14] 2,096
ReLU-89 [-1, 1072, 14, 14] 0
Block-90 [-1, 1072, 14, 14] 0
Conv2d-91 [-1, 384, 14, 14] 411,648
BatchNorm2d-92 [-1, 384, 14, 14] 768
ReLU-93 [-1, 384, 14, 14] 0
Conv2d-94 [-1, 384, 14, 14] 41,472
BatchNorm2d-95 [-1, 384, 14, 14] 768
ReLU-96 [-1, 384, 14, 14] 0
Conv2d-97 [-1, 1048, 14, 14] 402,432
BatchNorm2d-98 [-1, 1048, 14, 14] 2,096
ReLU-99 [-1, 1096, 14, 14] 0
Block-100 [-1, 1096, 14, 14] 0
Conv2d-101 [-1, 384, 14, 14] 420,864
BatchNorm2d-102 [-1, 384, 14, 14] 768
ReLU-103 [-1, 384, 14, 14] 0
Conv2d-104 [-1, 384, 14, 14] 41,472
BatchNorm2d-105 [-1, 384, 14, 14] 768
ReLU-106 [-1, 384, 14, 14] 0
Conv2d-107 [-1, 1048, 14, 14] 402,432
BatchNorm2d-108 [-1, 1048, 14, 14] 2,096
ReLU-109 [-1, 1120, 14, 14] 0
Block-110 [-1, 1120, 14, 14] 0
Conv2d-111 [-1, 384, 14, 14] 430,080
BatchNorm2d-112 [-1, 384, 14, 14] 768
ReLU-113 [-1, 384, 14, 14] 0
Conv2d-114 [-1, 384, 14, 14] 41,472
BatchNorm2d-115 [-1, 384, 14, 14] 768
ReLU-116 [-1, 384, 14, 14] 0
Conv2d-117 [-1, 1048, 14, 14] 402,432
BatchNorm2d-118 [-1, 1048, 14, 14] 2,096
ReLU-119 [-1, 1144, 14, 14] 0
Block-120 [-1, 1144, 14, 14] 0
Conv2d-121 [-1, 384, 14, 14] 439,296
BatchNorm2d-122 [-1, 384, 14, 14] 768
ReLU-123 [-1, 384, 14, 14] 0
Conv2d-124 [-1, 384, 14, 14] 41,472
BatchNorm2d-125 [-1, 384, 14, 14] 768
ReLU-126 [-1, 384, 14, 14] 0
Conv2d-127 [-1, 1048, 14, 14] 402,432
BatchNorm2d-128 [-1, 1048, 14, 14] 2,096
ReLU-129 [-1, 1168, 14, 14] 0
Block-130 [-1, 1168, 14, 14] 0
Conv2d-131 [-1, 384, 14, 14] 448,512
BatchNorm2d-132 [-1, 384, 14, 14] 768
ReLU-133 [-1, 384, 14, 14] 0
Conv2d-134 [-1, 384, 14, 14] 41,472
BatchNorm2d-135 [-1, 384, 14, 14] 768
ReLU-136 [-1, 384, 14, 14] 0
Conv2d-137 [-1, 1048, 14, 14] 402,432
BatchNorm2d-138 [-1, 1048, 14, 14] 2,096
ReLU-139 [-1, 1192, 14, 14] 0
Block-140 [-1, 1192, 14, 14] 0
Conv2d-141 [-1, 384, 14, 14] 457,728
BatchNorm2d-142 [-1, 384, 14, 14] 768
ReLU-143 [-1, 384, 14, 14] 0
Conv2d-144 [-1, 384, 14, 14] 41,472
BatchNorm2d-145 [-1, 384, 14, 14] 768
ReLU-146 [-1, 384, 14, 14] 0
Conv2d-147 [-1, 1048, 14, 14] 402,432
BatchNorm2d-148 [-1, 1048, 14, 14] 2,096
ReLU-149 [-1, 1216, 14, 14] 0
Block-150 [-1, 1216, 14, 14] 0
Conv2d-151 [-1, 384, 14, 14] 466,944
BatchNorm2d-152 [-1, 384, 14, 14] 768
ReLU-153 [-1, 384, 14, 14] 0
Conv2d-154 [-1, 384, 14, 14] 41,472
BatchNorm2d-155 [-1, 384, 14, 14] 768
ReLU-156 [-1, 384, 14, 14] 0
Conv2d-157 [-1, 1048, 14, 14] 402,432
BatchNorm2d-158 [-1, 1048, 14, 14] 2,096
ReLU-159 [-1, 1240, 14, 14] 0
Block-160 [-1, 1240, 14, 14] 0
Conv2d-161 [-1, 384, 14, 14] 476,160
BatchNorm2d-162 [-1, 384, 14, 14] 768
ReLU-163 [-1, 384, 14, 14] 0
Conv2d-164 [-1, 384, 14, 14] 41,472
BatchNorm2d-165 [-1, 384, 14, 14] 768
ReLU-166 [-1, 384, 14, 14] 0
Conv2d-167 [-1, 1048, 14, 14] 402,432
BatchNorm2d-168 [-1, 1048, 14, 14] 2,096
ReLU-169 [-1, 1264, 14, 14] 0
Block-170 [-1, 1264, 14, 14] 0
Conv2d-171 [-1, 384, 14, 14] 485,376
BatchNorm2d-172 [-1, 384, 14, 14] 768
ReLU-173 [-1, 384, 14, 14] 0
Conv2d-174 [-1, 384, 14, 14] 41,472
BatchNorm2d-175 [-1, 384, 14, 14] 768
ReLU-176 [-1, 384, 14, 14] 0
Conv2d-177 [-1, 1048, 14, 14] 402,432
BatchNorm2d-178 [-1, 1048, 14, 14] 2,096
ReLU-179 [-1, 1288, 14, 14] 0
Block-180 [-1, 1288, 14, 14] 0
Conv2d-181 [-1, 384, 14, 14] 494,592
BatchNorm2d-182 [-1, 384, 14, 14] 768
ReLU-183 [-1, 384, 14, 14] 0
Conv2d-184 [-1, 384, 14, 14] 41,472
BatchNorm2d-185 [-1, 384, 14, 14] 768
ReLU-186 [-1, 384, 14, 14] 0
Conv2d-187 [-1, 1048, 14, 14] 402,432
BatchNorm2d-188 [-1, 1048, 14, 14] 2,096
ReLU-189 [-1, 1312, 14, 14] 0
Block-190 [-1, 1312, 14, 14] 0
Conv2d-191 [-1, 384, 14, 14] 503,808
BatchNorm2d-192 [-1, 384, 14, 14] 768
ReLU-193 [-1, 384, 14, 14] 0
Conv2d-194 [-1, 384, 14, 14] 41,472
BatchNorm2d-195 [-1, 384, 14, 14] 768
ReLU-196 [-1, 384, 14, 14] 0
Conv2d-197 [-1, 1048, 14, 14] 402,432
BatchNorm2d-198 [-1, 1048, 14, 14] 2,096
ReLU-199 [-1, 1336, 14, 14] 0
Block-200 [-1, 1336, 14, 14] 0
Conv2d-201 [-1, 384, 14, 14] 513,024
BatchNorm2d-202 [-1, 384, 14, 14] 768
ReLU-203 [-1, 384, 14, 14] 0
Conv2d-204 [-1, 384, 14, 14] 41,472
BatchNorm2d-205 [-1, 384, 14, 14] 768
ReLU-206 [-1, 384, 14, 14] 0
Conv2d-207 [-1, 1048, 14, 14] 402,432
BatchNorm2d-208 [-1, 1048, 14, 14] 2,096
ReLU-209 [-1, 1360, 14, 14] 0
Block-210 [-1, 1360, 14, 14] 0
Conv2d-211 [-1, 384, 14, 14] 522,240
BatchNorm2d-212 [-1, 384, 14, 14] 768
ReLU-213 [-1, 384, 14, 14] 0
Conv2d-214 [-1, 384, 14, 14] 41,472
BatchNorm2d-215 [-1, 384, 14, 14] 768
ReLU-216 [-1, 384, 14, 14] 0
Conv2d-217 [-1, 1048, 14, 14] 402,432
BatchNorm2d-218 [-1, 1048, 14, 14] 2,096
ReLU-219 [-1, 1384, 14, 14] 0
Block-220 [-1, 1384, 14, 14] 0
Conv2d-221 [-1, 384, 14, 14] 531,456
BatchNorm2d-222 [-1, 384, 14, 14] 768
ReLU-223 [-1, 384, 14, 14] 0
Conv2d-224 [-1, 384, 14, 14] 41,472
BatchNorm2d-225 [-1, 384, 14, 14] 768
ReLU-226 [-1, 384, 14, 14] 0
Conv2d-227 [-1, 1048, 14, 14] 402,432
BatchNorm2d-228 [-1, 1048, 14, 14] 2,096
ReLU-229 [-1, 1408, 14, 14] 0
Block-230 [-1, 1408, 14, 14] 0
Conv2d-231 [-1, 384, 14, 14] 540,672
BatchNorm2d-232 [-1, 384, 14, 14] 768
ReLU-233 [-1, 384, 14, 14] 0
Conv2d-234 [-1, 384, 14, 14] 41,472
BatchNorm2d-235 [-1, 384, 14, 14] 768
ReLU-236 [-1, 384, 14, 14] 0
Conv2d-237 [-1, 1048, 14, 14] 402,432
BatchNorm2d-238 [-1, 1048, 14, 14] 2,096
ReLU-239 [-1, 1432, 14, 14] 0
Block-240 [-1, 1432, 14, 14] 0
Conv2d-241 [-1, 384, 14, 14] 549,888
BatchNorm2d-242 [-1, 384, 14, 14] 768
ReLU-243 [-1, 384, 14, 14] 0
Conv2d-244 [-1, 384, 14, 14] 41,472
BatchNorm2d-245 [-1, 384, 14, 14] 768
ReLU-246 [-1, 384, 14, 14] 0
Conv2d-247 [-1, 1048, 14, 14] 402,432
BatchNorm2d-248 [-1, 1048, 14, 14] 2,096
ReLU-249 [-1, 1456, 14, 14] 0
Block-250 [-1, 1456, 14, 14] 0
Conv2d-251 [-1, 384, 14, 14] 559,104
BatchNorm2d-252 [-1, 384, 14, 14] 768
ReLU-253 [-1, 384, 14, 14] 0
Conv2d-254 [-1, 384, 14, 14] 41,472
BatchNorm2d-255 [-1, 384, 14, 14] 768
ReLU-256 [-1, 384, 14, 14] 0
Conv2d-257 [-1, 1048, 14, 14] 402,432
BatchNorm2d-258 [-1, 1048, 14, 14] 2,096
ReLU-259 [-1, 1480, 14, 14] 0
Block-260 [-1, 1480, 14, 14] 0
Conv2d-261 [-1, 384, 14, 14] 568,320
BatchNorm2d-262 [-1, 384, 14, 14] 768
ReLU-263 [-1, 384, 14, 14] 0
Conv2d-264 [-1, 384, 14, 14] 41,472
BatchNorm2d-265 [-1, 384, 14, 14] 768
ReLU-266 [-1, 384, 14, 14] 0
Conv2d-267 [-1, 1048, 14, 14] 402,432
BatchNorm2d-268 [-1, 1048, 14, 14] 2,096
ReLU-269 [-1, 1504, 14, 14] 0
Block-270 [-1, 1504, 14, 14] 0
Conv2d-271 [-1, 384, 14, 14] 577,536
BatchNorm2d-272 [-1, 384, 14, 14] 768
ReLU-273 [-1, 384, 14, 14] 0
Conv2d-274 [-1, 384, 14, 14] 41,472
BatchNorm2d-275 [-1, 384, 14, 14] 768
ReLU-276 [-1, 384, 14, 14] 0
Conv2d-277 [-1, 1048, 14, 14] 402,432
BatchNorm2d-278 [-1, 1048, 14, 14] 2,096
ReLU-279 [-1, 1528, 14, 14] 0
Block-280 [-1, 1528, 14, 14] 0
Conv2d-281 [-1, 768, 14, 14] 1,173,504
BatchNorm2d-282 [-1, 768, 14, 14] 1,536
ReLU-283 [-1, 768, 14, 14] 0
Conv2d-284 [-1, 768, 7, 7] 165,888
BatchNorm2d-285 [-1, 768, 7, 7] 1,536
ReLU-286 [-1, 768, 7, 7] 0
Conv2d-287 [-1, 2176, 7, 7] 1,671,168
BatchNorm2d-288 [-1, 2176, 7, 7] 4,352
Conv2d-289 [-1, 2176, 7, 7] 29,924,352
BatchNorm2d-290 [-1, 2176, 7, 7] 4,352
ReLU-291 [-1, 2304, 7, 7] 0
Block-292 [-1, 2304, 7, 7] 0
Conv2d-293 [-1, 768, 7, 7] 1,769,472
BatchNorm2d-294 [-1, 768, 7, 7] 1,536
ReLU-295 [-1, 768, 7, 7] 0
Conv2d-296 [-1, 768, 7, 7] 165,888
BatchNorm2d-297 [-1, 768, 7, 7] 1,536
ReLU-298 [-1, 768, 7, 7] 0
Conv2d-299 [-1, 2176, 7, 7] 1,671,168
BatchNorm2d-300 [-1, 2176, 7, 7] 4,352
ReLU-301 [-1, 2432, 7, 7] 0
Block-302 [-1, 2432, 7, 7] 0
Conv2d-303 [-1, 768, 7, 7] 1,867,776
BatchNorm2d-304 [-1, 768, 7, 7] 1,536
ReLU-305 [-1, 768, 7, 7] 0
Conv2d-306 [-1, 768, 7, 7] 165,888
BatchNorm2d-307 [-1, 768, 7, 7] 1,536
ReLU-308 [-1, 768, 7, 7] 0
Conv2d-309 [-1, 2176, 7, 7] 1,671,168
BatchNorm2d-310 [-1, 2176, 7, 7] 4,352
ReLU-311 [-1, 2560, 7, 7] 0
Block-312 [-1, 2560, 7, 7] 0
AdaptiveAvgPool2d-313 [-1, 2560, 1, 1] 0
Linear-314 [-1, 4] 10,244
================================================================
Total params: 67,994,324
Trainable params: 67,994,324
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 489.24
Params size (MB): 259.38
Estimated Total Size (MB): 749.20
----------------------------------------------------------------
三、 训练模型
1. 编写训练函数
#训练循环
def train(dataloader,model,loss_fn,optimizer):
size=len(dataloader.dataset) #训练集的大小
num_batches=len(dataloader) #批次数目,(size/batch_size,向上取整)
train_loss,train_acc=0,0 #初始化训练损失和正确率
for x,y in dataloader: #获取图片及其标签
x,y=x.to(device),y.to(device)
#计算预测误差
pred=model(x) #网络输出
loss=loss_fn(pred,y) #计算网络输出pred和真实值y之间的差距,y为真实值,计算二者误差即为损失
#反向传播
optimizer.zero_grad() #grad属性归零
loss.backward() #反向传播
optimizer.step() #每一步自动更新
#记录acc与loss
train_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()
train_loss+=loss.item()
train_acc/=size
train_loss/=num_batches
return train_acc,train_loss
2. 编写测试函数
测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器
def test(dataloader,model,loss_fn):
size=len(dataloader.dataset) #训练集的大小
num_batches=len(dataloader) #批次数目,(size/batch_size,向上取整)
test_loss,test_acc=0,0 #初始化测试损失和正确率
#当不进行训练时,停止梯度更新,节省计算内存消耗
for imgs,target in dataloader: #获取图片及其标签
with torch.no_grad():
imgs,target=imgs.to(device),target.to(device)
#计算误差
target_pred=model(imgs) #网络输出
#计算网络输出和真实值之间的差距,targets为真实值,计算二者误差即为损失
loss=loss_fn(target_pred,target)
#记录acc和loss
test_loss+=loss.item()
test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()
test_acc/=size
test_loss/=num_batches
return test_acc,test_loss
3. 正式训练
import copy
optimizer=torch.optim.Adam(model.parameters(),lr=1e-4)
loss_fn=nn.CrossEntropyLoss() #创建损失函数
epochs=40
train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]
best_acc=0 #设置一个最佳准确率,作为最佳模型的判别指标
#释放未使用的GPU内存,以便其他GPU应用程序可以使用这些资源
if hasattr(torch.cuda,'empty_cache'):
torch.cuda.empty_cache()
for epoch in range(epochs):
model.train()
epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,optimizer)
#scheduler.step() #更新学习率(调用官方动态学习率接口时使用)
model.eval()
epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)
#保存最佳模型到best_model
if epoch_test_acc>best_acc:
best_acc=epoch_test_acc
best_model=copy.deepcopy(model)
train_acc.append(epoch_train_acc)
train_loss.append(epoch_train_loss)
test_acc.append(epoch_test_acc)
test_loss.append(epoch_test_loss)
#获取当前学习率
lr=optimizer.state_dict()['param_groups'][0]['lr']
template=('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},Lr:{:.2E}')
print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,
epoch_test_acc*100,epoch_test_loss,lr))
PATH=r'D:\THE MNIST DATABASE\J-series\J4_best_model.pth'
torch.save(model.state_dict(),PATH)
print('Done')
运行结果:
Epoch: 1,Train_acc:60.4%,Train_loss:0.994,Test_acc:48.7%,Test_loss:1.224,Lr:1.00E-04
Epoch: 2,Train_acc:70.1%,Train_loss:0.830,Test_acc:79.6%,Test_loss:0.542,Lr:1.00E-04
Epoch: 3,Train_acc:79.2%,Train_loss:0.570,Test_acc:61.9%,Test_loss:1.461,Lr:1.00E-04
Epoch: 4,Train_acc:83.0%,Train_loss:0.491,Test_acc:81.4%,Test_loss:0.896,Lr:1.00E-04
Epoch: 5,Train_acc:84.7%,Train_loss:0.435,Test_acc:83.2%,Test_loss:0.541,Lr:1.00E-04
Epoch: 6,Train_acc:86.3%,Train_loss:0.428,Test_acc:90.3%,Test_loss:0.404,Lr:1.00E-04
Epoch: 7,Train_acc:88.5%,Train_loss:0.301,Test_acc:65.5%,Test_loss:1.831,Lr:1.00E-04
Epoch: 8,Train_acc:85.4%,Train_loss:0.406,Test_acc:73.5%,Test_loss:0.829,Lr:1.00E-04
Epoch: 9,Train_acc:90.3%,Train_loss:0.296,Test_acc:73.5%,Test_loss:1.233,Lr:1.00E-04
Epoch:10,Train_acc:91.2%,Train_loss:0.253,Test_acc:89.4%,Test_loss:0.512,Lr:1.00E-04
Epoch:11,Train_acc:92.9%,Train_loss:0.208,Test_acc:52.2%,Test_loss:2.842,Lr:1.00E-04
Epoch:12,Train_acc:95.6%,Train_loss:0.134,Test_acc:90.3%,Test_loss:0.292,Lr:1.00E-04
Epoch:13,Train_acc:91.2%,Train_loss:0.215,Test_acc:79.6%,Test_loss:0.806,Lr:1.00E-04
Epoch:14,Train_acc:93.4%,Train_loss:0.236,Test_acc:90.3%,Test_loss:0.511,Lr:1.00E-04
Epoch:15,Train_acc:90.0%,Train_loss:0.219,Test_acc:88.5%,Test_loss:0.388,Lr:1.00E-04
Epoch:16,Train_acc:95.8%,Train_loss:0.099,Test_acc:90.3%,Test_loss:0.422,Lr:1.00E-04
Epoch:17,Train_acc:98.0%,Train_loss:0.103,Test_acc:81.4%,Test_loss:0.491,Lr:1.00E-04
Epoch:18,Train_acc:95.8%,Train_loss:0.146,Test_acc:86.7%,Test_loss:0.475,Lr:1.00E-04
Epoch:19,Train_acc:95.1%,Train_loss:0.108,Test_acc:89.4%,Test_loss:1.873,Lr:1.00E-04
Epoch:20,Train_acc:98.5%,Train_loss:0.065,Test_acc:92.0%,Test_loss:0.940,Lr:1.00E-04
Epoch:21,Train_acc:95.4%,Train_loss:0.151,Test_acc:87.6%,Test_loss:0.398,Lr:1.00E-04
Epoch:22,Train_acc:96.7%,Train_loss:0.101,Test_acc:77.0%,Test_loss:0.911,Lr:1.00E-04
Epoch:23,Train_acc:97.8%,Train_loss:0.064,Test_acc:93.8%,Test_loss:0.249,Lr:1.00E-04
Epoch:24,Train_acc:97.6%,Train_loss:0.073,Test_acc:87.6%,Test_loss:0.883,Lr:1.00E-04
Epoch:25,Train_acc:98.2%,Train_loss:0.068,Test_acc:92.9%,Test_loss:0.245,Lr:1.00E-04
Epoch:26,Train_acc:98.0%,Train_loss:0.125,Test_acc:88.5%,Test_loss:0.323,Lr:1.00E-04
Epoch:27,Train_acc:96.0%,Train_loss:0.128,Test_acc:88.5%,Test_loss:0.403,Lr:1.00E-04
Epoch:28,Train_acc:97.3%,Train_loss:0.083,Test_acc:92.9%,Test_loss:0.356,Lr:1.00E-04
Epoch:29,Train_acc:96.9%,Train_loss:0.083,Test_acc:87.6%,Test_loss:0.478,Lr:1.00E-04
Epoch:30,Train_acc:96.5%,Train_loss:0.124,Test_acc:85.0%,Test_loss:0.595,Lr:1.00E-04
Epoch:31,Train_acc:94.9%,Train_loss:0.142,Test_acc:85.0%,Test_loss:0.533,Lr:1.00E-04
Epoch:32,Train_acc:95.6%,Train_loss:0.125,Test_acc:93.8%,Test_loss:0.313,Lr:1.00E-04
Epoch:33,Train_acc:97.8%,Train_loss:0.058,Test_acc:92.0%,Test_loss:0.349,Lr:1.00E-04
Epoch:34,Train_acc:95.4%,Train_loss:0.123,Test_acc:89.4%,Test_loss:0.547,Lr:1.00E-04
Epoch:35,Train_acc:97.6%,Train_loss:0.075,Test_acc:89.4%,Test_loss:0.722,Lr:1.00E-04
Epoch:36,Train_acc:96.5%,Train_loss:0.078,Test_acc:92.0%,Test_loss:0.254,Lr:1.00E-04
Epoch:37,Train_acc:98.9%,Train_loss:0.031,Test_acc:89.4%,Test_loss:0.289,Lr:1.00E-04
Epoch:38,Train_acc:99.1%,Train_loss:0.028,Test_acc:93.8%,Test_loss:0.191,Lr:1.00E-04
Epoch:39,Train_acc:98.7%,Train_loss:0.035,Test_acc:87.6%,Test_loss:0.673,Lr:1.00E-04
Epoch:40,Train_acc:98.9%,Train_loss:0.037,Test_acc:95.6%,Test_loss:0.182,Lr:1.00E-04
Done
四、 结果可视化
1. Loss与Accuracy图
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif']=['SimHei'] #正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #正常显示负号
plt.rcParams['figure.dpi']=300 #分辨率
epochs_range=range(epochs)
plt.figure(figsize=(12,3))
plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='Training Accuracy')
plt.plot(epochs_range,test_acc,label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='Training Loss')
plt.plot(epochs_range,test_loss,label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
运行结果:
2. 指定图片进行预测
from PIL import Image
classes=list(total_data.class_to_idx)
def predict_one_image(image_path,model,transform,classes):
test_img=Image.open(image_path).convert('RGB')
plt.imshow(test_img) #展示预测的图片
test_img=transform(test_img)
img=test_img.to(device).unsqueeze(0)
model.eval()
output=model(img)
_,pred=torch.max(output,1)
pred_class=classes[pred]
print(f'预测结果是:{pred_class}')
预测图片:
#预测训练集中的某张照片
predict_one_image(image_path=r'D:\THE MNIST DATABASE\J-series\J1\bird_photos\Black Skimmer\001.jpg',
model=model,transform=train_transforms,classes=classes)
运行结果:
预测结果是:Black Skimmer
五、心得体会
在本周项目训练中,体会了在pytorch环境下手动搭建DPN模型的过程,加深了对DPN模型结构的理解。