pytorch学习笔记-加载现有的网络模型(VGG16)、增加/修改其中的网络层(修改为10分类)

写在前面:有些地方和视频里不一样的是因为官方文档更新了,一些参数用法不一样也很正常,包括我现在的也是我这个时间节点最新的,谁知道过段时间会不会更新呢= =建议大家不要一味看视频/博客,多看看官方文档才是正道(

加载现有的网络模型

加载有两种方式加载,一种是直接加载固有的网络结构,这种比较简单,还有一种是将原有的网络训练好的参数也下载下来,这种加载的时候如果原来没有的话会自动下载,如下:

对应的用法如下:

python 复制代码
#只加载网络结构
vgg16_false = torchvision.models.vgg16(weights=None)
print(vgg16_false)

#加载网络结构and参数
vgg16_true = torchvision.models.vgg16(weights='DEFAULT')
print(vgg16_true)

VGG16原有结构(图太长了,开头没截全,重点关注最后的就ok)

在最后加入新层(以修改为10分类为例)

python 复制代码
#在最后加入新层
vgg16_true.add_module('my_add_linear1',nn.Linear(1000,10))
print(vgg16_true)

在原有区域块中加入新层

python 复制代码
#在原有区域块中加入新层
vgg16_true.classifier.add_module('my_add_linear2',nn.Linear(1000,10))
print(vgg16_true)

对原有层进行修改

python 复制代码
#对原有层进行修改
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)