动手学人工智能-深度学习计算5-文件读写操作

在深度学习模型的训练过程中,我们往往会遇到需要保存模型参数或训练结果的情况。这样不仅有助于防止因系统崩溃而丢失重要进展,还能让我们在后续的应用中更方便地加载模型,进行预测或继续训练。今天,我们将深入探讨如何在深度学习中读写文件,特别是如何保存和加载张量及模型参数。

1. 读写张量数据

张量是深度学习中的基本数据结构,它是多维数组的泛化。我们可以通过深度学习框架提供的 saveload 函数来存储和加载张量。在这里,我们将以PyTorch为例,演示如何保存和加载张量数据。

1.1 保存张量

首先,我们创建一个简单的张量,并将其保存到文件中。PyTorch提供了 torch.save() 函数,能够将张量存储到指定的文件路径。

python 复制代码
import torch

# 创建一个张量
X = torch.arange(4)

print(X)  # tensor([0, 1, 2, 3])
# 保存张量到文件
torch.save(X, 'x-file')

这里,我们通过 torch.arange(4) 生成了一个包含四个元素的张量,并将其保存到名为 x-file 的文件中。

1.2 加载张量

接下来,我们可以将文件中的张量加载回内存。使用 torch.load() 函数可以轻松地读取保存的张量数据。

python 复制代码
# 加载存储的张量
X2 = torch.load('x-file')
print(X2)

输出结果将会是:

scss 复制代码
tensor([0, 1, 2, 3])

如你所见,加载的张量和原始张量相同。

1.3 保存多个张量

我们不仅可以保存一个张量,也可以保存多个张量。PyTorch允许我们将多个张量一起保存为列表或字典。以下是保存多个张量并加载它们的例子:

python 复制代码
y = torch.zeros(4)
# 保存多个张量到文件
torch.save([X, y], 'x-files')
print(X, y)
print('-' * 50)
x2, y2 = torch.load('x-files')
print(x2, y2)

输出结果是:

scss 复制代码
tensor([0, 1, 2, 3]) tensor([0., 0., 0., 0.])
---------------------------------------------
tensor([0, 1, 2, 3]) tensor([0., 0., 0., 0.])

1.4 保存张量字典

除了列表,我们还可以将多个张量存储在字典中,方便按名称访问。在深度学习中,常常需要保存模型的权重和偏置,这时字典格式会特别有用。

python 复制代码
# 创建一个包含张量的字典
mydict = {'x': X, 'y': y}
# 保存字典
torch.save(mydict, 'mydict')
# 加载字典
mydict2 = torch.load('mydict')
print(mydict2)

输出结果为:

scss 复制代码
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}

小结:

  • torch.save()torch.load() 是用于保存和加载张量数据的基础函数。
  • 我们可以保存和加载单个张量、多个张量(如列表、字典)等结构。

2. 保存和加载模型参数

虽然保存单个张量非常有用,但在深度学习中,我们更常见的需求是保存和加载整个模型,特别是它的参数。模型的参数包括权重和偏置,它们通常是模型训练的核心内容。

2.1 保存模型参数

在PyTorch中,保存和加载模型的参数是通过 state_dict() 方法来完成的。state_dict() 包含了模型的所有可学习参数,如权重和偏置。我们可以将这些参数保存到文件中,以便后续加载。

首先,定义一个简单的多层感知机(MLP)模型,并将其训练后的参数保存到文件中。

python 复制代码
import torch
from torch import nn
from torch.nn import functional as F


# 定义多层感知机模型
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

    def forward(self, X):
        return self.output(F.relu(self.hidden(X)))


# 创建模型实例
net = MLP()
# 创建一些随机输入
X = torch.randn(size=(2, 20))
# 前向传播计算
y = net(X)
print(y)
"""
tensor([[ 0.1941, -0.0554,  0.2344,  0.4940,  0.1903,  0.0163,  0.1400,  0.2266,
          0.1028, -0.3156],
        [ 0.0139,  0.1429,  0.2452,  0.2173, -0.0184,  0.1851,  0.2068, -0.0041,
         -0.1291, -0.1377]], grad_fn=<AddmmBackward0>)
"""

torch.save(net.state_dict(), 'mlp.params')

2.2 加载模型参数

当我们需要加载模型时,我们首先实例化一个新的模型对象,然后使用 load_state_dict() 加载保存的参数。注意,加载的只是参数,而不是整个模型的架构。因此,加载时需要确保模型的结构和保存时一致。

python 复制代码
# 创建一个新的模型实例
clone = MLP()
# 加载保存的模型参数
clone.load_state_dict(torch.load('mlp.params', weights_only=True))
# 切换模型为评估模式
clone.eval()

# 验证两个模型输出是否一致
Y_clone = clone(X)

print(Y_clone == Y)

输出结果:

scss 复制代码
tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

可以看到,加载的模型和原始模型的输出是完全一致的。

小结

  • 通过 state_dict() 方法,我们可以保存和加载深度学习模型的参数。
  • 加载模型参数时,需要重新构建模型架构,确保参数和模型结构一致。

总结

通过这篇文章,我们学习了如何在深度学习中保存和加载张量及模型参数。具体来说:

  • 使用 torch.save()torch.load() 函数可以轻松地保存和加载张量。
  • 模型的参数保存和加载是通过 state_dict() 方法实现的,保存的只是模型的参数而非模型本身。
  • 在加载模型时,必须重新定义模型结构,并从文件中加载其对应的参数。

这些技巧不仅能帮助我们在训练过程中保存进度,还能在后续的部署或复现中充分利用已有的模型。

相关推荐
paixiaoxin2 小时前
CV-OCR经典论文解读|An Empirical Study of Scaling Law for OCR/OCR 缩放定律的实证研究
人工智能·深度学习·机器学习·生成对抗网络·计算机视觉·ocr·.net
weixin_515202492 小时前
第R3周:RNN-心脏病预测
人工智能·rnn·深度学习
吕小明么3 小时前
OpenAI o3 “震撼” 发布后回归技术本身的审视与进一步思考
人工智能·深度学习·算法·aigc·agi
CSBLOG4 小时前
深度学习试题及答案解析(一)
人工智能·深度学习
小陈phd5 小时前
深度学习之超分辨率算法——SRCNN
python·深度学习·tensorflow·卷积
威化饼的一隅7 小时前
【多模态】swift-3框架使用
人工智能·深度学习·大模型·swift·多模态
机器学习之心7 小时前
BiTCN-BiGRU基于双向时间卷积网络结合双向门控循环单元的数据多特征分类预测(多输入单输出)
深度学习·分类·gru
MorleyOlsen8 小时前
【Trick】解决服务器cuda报错——RuntimeError: cuDNN error: CUDNN_STATUS_NOT_INITIALIZED
运维·服务器·深度学习
愚者大大10 小时前
1. 深度学习介绍
人工智能·深度学习