【深度学习】(8)--神经网络使用最优模型

文章目录

使用最优模型

直接使用最优模型在多个方面都具有显著的好处,尤其是在深度学习和机器学习领域。以下是一些主要的好处:

  1. 节省时间和资源
    • 训练时间:训练一个深度学习模型可能需要数小时、数天甚至数周的时间,特别是当数据集很大或模型很复杂时。直接使用最优模型可以立即开始使用,无需等待长时间的训练过程。
    • 计算资源:训练模型需要大量的计算资源,包括高性能的GPU。直接使用最优模型可以显著减少对这些资源的需求。
  2. 提高性能
    • 更好的泛化能力:最优模型通常是在大型、多样化的数据集上训练的,因此它们能够更好地泛化到新的、未见过的数据上。
    • 调优:许多最优模型已经过仔细的调优,包括超参数调整和架构搜索,以确保它们在特定任务上表现最佳。
  3. 易于实现
    • 快速原型开发:对于研究人员或开发人员来说,使用最优模型可以快速实现原型,以测试想法或验证假设。
    • 减少复杂性:直接使用最优模型可以减少实现和调试新模型的复杂性,特别是在模型架构和数据预处理方面。

直接使用最优模型的两种方法

一、 定义模型结构

首先,你需要有定义模型结构的代码。这通常是一个继承自torch.nn.Module的类。如果你没有保存整个模型实例,而是只保存了模型的状态字典(即模型的参数和缓冲区),那么你需要重新定义模型结构。

注意:定义模型的将结果,务必要与使用的最优模型结构相同,否则参数不匹配。

python 复制代码
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = nn.Sequential( # 将多个层组合在一起
            nn.Conv2d(         # 2d一般用于图像,3d用于视频数据(多一个时间维度),1d一般用于结构化的序列数据
                in_channels=3, # 图像通道个数,1表示灰度图(确定卷积核 组中的个数)
                out_channels=16, # 要得到多少特征图,卷积核的个数
                kernel_size=5,  # 卷积核大小
                stride=1,   # 步长
                padding=2   # 边界填充大小
            ),
            nn.ReLU(), # relu层,不会改变特征图的大小
            nn.MaxPool2d(kernel_size=2) # 进行池化操作(2*2区域)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16,32,5,1,2),
            nn.ReLU(),
            nn.Conv2d(32,32,5,1,2),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(32,128,5,1,2), # 输出(64,7,7)
            nn.ReLU()
        )
        self.out = nn.Linear(128*54*54,4)

    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x) # 输出(64,7,7)
        x = x.view(x.size(0),-1) # flatten 操作,结果为:(batch_size,64*7*7)
        output = self.out(x)
        output = torch.sigmoid(output)
        return output

二、 加载模型

加载模型有两种方法:1.加载模型状态字典;2.加载整个模型实例

如果你保存的是整个模型实例(使用torch.save(model, 'model_path.pt') ),你可以直接加载它。但是,更常见的是只保存和加载模型的状态字典(使用torch.save(model.state_dict(), 'model_state_dict.pt'))。

1. 加载模型状态字典
python 复制代码
# 1. 读取参数的方法:
model = CNN().to(device)
model.load_state_dict(torch.load("best.pth"))
model.eval()  # 设置为评估模式,固定模型参数和数据,防止后面被修改
2. 加载完整模型

虽然不需要与保存时完全相同的网络结构代码,但你的环境中必须存在与保存模型时相同的模型类定义。这是因为加载模型时,PyTorch需要知道如何将加载的数据(即模型的参数和结构)映射回Python中的类实例。

因此,该方法也需要提前定义模型。

python 复制代码
# 2. 读取完整模型的方法,需要提前创建model:
model = CNN().to(device)
model = torch.load("best1.pt")
model.eval()    #固定模型参数和数据,防止后面被修改

三、设置设备

python 复制代码
"""---判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU"""
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

四、使用模型预测

python 复制代码
result = []#保存的预测的结果
lables = []#真实结果
def test_true(dataloader, model):
    with torch.no_grad():   #一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。这可以减少计算所用内存消耗。
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model.forward(X)#预测之后的结果。
            result.append(pred.argmax(1).item())
            lables.append(y.item())

test_data = food_dataset(file_path='testda.txt',transform=data_transforms['valid'])
test_dataloader = DataLoader(test_data,batch_size=1,shuffle=True)

test_true(test_dataloader,model)
print(f"预测值 \t:{result}")
print(f"真实值 \t:{lables}")
  • 梯度计算:在评估模式下,使用**torch.no_grad()**上下文管理器可以减少内存消耗并加速计算,因为不需要存储用于反向传播的梯度。

总结

本篇介绍了:

  1. 如何使用保存好的最优模型
  2. 加载模型的两种方法:1.加载模型状态字典;2.加载整个模型实例
  3. 注意:两种方法都需要提前定义神经网络结构。
相关推荐
Leweslyh1 小时前
物理信息神经网络(PINN)八课时教案
人工智能·深度学习·神经网络·物理信息神经网络
love you joyfully2 小时前
目标检测与R-CNN——pytorch与paddle实现目标检测与R-CNN
人工智能·pytorch·目标检测·cnn·paddle
该醒醒了~2 小时前
PaddlePaddle推理模型利用Paddle2ONNX转换成onnx模型
人工智能·paddlepaddle
小树苗1932 小时前
DePIN潜力项目Spheron解读:激活闲置硬件,赋能Web3与AI
人工智能·web3
凡人的AI工具箱2 小时前
每天40分玩转Django:Django测试
数据库·人工智能·后端·python·django·sqlite
大多_C2 小时前
BERT outputs
人工智能·深度学习·bert
Debroon3 小时前
乳腺癌多模态诊断解释框架:CNN + 可解释 AI 可视化
人工智能·神经网络·cnn
反方向的钟儿3 小时前
非结构化数据分析与应用(Unstructured data analysis and applications)(pt3)图像数据分析1
人工智能·计算机视觉·数据分析
Heartsuit3 小时前
LLM大语言模型私有化部署-使用Dify的工作流编排打造专属AI搜索引擎
人工智能·dify·ollama·qwen2.5·ai搜索引擎·tavily search·工作流编排
剑盾云安全专家3 小时前
AI加持,如何让PPT像开挂一键生成?
人工智能·aigc·powerpoint·软件