深入解析torch.load中的【map_location】参数

深入解析torch.load中的map_location参数


🌵文章目录🌵


🌳引言🌳

在PyTorch中,torch.load()函数是用于加载保存模型或张量数据的重要工具。当我们训练好一个深度学习模型后,通常需要将模型的参数(或称为状态字典,state_dict)保存下来,以便后续进行模型评估、继续训练或部署到其他环境中。在加载这些保存的数据时,map_location参数为我们提供了极大的灵活性,以决定这些数据应该被加载到哪个设备上 。本文将详细解析map_location参数的功能和使用方法,并通过实战案例来展示其在不同场景下的应用。

🌳map_location参数详解🌳

map_location参数在torch.load()函数中扮演着至关重要的角色。它决定了从保存的文件中加载数据时应将它们映射到哪个设备上 。在PyTorch中,设备可以是CPU或GPU,而GPU可以有多个,每个都有其独立的索引。map_location的灵活使用能够让我们轻松地在不同设备之间迁移模型,从而充分利用不同设备的计算优势。

map_location参数的数据类型

map_location参数的数据类型可以是:

参数类型 描述 示例
字符串(str) 预定义的设备字符串,指定目标设备。 1. 'cpu':加载到CPU上; 2. 'cuda:X':加载到索引为X的GPU上。
torch.device对象 一个表示目标设备的torch.device对象。 1.torch.device('cpu'):加载到CPU上; 2. torch.device('cuda:1'):加载到索引为1的GPU上。
可调用对象(callable) 一个接收存储路径并返回新位置的函数。 lambda storage, loc: storage.cuda(1):将每个存储对象移动到索引为1的GPU上。
字典(dict) 一个将存储路径映射到新位置的字典。 {'cuda:1':'cuda:0'}:将原本在GPU 1上的张量加载到GPU 0上。

map_location参数的使用场景

  1. CPU加载 :当你想在CPU上加载模型时,可以设置map_location='cpu'。这适用于那些不需要GPU加速的推理任务,或者在没有GPU的环境中部署模型。

  2. 指定GPU加载 :如果你有多个GPU,并且想将模型加载到特定的GPU上,可以使用'cuda:X'格式的字符串,其中X是GPU的索引。这在多GPU环境中非常有用,可以确保模型加载到指定的设备上。

  3. 自动选择GPU :如果你只想在GPU上加载模型,但不关心具体是哪一个GPU,可以设置map_location=torch.device('cuda')。这会自动选择第一个可用的GPU来加载模型。

  4. 保持原始设备 :如果你想保持模型在加载时的原始设备(即如果模型原先是在GPU上训练的,就仍然在GPU上加载;如果是在CPU上,就在CPU上加载),可以使用map_location=Nonemap_location=torch.device('cpu')(对于CPU模型)和map_location=torch.device('cuda')(对于GPU模型)。

  5. 自定义映射逻辑:通过传递一个可调用对象,你可以实现更复杂的映射逻辑。例如,你可以编写一个函数,根据存储路径或模型结构来决定将模型加载到哪个设备上。这在需要根据特定条件动态选择加载设备时非常有用。

🌳代码实战(详细注释)🌳

下面将通过几个实战案例来展示map_location参数在不同场景下的应用。

案例1:从文件加载张量到CPU

python 复制代码
# 案例1:从文件加载张量到CPU
# 使用torch.load()函数加载tensors.pt文件中的所有张量到CPU上
tensors = torch.load('tensors.pt')

案例2:指定设备加载张量

python 复制代码
# 案例2:指定设备加载张量
# 使用torch.load()函数并指定map_location参数为CPU设备,加载tensors.pt文件中的所有张量到CPU上
tensors_on_cpu = torch.load('tensors.pt', map_location=torch.device('cpu'))

案例3:使用匿名函数指定加载位置

