PyTorch入门学习(十五):现有网络模型的使用及修改

目录

一、使用现有网络模型

二、修改现有网络模型


一、使用现有网络模型

PyTorch提供了许多流行的深度学习模型,这些模型在大规模图像数据集上进行了预训练。其中一个著名的模型是VGG16。下面是如何使用VGG16模型的示例代码:

python 复制代码
import torchvision
from torch import nn
from torchvision.models import VGG16

# 使用不带预训练权重的VGG16模型
vgg16_false = torchvision.models.vgg16(pretrained=False)

# 使用预训练权重的VGG16模型
vgg16_true = torchvision.models.vgg16(pretrained=True)

print(vgg16_false)
print(vgg16_true)

在上述代码中,使用torchvision.models.vgg16来加载VGG16模型。通过pretrained参数,我们可以选择是否加载预训练的权重。vgg16_false代表一个不带预训练权重的VGG16模型,而vgg16_true代表一个带有预训练权重的模型。

二、修改现有网络模型

一旦加载了现有的网络模型,可以对其进行修改,以满足特定任务的需求。下面是如何修改VGG16模型的示例代码:

python 复制代码
import torchvision
from torch import nn
from torchvision.models import VGG16

# 加载带有预训练权重的VGG16模型
vgg16 = torchvision.models.vgg16(pretrained=True)

# 添加一个新的线性层,将输出从1000类修改为10类
vgg16.classifier.add_module('add_linear', nn.Linear(1000, 10))

# 修改VGG16模型的最后一个全连接层
vgg16.classifier[6] = nn.Linear(4096, 10)

print(vgg16)

在上述代码中,加载了一个带有预训练权重的VGG16模型,并通过add_module方法添加了一个新的线性层,将输出从1000类修改为10类。此外,还演示了如何通过修改模型的索引来改变VGG16模型的最后一个全连接层。

这种方法可以帮助您快速构建适用于特定任务的模型,而无需从头开始训练整个网络。

完整代码如下:

python 复制代码
import torchvision
from torch import nn
from torchvision.models import VGG16_Weights

# train_data = torchvision.datasets.ImageNet("D:\\Python_Project\\pytorch\\data_image_net",split="train",download=True,transform=torchvision.transforms.ToTensor())

# 错误原因:参数pretrained自0.13起已弃用,将在0.15后删除,要改用"weights"。
vgg16_false = torchvision.models.vgg16(weights=None)
vgg16_true = torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT)

# print(vgg16_true)

# 要想用于 CIFAR10 数据集, 可以在网络下面多加一行,转成10分类的输出,这样输出的结果,跟下面的不一样,位置不一样
# vgg16_true.add_module('add_Linear',nn.Linear(1000,10))
# print(vgg16_true)

vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))
# 层级不同
# 如何利用现有的网络,改变结构
print(vgg16_true)

# 上面是添加层,下面是如何修改VGG里面的层内容
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096,10)  # 中括号里的内容,是网络输出结果自带的索引,套进这种格式,就可以直接修改那一层的内容
print(vgg16_false)

参考资料:

视频教程:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

相关推荐
潮汐退涨月冷风霜2 小时前
机器学习之非监督学习(四)K-means 聚类算法
学习·算法·机器学习
GoppViper2 小时前
golang学习笔记29——golang 中如何将 GitHub 最新提交的版本设置为 v1.0.0
笔记·git·后端·学习·golang·github·源代码管理
B站计算机毕业设计超人2 小时前
计算机毕业设计Python+Flask微博情感分析 微博舆情预测 微博爬虫 微博大数据 舆情分析系统 大数据毕业设计 NLP文本分类 机器学习 深度学习 AI
爬虫·python·深度学习·算法·机器学习·自然语言处理·数据可视化
羊小猪~~2 小时前
深度学习基础案例5--VGG16人脸识别(体验学习的痛苦与乐趣)
人工智能·python·深度学习·学习·算法·机器学习·cnn
Charles Ray3 小时前
C++学习笔记 —— 内存分配 new
c++·笔记·学习
我要吐泡泡了哦4 小时前
GAMES104:15 游戏引擎的玩法系统基础-学习笔记
笔记·学习·游戏引擎
骑鱼过海的猫1234 小时前
【tomcat】tomcat学习笔记
笔记·学习·tomcat
AI大模型知识分享5 小时前
Prompt最佳实践|如何用参考文本让ChatGPT答案更精准?
人工智能·深度学习·机器学习·chatgpt·prompt·gpt-3
贾saisai6 小时前
Xilinx系FPGA学习笔记(九)DDR3学习
笔记·学习·fpga开发
北岛寒沫6 小时前
JavaScript(JS)学习笔记 1(简单介绍 注释和输入输出语句 变量 数据类型 运算符 流程控制 数组)
javascript·笔记·学习