深度学习速成:模型的使用与修改,保存与读取

1.使用与修改

VGG16为例

python 复制代码
import torchvision
import torch

#trian_data=torchvision.datasets.imagenet("../data_imgnet",train=True,transform=torchvision.transforms.ToTensor(),download=True)

vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_true=torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)
print(vgg16_true)



#vgg16_true.add_module("new_fc",torch.nn.Linear(1000,10))#追加全結合層
vgg16_true.classifier.add_module("new_fc",torch.nn.Linear(1000,10))#追加全結合層
print(vgg16_true)

print(vgg16_false)

vgg16_false.classifier[6]=torch.nn.Linear(4096,10)#在(6)那里修改全连接层
print(vgg16_false)

修改完的输出

2.保存与读取

2.1保存

python 复制代码
import torch
import torchvision
import torch.nn as nn
vgg16=torchvision.models.vgg16(weights=None)
#保存1 保存整个模型(结构+参数)
torch.save(vgg16,"vgg16.pth")


#保存2 只保存模型参数(官方推荐,内存小)
torch.save(vgg16.state_dict(),"vgg16_params.pth")


class tudui(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Conv2d(3,32,5,padding=2)# 
        self.maxpool1=nn.MaxPool2d(2)#
        self.conv2=nn.Conv2d(32,32,5,padding=2)
        self.maxpool2=nn.MaxPool2d(2)
        self.conv3=nn.Conv2d(32,64,5,padding=2)
        self.maxpool3=nn.MaxPool2d(2)
        self.flatten=nn.Flatten()#
        self.linear1=nn.Linear(1024,64)#
        self.linear2=nn.Linear(64,10)#

   
    def forward(self,x):
        x=self.conv1(x)
        x=self.maxpool1(x)
        x=self.conv2(x)
        x=self.maxpool2(x)
        x=self.conv3(x)
        x=self.maxpool3(x)
        x=self.flatten(x)
        x=self.linear1(x)
        x=self.linear2(x)
         
        return x
    
tudui_model=tudui()
torch.save(tudui_model,"tudui_params.pth")#方法一保存

2.2 读取

python 复制代码
import torch
import torchvision
from model_save import *
#方式一 加载模型
""" vgg16=torch.load("vgg16.pth")
print(vgg16)
 """
#方式二 加载模型参数
vgg16=torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load("vgg16_params.pth"))

print(vgg16)
""" vgg16_params=torch.load("vgg16_params.pth")#字典形式
print(vgg16_params)    """



#方式一有陷阱  需要能访问到save时的类定义 im
model=torch.load("tudui_params.pth")
print(model)

2.3输出结果

python 复制代码
(base) PS E:\desktop\deeplearning> & D:\miniconda3\envs\pytorch_py312\python.exe e:/desktop/deeplearning/src/model_load.py
e:\desktop\deeplearning\src\model_load.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  vgg16.load_state_dict(torch.load("vgg16_params.pth"))
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)
e:\desktop\deeplearning\src\model_load.py:19: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  model=torch.load("tudui_params.pth")
tudui(
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=1024, out_features=64, bias=True)
  (linear2): Linear(in_features=64, out_features=10, bias=True)
)
相关推荐
happyprince7 分钟前
2026年04月07日热门模型
人工智能
IT_陈寒13 分钟前
Vue的这个响应式问题,坑了我整整两小时
前端·人工智能·后端
HIT_Weston14 分钟前
41、【Agent】【OpenCode】本地代理分析(五)
javascript·人工智能·opencode
万添裁35 分钟前
pytorch的张量数据结构以及各种操作函数的底层原理
人工智能·pytorch·python
盘古开天166644 分钟前
Gemma4本地部署,零成本打造私有 AI 助手
人工智能·本地部署·智能体·gemma4·ai私有助理
夜影风1 小时前
算力租赁产业链全景分析:解构AI时代的“算力电厂”
人工智能·算力租赁
MediaTea1 小时前
AI 术语通俗词典:矩阵乘法
人工智能·线性代数·矩阵
NHuan^_^1 小时前
SpringBoot3 整合 SpringAI 实现ai助手(记忆)
java·人工智能·spring boot
Binary_ey1 小时前
光刻技术第22期 | 贝叶斯压缩感知光源优化的优化技术及对比分析
人工智能·深度学习·机器学习