Pytorch如何将嵌套的dict类型数据加载到GPU

在PyTorch中,您可以使用.to(device)方法将嵌套的字典中的所有支持的Tensor对象转移到GPU。以下是一个简单的例子

python 复制代码
import torch
 
# 假设您已经有了一个名为device的GPU设备对象
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
# 嵌套的字典,其中包含一些Tensors
nested_dict = {
    'a': torch.randn(2, 2),
    'b': {
        'b1': torch.randn(2, 2),
        'b2': torch.randn(2, 2)
    },
    'c': torch.randn(2, 2)
}
 
# 将嵌套字典中的所有Tensors移动到GPU
def to_gpu(data):
    if isinstance(data, dict):
        return {k: to_gpu(v) for k, v in data.items()}
    elif isinstance(data, list):
        return [to_gpu(i) for i in data]
    elif isinstance(data, tuple):
        return tuple([to_gpu(i) for i in data])
    elif torch.is_tensor(data) and data.device != device:
        return data.to(device)
    else:
        return data
 
nested_dict_gpu = to_gpu(nested_dict)
 
# 检查是否所有Tensors都已移动到GPU
for k, v in nested_dict_gpu.items():
    if torch.is_tensor(v):
        assert v.device == device

这个函数to_gpu会递归地检查字典中的每个元素,如果是Tensor类型并且不在GPU上,就会使用.to(device)方法转移它。您需要先设置device变量指向您的GPU设备。如果没有GPU可用,它会默认使用CPU。

相关推荐
IMER SIMPLE12 小时前
人工智能-python-深度学习-经典神经网络AlexNet
人工智能·python·深度学习
UQI-LIUWJ14 小时前
unsloth笔记:运行&微调 gemma
人工智能·笔记·深度学习
THMAIL14 小时前
深度学习从入门到精通 - 生成对抗网络(GAN)实战:创造逼真图像的魔法艺术
人工智能·python·深度学习·神经网络·机器学习·生成对抗网络·cnn
北京地铁1号线14 小时前
GPT(Generative Pre-trained Transformer)模型架构与损失函数介绍
gpt·深度学习·transformer
fantasy_arch15 小时前
9.3深度循环神经网络
人工智能·rnn·深度学习
Shiyuan717 小时前
【检索通知】2025年IEEE第二届深度学习与计算机视觉国际会议检索
人工智能·深度学习·计算机视觉
cyyt19 小时前
深度学习周报(9.1~9.7)
人工智能·深度学习
max50060019 小时前
图像处理:实现多图点重叠效果
开发语言·图像处理·人工智能·python·深度学习·音视频
西猫雷婶1 天前
scikit-learn/sklearn学习|广义线性回归损失函数的基本表达式
深度学习·神经网络·学习·机器学习·线性回归·scikit-learn·概率论
IMER SIMPLE1 天前
人工智能-python-深度学习-神经网络-MobileNet V1&V2
人工智能·python·深度学习