写在前面:有些地方和视频里不一样的是因为官方文档更新了,一些参数用法不一样也很正常,包括我现在的也是我这个时间节点最新的,谁知道过段时间会不会更新呢= =建议大家不要一味看视频/博客,多看看官方文档才是正道(
加载现有的网络模型
加载有两种方式加载,一种是直接加载固有的网络结构,这种比较简单,还有一种是将原有的网络训练好的参数也下载下来,这种加载的时候如果原来没有的话会自动下载,如下:
对应的用法如下:
python
#只加载网络结构
vgg16_false = torchvision.models.vgg16(weights=None)
print(vgg16_false)
#加载网络结构and参数
vgg16_true = torchvision.models.vgg16(weights='DEFAULT')
print(vgg16_true)
VGG16原有结构(图太长了,开头没截全,重点关注最后的就ok)
在最后加入新层(以修改为10分类为例)
python
#在最后加入新层
vgg16_true.add_module('my_add_linear1',nn.Linear(1000,10))
print(vgg16_true)

在原有区域块中加入新层
python
#在原有区域块中加入新层
vgg16_true.classifier.add_module('my_add_linear2',nn.Linear(1000,10))
print(vgg16_true)

对原有层进行修改
python
#对原有层进行修改
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)
