【动手学深度学习】读写文件

【动手学深度学习】读写文件

加载和保存张量

对于单个张量 我么可以直接调用load和save函数分别读写,这两个函数要求我们提供一个名称,save要求保存的变量作为输入

py 复制代码
import torch
from torch import nn
from torch.nn import functional as F

# 创建一个长度为4的张量
x = torch.arange(4)
torch.save(x, 'x-file')


x2 = torch.load('x-file')
print(x2)

存储一个张量列表,然后把他们写入内存

py 复制代码
y = torch.zeros(4)
torch.save([x,y],'x-files')
x2,y2 = torch.load('x-files')
(x2,y2)

我们甚至可以写入或者读取从字符串映射到张量的字典,方面读取权重

py 复制代码
# 创建张量字典  保存张量
mydict = {'x':x,'y':y}
torch.save(mydict,'mydict')

mydict2 = torch.load('mydict')
mydict2

加载和保存模型参数

深度学习框架提供内置函数来保存和加载整个网络,这里是保存模型的参数而不是保存整个模型

py 复制代码
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20,256)
        self.output = nn.Linear(256,10)

    def forward(self,X):
        return self.output(F.relu(self.hidden(X)))
    
net = MLP()
X = torch.randn(size= (2,20))
Y = net(X)

取出模型的参数保存在一个mlp.params文件中

py 复制代码
# 取出模型的参数保存在一个mlp.params文件中
torch.save(net.state_dict(),'mlp.params')

恢复模型,实例化原始多层感知机模型的一个备份,我们不需要随机初始化模型参数,而是直接读取文件中存储的参数

py 复制代码
clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()

比较两个对象的模型参数,那么输入相同的X 计算的输出应该相同

py 复制代码
Y_clone = clone(X)
Y_clone == Y
相关推荐
辰阳星宇几秒前
【Agent】rStar2-Agent: Agentic Reasoning Technical Report
人工智能·算法·自然语言处理
再__努力1点1 分钟前
【50】OpenCV背景减法技术解析与实现
开发语言·图像处理·人工智能·python·opencv·算法·计算机视觉
serve the people1 分钟前
tensorflow Keras 模型的保存与加载
人工智能·tensorflow·keras
c骑着乌龟追兔子2 分钟前
Day 29 机器学习管道 pipeline
人工智能·机器学习
努力也学不会java4 分钟前
【docker】Docker Image(镜像)
java·运维·人工智能·机器学习·docker·容器
zhangfeng11334 分钟前
suppr.wilddata.cn 文献检索,用中文搜 PubMed 一种基于大语言模型的智能搜索引擎构建方法
人工智能·搜索引擎·语言模型
大千AI助手5 分钟前
高维空间中的高效导航者:球树(Ball Tree)算法深度解析
人工智能·算法·机器学习·数据挖掘·大千ai助手·球树·ball-tree
新知图书6 分钟前
使用FastGPT知识库构建智能客服的示例
人工智能·ai agent·智能体·大模型应用开发·大模型应用
生信大表哥8 分钟前
GPT-5-Codex VS Gemini 3 VS Claude Sonnet 4.5 新手小白入门学习教程
人工智能·gpt·学习·rstudio·数信院生信服务器
子午16 分钟前
【植物识别系统】Python+TensorFlow+Django+人工智能+深度学习+卷积神经网络算法
人工智能·python·深度学习