深度学习速成:模型的使用与修改,保存与读取

1.使用与修改

VGG16为例

python 复制代码
import torchvision
import torch

#trian_data=torchvision.datasets.imagenet("../data_imgnet",train=True,transform=torchvision.transforms.ToTensor(),download=True)

vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_true=torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)
print(vgg16_true)



#vgg16_true.add_module("new_fc",torch.nn.Linear(1000,10))#追加全結合層
vgg16_true.classifier.add_module("new_fc",torch.nn.Linear(1000,10))#追加全結合層
print(vgg16_true)

print(vgg16_false)

vgg16_false.classifier[6]=torch.nn.Linear(4096,10)#在(6)那里修改全连接层
print(vgg16_false)

修改完的输出

2.保存与读取

2.1保存

python 复制代码
import torch
import torchvision
import torch.nn as nn
vgg16=torchvision.models.vgg16(weights=None)
#保存1 保存整个模型(结构+参数)
torch.save(vgg16,"vgg16.pth")


#保存2 只保存模型参数(官方推荐,内存小)
torch.save(vgg16.state_dict(),"vgg16_params.pth")


class tudui(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Conv2d(3,32,5,padding=2)# 
        self.maxpool1=nn.MaxPool2d(2)#
        self.conv2=nn.Conv2d(32,32,5,padding=2)
        self.maxpool2=nn.MaxPool2d(2)
        self.conv3=nn.Conv2d(32,64,5,padding=2)
        self.maxpool3=nn.MaxPool2d(2)
        self.flatten=nn.Flatten()#
        self.linear1=nn.Linear(1024,64)#
        self.linear2=nn.Linear(64,10)#

   
    def forward(self,x):
        x=self.conv1(x)
        x=self.maxpool1(x)
        x=self.conv2(x)
        x=self.maxpool2(x)
        x=self.conv3(x)
        x=self.maxpool3(x)
        x=self.flatten(x)
        x=self.linear1(x)
        x=self.linear2(x)
         
        return x
    
tudui_model=tudui()
torch.save(tudui_model,"tudui_params.pth")#方法一保存

2.2 读取

python 复制代码
import torch
import torchvision
from model_save import *
#方式一 加载模型
""" vgg16=torch.load("vgg16.pth")
print(vgg16)
 """
#方式二 加载模型参数
vgg16=torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load("vgg16_params.pth"))

print(vgg16)
""" vgg16_params=torch.load("vgg16_params.pth")#字典形式
print(vgg16_params)    """



#方式一有陷阱  需要能访问到save时的类定义 im
model=torch.load("tudui_params.pth")
print(model)

2.3输出结果

python 复制代码
(base) PS E:\desktop\deeplearning> & D:\miniconda3\envs\pytorch_py312\python.exe e:/desktop/deeplearning/src/model_load.py
e:\desktop\deeplearning\src\model_load.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  vgg16.load_state_dict(torch.load("vgg16_params.pth"))
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)
e:\desktop\deeplearning\src\model_load.py:19: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  model=torch.load("tudui_params.pth")
tudui(
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=1024, out_features=64, bias=True)
  (linear2): Linear(in_features=64, out_features=10, bias=True)
)
相关推荐
向哆哆11 小时前
PCB电路板缺陷检测数据集(近千张图片已划分、已标注)适用于YOLO系列深度学习检测任务
人工智能·深度学习·yolo
F_D_Z11 小时前
Stacked Generalization 堆叠泛化
人工智能·机器学习·集成学习·stacking
焦耳热科技前沿11 小时前
复旦大学Nat. Commun.:等离子体辅助碳热闪烧合成突破Hume-Rothery极限的亚5纳米高熵合金
人工智能·科技·自动化·能源·材料工程
未来之窗软件服务11 小时前
vosk-ASR freeswitch调用[AI人工智能(五十六)]—东方仙盟
人工智能·仙盟创梦ide·东方仙盟
踏浪无痕11 小时前
聊聊最近很火的"小龙虾":AI 应用的下一步是什么?
人工智能
chaofan98011 小时前
2026 轻量模型三国杀:Flash-Lite vs GPT-4.1 Nano vs Haiku,技术选型到底该站谁?
前端·人工智能·microsoft
BB学长11 小时前
LBM vs FVM:谁才是 CFD 的未来?
人工智能·算法·机器学习
AIDF202611 小时前
AI 芯片推理适配踩坑记:从 GPU 到国产算力的迁移思路
人工智能
Zzj_tju12 小时前
AI+医疗实战:影像+文本报告怎么结合?从单模态分类到多模态医疗 AI 系统设计
人工智能·分类·数据挖掘
智能交通技术12 小时前
iTSTech:自动驾驶、无人机与机器人在物流中的协同应用场景分析 2026
人工智能·机器学习·机器人·自动驾驶·无人机