在深度学习模型的训练过程中,我们往往会遇到需要保存模型参数或训练结果的情况。这样不仅有助于防止因系统崩溃而丢失重要进展,还能让我们在后续的应用中更方便地加载模型,进行预测或继续训练。今天,我们将深入探讨如何在深度学习中读写文件,特别是如何保存和加载张量及模型参数。
1. 读写张量数据
张量是深度学习中的基本数据结构,它是多维数组的泛化。我们可以通过深度学习框架提供的 save
和 load
函数来存储和加载张量。在这里,我们将以PyTorch为例,演示如何保存和加载张量数据。
1.1 保存张量
首先,我们创建一个简单的张量,并将其保存到文件中。PyTorch提供了 torch.save()
函数,能够将张量存储到指定的文件路径。
python
import torch
# 创建一个张量
X = torch.arange(4)
print(X) # tensor([0, 1, 2, 3])
# 保存张量到文件
torch.save(X, 'x-file')
这里,我们通过 torch.arange(4)
生成了一个包含四个元素的张量,并将其保存到名为 x-file 的文件中。
1.2 加载张量
接下来,我们可以将文件中的张量加载回内存。使用 torch.load()
函数可以轻松地读取保存的张量数据。
python
# 加载存储的张量
X2 = torch.load('x-file')
print(X2)
输出结果将会是:
scss
tensor([0, 1, 2, 3])
如你所见,加载的张量和原始张量相同。
1.3 保存多个张量
我们不仅可以保存一个张量,也可以保存多个张量。PyTorch允许我们将多个张量一起保存为列表或字典。以下是保存多个张量并加载它们的例子:
python
y = torch.zeros(4)
# 保存多个张量到文件
torch.save([X, y], 'x-files')
print(X, y)
print('-' * 50)
x2, y2 = torch.load('x-files')
print(x2, y2)
输出结果是:
scss
tensor([0, 1, 2, 3]) tensor([0., 0., 0., 0.])
---------------------------------------------
tensor([0, 1, 2, 3]) tensor([0., 0., 0., 0.])
1.4 保存张量字典
除了列表,我们还可以将多个张量存储在字典中,方便按名称访问。在深度学习中,常常需要保存模型的权重和偏置,这时字典格式会特别有用。
python
# 创建一个包含张量的字典
mydict = {'x': X, 'y': y}
# 保存字典
torch.save(mydict, 'mydict')
# 加载字典
mydict2 = torch.load('mydict')
print(mydict2)
输出结果为:
scss
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
小结:
torch.save()
和torch.load()
是用于保存和加载张量数据的基础函数。- 我们可以保存和加载单个张量、多个张量(如列表、字典)等结构。
2. 保存和加载模型参数
虽然保存单个张量非常有用,但在深度学习中,我们更常见的需求是保存和加载整个模型,特别是它的参数。模型的参数包括权重和偏置,它们通常是模型训练的核心内容。
2.1 保存模型参数
在PyTorch中,保存和加载模型的参数是通过 state_dict()
方法来完成的。state_dict()
包含了模型的所有可学习参数,如权重和偏置。我们可以将这些参数保存到文件中,以便后续加载。
首先,定义一个简单的多层感知机(MLP)模型,并将其训练后的参数保存到文件中。
python
import torch
from torch import nn
from torch.nn import functional as F
# 定义多层感知机模型
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.hidden = nn.Linear(20, 256)
self.output = nn.Linear(256, 10)
def forward(self, X):
return self.output(F.relu(self.hidden(X)))
# 创建模型实例
net = MLP()
# 创建一些随机输入
X = torch.randn(size=(2, 20))
# 前向传播计算
y = net(X)
print(y)
"""
tensor([[ 0.1941, -0.0554, 0.2344, 0.4940, 0.1903, 0.0163, 0.1400, 0.2266,
0.1028, -0.3156],
[ 0.0139, 0.1429, 0.2452, 0.2173, -0.0184, 0.1851, 0.2068, -0.0041,
-0.1291, -0.1377]], grad_fn=<AddmmBackward0>)
"""
torch.save(net.state_dict(), 'mlp.params')
2.2 加载模型参数
当我们需要加载模型时,我们首先实例化一个新的模型对象,然后使用 load_state_dict()
加载保存的参数。注意,加载的只是参数,而不是整个模型的架构。因此,加载时需要确保模型的结构和保存时一致。
python
# 创建一个新的模型实例
clone = MLP()
# 加载保存的模型参数
clone.load_state_dict(torch.load('mlp.params', weights_only=True))
# 切换模型为评估模式
clone.eval()
# 验证两个模型输出是否一致
Y_clone = clone(X)
print(Y_clone == Y)
输出结果:
scss
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])
可以看到,加载的模型和原始模型的输出是完全一致的。
小结
- 通过
state_dict()
方法,我们可以保存和加载深度学习模型的参数。 - 加载模型参数时,需要重新构建模型架构,确保参数和模型结构一致。
总结
通过这篇文章,我们学习了如何在深度学习中保存和加载张量及模型参数。具体来说:
- 使用
torch.save()
和torch.load()
函数可以轻松地保存和加载张量。 - 模型的参数保存和加载是通过
state_dict()
方法实现的,保存的只是模型的参数而非模型本身。 - 在加载模型时,必须重新定义模型结构,并从文件中加载其对应的参数。
这些技巧不仅能帮助我们在训练过程中保存进度,还能在后续的部署或复现中充分利用已有的模型。