深度学习速成:模型的使用与修改,保存与读取

1.使用与修改

VGG16为例

python 复制代码
import torchvision
import torch

#trian_data=torchvision.datasets.imagenet("../data_imgnet",train=True,transform=torchvision.transforms.ToTensor(),download=True)

vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_true=torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)
print(vgg16_true)



#vgg16_true.add_module("new_fc",torch.nn.Linear(1000,10))#追加全結合層
vgg16_true.classifier.add_module("new_fc",torch.nn.Linear(1000,10))#追加全結合層
print(vgg16_true)

print(vgg16_false)

vgg16_false.classifier[6]=torch.nn.Linear(4096,10)#在(6)那里修改全连接层
print(vgg16_false)

修改完的输出

2.保存与读取

2.1保存

python 复制代码
import torch
import torchvision
import torch.nn as nn
vgg16=torchvision.models.vgg16(weights=None)
#保存1 保存整个模型(结构+参数)
torch.save(vgg16,"vgg16.pth")


#保存2 只保存模型参数(官方推荐,内存小)
torch.save(vgg16.state_dict(),"vgg16_params.pth")


class tudui(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Conv2d(3,32,5,padding=2)# 
        self.maxpool1=nn.MaxPool2d(2)#
        self.conv2=nn.Conv2d(32,32,5,padding=2)
        self.maxpool2=nn.MaxPool2d(2)
        self.conv3=nn.Conv2d(32,64,5,padding=2)
        self.maxpool3=nn.MaxPool2d(2)
        self.flatten=nn.Flatten()#
        self.linear1=nn.Linear(1024,64)#
        self.linear2=nn.Linear(64,10)#

   
    def forward(self,x):
        x=self.conv1(x)
        x=self.maxpool1(x)
        x=self.conv2(x)
        x=self.maxpool2(x)
        x=self.conv3(x)
        x=self.maxpool3(x)
        x=self.flatten(x)
        x=self.linear1(x)
        x=self.linear2(x)
         
        return x
    
tudui_model=tudui()
torch.save(tudui_model,"tudui_params.pth")#方法一保存

2.2 读取

python 复制代码
import torch
import torchvision
from model_save import *
#方式一 加载模型
""" vgg16=torch.load("vgg16.pth")
print(vgg16)
 """
#方式二 加载模型参数
vgg16=torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load("vgg16_params.pth"))

print(vgg16)
""" vgg16_params=torch.load("vgg16_params.pth")#字典形式
print(vgg16_params)    """



#方式一有陷阱  需要能访问到save时的类定义 im
model=torch.load("tudui_params.pth")
print(model)

2.3输出结果

python 复制代码
(base) PS E:\desktop\deeplearning> & D:\miniconda3\envs\pytorch_py312\python.exe e:/desktop/deeplearning/src/model_load.py
e:\desktop\deeplearning\src\model_load.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  vgg16.load_state_dict(torch.load("vgg16_params.pth"))
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)
e:\desktop\deeplearning\src\model_load.py:19: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  model=torch.load("tudui_params.pth")
tudui(
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=1024, out_features=64, bias=True)
  (linear2): Linear(in_features=64, out_features=10, bias=True)
)
相关推荐
小徐xxx2 小时前
感知机(Perceptron)学习记录
深度学习·感知机·二分类
友思特 智能感知2 小时前
友思特案例 | 金属行业视觉检测案例四:挖掘机钢板表面光学字符识别(OCR)检测
人工智能·视觉检测·深度学习视觉检测
爱吃泡芙的小白白2 小时前
CNN激活函数新篇:Sigmoid与Softmax的进化与实战
人工智能·神经网络·cnn·softmax·sigmoid·函数激活层
星爷AG I2 小时前
9-27 视觉表象(AGI基础理论)
人工智能·agi
Coder_Boy_2 小时前
基于SpringAI的在线考试系统-企业级教育考试系统核心架构(完善版)
开发语言·人工智能·spring boot·python·架构·领域驱动
艾莉丝努力练剑2 小时前
【Linux:文件】基础IO:文件操作的系统调用和库函数各个接口汇总及代码演示
linux·运维·服务器·c++·人工智能·centos·io
Leinwin2 小时前
VibeVoice-ASR:突破60分钟长音频处理瓶颈,语音识别进入端到端时代
人工智能·音视频·语音识别
没有不重的名么2 小时前
Multiple Object Tracking as ID Prediction
深度学习·opencv·计算机视觉·目标跟踪
Godspeed Zhao2 小时前
从零开始学AI7——机器学习0
人工智能·机器学习