pytorch 指定GPU设备

使用os.environ["CUDA_VISIBLE_DEVICES"]

这种方法是通过环境变量限制可见的CUDA设备,从而在多个GPU的机器上只让PyTorch看到并使用指定的GPU。这种方式的好处是所有后续的CUDA调用都会使用这个GPU,并且代码中不需要显式地指定设备索引。

python 复制代码
import os

# 设置只使用2号GPU
os.environ["CUDA_VISIBLE_DEVICES"] = '2'

import torch
import torch.nn as nn

# 检查PyTorch是否检测到GPU
if torch.cuda.is_available():
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")  # 注意这里是0,因为只有一个可见的GPU
else:
    print("No GPU available, using CPU instead.")

# 定义模型
class YourModel(nn.Module):
    def __init__(self):
        super(YourModel, self).__init__()
        self.layer = nn.Linear(10, 1)
    
    def forward(self, x):
        return self.layer(x)

# 创建模型并移动到GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = YourModel().to(device)

# 示例数据和前向传播
input_data = torch.randn(5, 10).to(device)
output = model(input_data)
print(output)

直接指定设备索引

这种方法是在代码中直接指定要使用的设备索引,无需修改环境变量。这种方式更加显式,并且可以在同一脚本中使用多个不同的GPU。

python 复制代码
import torch
import torch.nn as nn

# 检查设备是否可用并打印设备名称
if torch.cuda.is_available():
    device = torch.device("cuda:2")  # 直接指定设备索引
    print(f"Using GPU: {torch.cuda.get_device_name(2)}")
else:
    device = torch.device("cpu")
    print("No GPU available, using CPU instead.")

# 定义模型
class YourModel(nn.Module):
    def __init__(self):
        super(YourModel, self).__init__()
        self.layer = nn.Linear(10, 1)
    
    def forward(self, x):
        return self.layer(x)

# 创建模型并移动到指定的GPU
model = YourModel().to(device)

# 示例数据和前向传播
input_data = torch.randn(5, 10).to(device)
output = model(input_data)
print(output)
相关推荐
RoboWizard5 分钟前
PCIe 5.0 SSD有无独立缓存对性能影响大吗?Kingston FURY Renegade G5!
人工智能·缓存·电脑·金士顿
自学互联网5 分钟前
使用Python构建钢铁行业生产监控系统:从理论到实践
开发语言·python
无心水10 分钟前
【Python实战进阶】7、Python条件与循环实战详解:从基础语法到高级技巧
android·java·python·python列表推导式·python条件语句·python循环语句·python实战案例
霍格沃兹测试开发学社-小明22 分钟前
测试左移2.0:在开发周期前端筑起质量防线
前端·javascript·网络·人工智能·测试工具·easyui
懒麻蛇22 分钟前
从矩阵相关到矩阵回归:曼特尔检验与 MRQAP
人工智能·线性代数·矩阵·数据挖掘·回归
xwill*26 分钟前
RDT-1B: A DIFFUSION FOUNDATION MODEL FOR BIMANUAL MANIPULATION
人工智能·pytorch·python·深度学习
网安INF33 分钟前
机器学习入门:深入理解线性回归
人工智能·机器学习·线性回归
陈奕昆34 分钟前
n8n实战营Day2课时2:Loop+Merge节点进阶·Excel批量校验实操
人工智能·python·excel·n8n
程序猿追37 分钟前
PyTorch算子模板库技术解读:无缝衔接PyTorch模型与Ascend硬件的桥梁
人工智能·pytorch·python·深度学习·机器学习
程序小旭38 分钟前
Kaggle平台的使用
人工智能