动手学人工智能-深度学习计算5-文件读写操作

在深度学习模型的训练过程中,我们往往会遇到需要保存模型参数或训练结果的情况。这样不仅有助于防止因系统崩溃而丢失重要进展,还能让我们在后续的应用中更方便地加载模型,进行预测或继续训练。今天,我们将深入探讨如何在深度学习中读写文件,特别是如何保存和加载张量及模型参数。

1. 读写张量数据

张量是深度学习中的基本数据结构,它是多维数组的泛化。我们可以通过深度学习框架提供的 saveload 函数来存储和加载张量。在这里,我们将以PyTorch为例,演示如何保存和加载张量数据。

1.1 保存张量

首先,我们创建一个简单的张量,并将其保存到文件中。PyTorch提供了 torch.save() 函数,能够将张量存储到指定的文件路径。

python 复制代码
import torch

# 创建一个张量
X = torch.arange(4)

print(X)  # tensor([0, 1, 2, 3])
# 保存张量到文件
torch.save(X, 'x-file')

这里,我们通过 torch.arange(4) 生成了一个包含四个元素的张量,并将其保存到名为 x-file 的文件中。

1.2 加载张量

接下来,我们可以将文件中的张量加载回内存。使用 torch.load() 函数可以轻松地读取保存的张量数据。

python 复制代码
# 加载存储的张量
X2 = torch.load('x-file')
print(X2)

输出结果将会是:

scss 复制代码
tensor([0, 1, 2, 3])

如你所见,加载的张量和原始张量相同。

1.3 保存多个张量

我们不仅可以保存一个张量,也可以保存多个张量。PyTorch允许我们将多个张量一起保存为列表或字典。以下是保存多个张量并加载它们的例子:

python 复制代码
y = torch.zeros(4)
# 保存多个张量到文件
torch.save([X, y], 'x-files')
print(X, y)
print('-' * 50)
x2, y2 = torch.load('x-files')
print(x2, y2)

输出结果是:

scss 复制代码
tensor([0, 1, 2, 3]) tensor([0., 0., 0., 0.])
---------------------------------------------
tensor([0, 1, 2, 3]) tensor([0., 0., 0., 0.])

1.4 保存张量字典

除了列表,我们还可以将多个张量存储在字典中,方便按名称访问。在深度学习中,常常需要保存模型的权重和偏置,这时字典格式会特别有用。

python 复制代码
# 创建一个包含张量的字典
mydict = {'x': X, 'y': y}
# 保存字典
torch.save(mydict, 'mydict')
# 加载字典
mydict2 = torch.load('mydict')
print(mydict2)

输出结果为:

scss 复制代码
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}

小结:

  • torch.save()torch.load() 是用于保存和加载张量数据的基础函数。
  • 我们可以保存和加载单个张量、多个张量(如列表、字典)等结构。

2. 保存和加载模型参数

虽然保存单个张量非常有用,但在深度学习中,我们更常见的需求是保存和加载整个模型,特别是它的参数。模型的参数包括权重和偏置,它们通常是模型训练的核心内容。

2.1 保存模型参数

在PyTorch中,保存和加载模型的参数是通过 state_dict() 方法来完成的。state_dict() 包含了模型的所有可学习参数,如权重和偏置。我们可以将这些参数保存到文件中,以便后续加载。

首先,定义一个简单的多层感知机(MLP)模型,并将其训练后的参数保存到文件中。

python 复制代码
import torch
from torch import nn
from torch.nn import functional as F


# 定义多层感知机模型
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

    def forward(self, X):
        return self.output(F.relu(self.hidden(X)))


# 创建模型实例
net = MLP()
# 创建一些随机输入
X = torch.randn(size=(2, 20))
# 前向传播计算
y = net(X)
print(y)
"""
tensor([[ 0.1941, -0.0554,  0.2344,  0.4940,  0.1903,  0.0163,  0.1400,  0.2266,
          0.1028, -0.3156],
        [ 0.0139,  0.1429,  0.2452,  0.2173, -0.0184,  0.1851,  0.2068, -0.0041,
         -0.1291, -0.1377]], grad_fn=<AddmmBackward0>)
"""

torch.save(net.state_dict(), 'mlp.params')

2.2 加载模型参数

当我们需要加载模型时,我们首先实例化一个新的模型对象,然后使用 load_state_dict() 加载保存的参数。注意,加载的只是参数,而不是整个模型的架构。因此,加载时需要确保模型的结构和保存时一致。

python 复制代码
# 创建一个新的模型实例
clone = MLP()
# 加载保存的模型参数
clone.load_state_dict(torch.load('mlp.params', weights_only=True))
# 切换模型为评估模式
clone.eval()

# 验证两个模型输出是否一致
Y_clone = clone(X)

print(Y_clone == Y)

输出结果:

scss 复制代码
tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

可以看到,加载的模型和原始模型的输出是完全一致的。

小结

  • 通过 state_dict() 方法,我们可以保存和加载深度学习模型的参数。
  • 加载模型参数时,需要重新构建模型架构,确保参数和模型结构一致。

总结

通过这篇文章,我们学习了如何在深度学习中保存和加载张量及模型参数。具体来说:

  • 使用 torch.save()torch.load() 函数可以轻松地保存和加载张量。
  • 模型的参数保存和加载是通过 state_dict() 方法实现的,保存的只是模型的参数而非模型本身。
  • 在加载模型时,必须重新定义模型结构,并从文件中加载其对应的参数。

这些技巧不仅能帮助我们在训练过程中保存进度,还能在后续的部署或复现中充分利用已有的模型。

相关推荐
漂亮_大男孩41 分钟前
深度学习|表示学习|卷积神经网络|局部链接是什么?|06
深度学习·学习·cnn
lly_csdn1232 小时前
【Image Captioning】DynRefer
python·深度学习·ai·图像分类·多模态·字幕生成·属性识别
TURING.DT3 小时前
模型部署:TF Serving 的使用
深度学习·tensorflow
励志去大厂的菜鸟5 小时前
系统相关类——java.lang.Math (三)(案例详细拆解小白友好)
java·服务器·开发语言·深度学习·学习方法
liuhui2445 小时前
Pytorch深度学习指南 卷I --编程基础(A Beginner‘s Guide) 第1章 一个简单的回归
pytorch·深度学习·回归
睡不着还睡不醒6 小时前
【深度学习】神经网络实战分类与回归任务
深度学习·神经网络·分类
编码浪子6 小时前
Transformer的编码机制
人工智能·深度学习·transformer
IE066 小时前
深度学习系列76:流式tts的一个简单实现
人工智能·深度学习
m0_7431064611 小时前
【论文笔记】MV-DUSt3R+:两秒重建一个3D场景
论文阅读·深度学习·计算机视觉·3d·几何学
m0_7431064611 小时前
【论文笔记】TranSplat:深度refine的camera-required可泛化稀疏方法
论文阅读·深度学习·计算机视觉·3d·几何学