深度学习速成:模型的使用与修改,保存与读取

1.使用与修改

VGG16为例

python 复制代码
import torchvision
import torch

#trian_data=torchvision.datasets.imagenet("../data_imgnet",train=True,transform=torchvision.transforms.ToTensor(),download=True)

vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_true=torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)
print(vgg16_true)



#vgg16_true.add_module("new_fc",torch.nn.Linear(1000,10))#追加全結合層
vgg16_true.classifier.add_module("new_fc",torch.nn.Linear(1000,10))#追加全結合層
print(vgg16_true)

print(vgg16_false)

vgg16_false.classifier[6]=torch.nn.Linear(4096,10)#在(6)那里修改全连接层
print(vgg16_false)

修改完的输出

2.保存与读取

2.1保存

python 复制代码
import torch
import torchvision
import torch.nn as nn
vgg16=torchvision.models.vgg16(weights=None)
#保存1 保存整个模型(结构+参数)
torch.save(vgg16,"vgg16.pth")


#保存2 只保存模型参数(官方推荐,内存小)
torch.save(vgg16.state_dict(),"vgg16_params.pth")


class tudui(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Conv2d(3,32,5,padding=2)# 
        self.maxpool1=nn.MaxPool2d(2)#
        self.conv2=nn.Conv2d(32,32,5,padding=2)
        self.maxpool2=nn.MaxPool2d(2)
        self.conv3=nn.Conv2d(32,64,5,padding=2)
        self.maxpool3=nn.MaxPool2d(2)
        self.flatten=nn.Flatten()#
        self.linear1=nn.Linear(1024,64)#
        self.linear2=nn.Linear(64,10)#

   
    def forward(self,x):
        x=self.conv1(x)
        x=self.maxpool1(x)
        x=self.conv2(x)
        x=self.maxpool2(x)
        x=self.conv3(x)
        x=self.maxpool3(x)
        x=self.flatten(x)
        x=self.linear1(x)
        x=self.linear2(x)
         
        return x
    
tudui_model=tudui()
torch.save(tudui_model,"tudui_params.pth")#方法一保存

2.2 读取

python 复制代码
import torch
import torchvision
from model_save import *
#方式一 加载模型
""" vgg16=torch.load("vgg16.pth")
print(vgg16)
 """
#方式二 加载模型参数
vgg16=torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load("vgg16_params.pth"))

print(vgg16)
""" vgg16_params=torch.load("vgg16_params.pth")#字典形式
print(vgg16_params)    """



#方式一有陷阱  需要能访问到save时的类定义 im
model=torch.load("tudui_params.pth")
print(model)

2.3输出结果

python 复制代码
(base) PS E:\desktop\deeplearning> & D:\miniconda3\envs\pytorch_py312\python.exe e:/desktop/deeplearning/src/model_load.py
e:\desktop\deeplearning\src\model_load.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  vgg16.load_state_dict(torch.load("vgg16_params.pth"))
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)
e:\desktop\deeplearning\src\model_load.py:19: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  model=torch.load("tudui_params.pth")
tudui(
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=1024, out_features=64, bias=True)
  (linear2): Linear(in_features=64, out_features=10, bias=True)
)
相关推荐
IT_陈寒12 分钟前
Redis持久化这个坑,我爬了一整天才出来
前端·人工智能·后端
kimi-22217 分钟前
LangChain 里的 chatmodel.bind_tools 和 ReAct Agent
人工智能
zhangfeng113334 分钟前
计算机视觉vc 3D 希尔伯特曲线 基础介绍,人工智能
人工智能·计算机视觉·3d
没事别瞎琢磨38 分钟前
十一、审计与 Run Session——每一步操作都被记录
人工智能·node.js
没事别瞎琢磨38 分钟前
十六、AgentSandbox——把所有模块串起来的编排类
人工智能·node.js
George37541 分钟前
当 Loop Engineering 成为行业共识,我发现自己的开源项目已经实践了 3 个月
人工智能
没事别瞎琢磨43 分钟前
十二、网络代理与白名单规则引擎
人工智能·node.js
马士兵教育1 小时前
Java还有前景吗?Java+AI大模型学习路线及项目?
java·人工智能·python·学习·机器学习
没事别瞎琢磨1 小时前
十四、Git Worktree 隔离执行
人工智能·node.js
安全指北针1 小时前
大模型时代,谁在领跑中国AI安全赛道?中国AI安全产品市场分析
人工智能