python 复制代码
# 案例3:使用函数指定加载位置
# 使用torch.load()函数和map_location参数为一个lambda函数,该函数不做任何改变,保持张量原始位置(通常是CPU)
tensors_original_location = torch.load('tensors.pt', map_location=lambda storage, loc: storage)

案例4:将张量加载到指定GPU

python 复制代码
# 案例4:将张量加载到指定GPU
# 使用torch.load()函数和map_location参数为一个lambda函数,该函数将张量移动到索引为1的GPU上
tensors_on_gpu1 = torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))

案例5:张量从一个GPU映射到另一个GPU

python 复制代码
# 案例5:张量从一个GPU映射到另一个GPU
# 使用torch.load()函数和map_location参数为一个字典,将原本在GPU 1上的张量映射到GPU 0上
tensors_mapped = torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})

案例6:从io.BytesIO对象加载张量

python 复制代码
# 案例6:从io.BytesIO对象加载张量
# 打开tensor.pt文件并读取内容到BytesIO缓冲区
with open('tensor.pt', 'rb') as f:
    buffer = io.BytesIO(f.read())
    
# 使用torch.load()函数从BytesIO缓冲区加载张量
tensors_from_buffer = torch.load(buffer)

案例7:使用ASCII编码加载模块

python 复制代码
# 案例7:使用ASCII编码加载模块
# 使用torch.load()函数和encoding参数为'ascii',加载module.pt文件中的模块(如神经网络模型)
model = torch.load('module.pt', encoding='ascii')

这些案例代码和注释展示了如何使用torch.load()函数的不同map_location参数和编码设置来加载张量和模型。这些设置对于控制数据加载的位置和格式非常重要,特别是在跨设备或跨平台加载数据时。


🌳参考文档🌳

[1] PyTorch官方文档


🌳结尾🌳

亲爱的读者,首先感谢抽出宝贵的时间来阅读我们的博客。我们真诚地欢迎您留下评论和意见💬。

俗话说,当局者迷,旁观者清。的客观视角对于我们发现博文的不足、提升内容质量起着不可替代的作用。

如果博文给您带来了些许帮助,那么,希望能为我们点个免费的赞👍👍/收藏👇👇,您的支持和鼓励👏👏是我们持续创作✍️✍️的动力。

我们会持续努力创作✍️✍️,并不断优化博文质量👨‍💻👨‍💻,只为给带来更佳的阅读体验。

如果有任何疑问或建议,请随时在评论区留言,我们将竭诚为你解答~

愿我们共同成长🌱🌳,共享智慧的果实🍎🍏!


万分感谢🙏🙏点赞 👍👍、收藏 ⭐🌟、评论 💬🗯️、关注❤️💚~

相关推荐
昨日之日20061 小时前
Moonshine - 新型开源ASR(语音识别)模型,体积小,速度快,比OpenAI Whisper快五倍 本地一键整合包下载
人工智能·whisper·语音识别
浮生如梦_1 小时前
Halcon基于laws纹理特征的SVM分类
图像处理·人工智能·算法·支持向量机·计算机视觉·分类·视觉检测
深度学习lover1 小时前
<项目代码>YOLOv8 苹果腐烂识别<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·苹果腐烂识别
热爱跑步的恒川2 小时前
【论文复现】基于图卷积网络的轻量化推荐模型
网络·人工智能·开源·aigc·ai编程
API快乐传递者2 小时前
淘宝反爬虫机制的主要手段有哪些?
爬虫·python
阡之尘埃4 小时前
Python数据分析案例61——信贷风控评分卡模型(A卡)(scorecardpy 全面解析)
人工智能·python·机器学习·数据分析·智能风控·信贷风控
孙同学要努力6 小时前
全连接神经网络案例——手写数字识别
人工智能·深度学习·神经网络
Eric.Lee20216 小时前
yolo v5 开源项目
人工智能·yolo·目标检测·计算机视觉
其实吧37 小时前
基于Matlab的图像融合研究设计
人工智能·计算机视觉·matlab
丕羽7 小时前
【Pytorch】基本语法
人工智能·pytorch·python