PyTorch 2.0 中设置默认使用 GPU 的方法

PyTorch 2.0 中设置默认使用 GPU 的方法

在 PyTorch 2.0 中,默认情况下仍然是使用 CPU 进行计算,除非明确指定使用 GPU。torch.set_default_device 是 PyTorch 2.0 引入的新功能,用于设置默认设备,使得所有后续张量和模块在没有明确指定设备的情况下,会被创建在这个默认设备上。这在代码中提供了一种更简洁的方式来指定设备,而无需在每次创建张量或模型时手动指定。

  1. 检查 PyTorch 版本

    确保使用的是 PyTorch 2.0 或更高版本:

    python 复制代码
    import torch
    print(torch.__version__)  # 必须是 2.0 或更高版本
  2. 检查 CUDA 是否可用

    在设置 GPU 为默认设备之前,确认 CUDA 可用性:

    python 复制代码
    print(torch.cuda.is_available())  # True 表示可用
  3. 设置默认设备为 GPU

    使用 torch.set_default_device 将默认设备设置为 GPU:

    python 复制代码
    import torch
    
    # 确保 CUDA 可用
    if torch.cuda.is_available():
        # 设置默认设备为 GPU
        torch.set_default_device('cuda')
        print("默认设备已设置为 GPU")
    else:
        print("CUDA 不可用,无法设置 GPU 为默认设备")
  4. 验证默认设备设置

    创建一个张量,验证其是否在 GPU 上:

    python 复制代码
    x = torch.tensor([1.0, 2.0, 3.0])
    print(x.device)  # 输出:cuda:0
  5. 模型自动加载到 GPU

    如果设置了默认设备,模型的参数和新建的张量会自动加载到 GPU:

    python 复制代码
    class MyModel(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.linear = torch.nn.Linear(10, 1)
    
        def forward(self, x):
            return self.linear(x)
    
    model = MyModel()
    print(next(model.parameters()).device)  # 输出:cuda:0
全局设置代码示例

以下代码展示如何在脚本中全局设置默认设备为 GPU:

python 复制代码
import torch

# 检查并设置默认设备
if torch.cuda.is_available():
    torch.set_default_device('cuda')
    print("默认设备已设置为 GPU")
else:
    raise RuntimeError("CUDA 不可用,请检查环境配置")

# 示例:自动使用 GPU 创建张量
x = torch.tensor([1.0, 2.0, 3.0])
print(f"x device: {x.device}")  # 输出:cuda:0

# 示例:自动将模型参数放到 GPU
model = torch.nn.Linear(5, 2)
print(f"Model parameters device: {next(model.parameters()).device}")  # 输出:cuda:0
注意事项
  1. 与设备显式管理的代码兼容性

    如果代码中显式指定了设备(如 tensor.to(device)),torch.set_default_device 不会影响这些张量。建议在全局设置后,尽量减少显式设备管理操作。

  2. 多 GPU 环境

    如果有多个 GPU,可以指定具体设备,比如 'cuda:1'。示例:

    python 复制代码
    torch.set_default_device('cuda:1')  # 使用第二块 GPU
  3. 性能调优

    默认将所有操作转移到 GPU 可能并不适合所有场景,尤其是小规模任务时,GPU 的初始化开销可能超过性能提升。根据需求灵活调整设备。

相关推荐
2401_89719055几秒前
怎样使用Navicat高级特权进行还原时解决字符集冲突_企业数据保护
jvm·数据库·python
ZC跨境爬虫3 分钟前
3D 地球卫星轨道可视化平台开发 Day5(简介接口对接+规划AI自动化卫星数据生成工作流)
前端·人工智能·3d·ai·自动化
木卫二号Coding3 分钟前
第八十四篇-V100-32G+Easyclaw+Ollama+Qwopus3.5-27B-V3
人工智能
weixin_580614004 分钟前
c++文件锁使用方法 c++如何实现多进程文件同步
jvm·数据库·python
qq_330037994 分钟前
如何转换数据文件字节序_CONVERT DATAFILE用于跨OS平台数据库迁移
jvm·数据库·python
白日梦想家6815 分钟前
博客二:递归实战避坑指南,从入门到熟练运用
开发语言·python
djjdjdjdjjdj6 分钟前
SQL窗口函数解决多维排名问题_组合排序实战
jvm·数据库·python
xiaoxiang96097 分钟前
TDD测试驱动开发:从理论到实战的完整指南(含AI增强工作流)
人工智能·驱动开发·tdd
AC赳赳老秦7 分钟前
OpenClaw与系统环境冲突:Windows/Mac系统兼容问题解决指南
开发语言·python·产品经理·策略模式·pygame·deepseek·openclaw
小张同学8248 分钟前
Python 封神技巧:1 行代码搞定 90% 日常数据处理,效率直接拉满
开发语言·人工智能·python