PyTorch_指定运算设备 (包含安装 GPU 的 PyTorch)

PyTorch默认会将张量创建在 CPU 控制的内存中,即:默认的运算设备为 CPU。我们也可以将张量创建在 GPU 上,能够利用对于矩阵计算的优势加快模型训练。

将张量移动到 GPU 上有两种方法:

  1. 使用 cuda 方法
  2. 直接在 GPU 上创建张量
  3. 使用 to 方法指定设备

安装含有 GPU 的 PyTorch

通过这个可以判断电脑里是否已经安装 CUDA

python 复制代码
import torch 

print(torch.cuda.is_available()) # 判断是否有可用的 GPU 设备

如果结果输出是 False,说明设备里没有 GPU 的 PyTorch 的版本。

可以通过 PyTorch官网来安装含有 GPU 的 PyTorch 的版本。

安装 CUDA 12.8 版本的pip命令

复制代码
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

这样就安装好含有 GPU 的 PyTorch 了。


代码

python 复制代码
import torch 

# 使用 cuda 方法
def test01():
    data = torch.tensor([10, 20, 30])
    print("存储设备:", data.device)

    # 将张量移动到 GPU 设备上
    data = data.cuda()
    print("存储设备:", data.device)
    
    # 将张量从 GPU 移动到 CPU 设备上
    data = data.cpu()
    print("存储设备:", data.device)

# 直接将张量创建在指定设备上
def test02():
    data = torch.tensor([10, 20, 30], device='cuda')
    print("存储设备:", data.device)

    # 将张量移动到 CPU 设备上
    data = data.cpu()
    print("存储设备:", data.device)

# 使用 to 方法
def test03():
    data = torch.tensor([10, 20, 30])
    print("存储设备:", data.device)

    # 将张量移动到 GPU 设备上
    data = data.to('cuda')
    print("存储设备:", data.device)
    
    # 将张量从 GPU 移动到 CPU 设备上
    data = data.to('cpu')
    print("存储设备:", data.device)


# 注意点:张量存储在不同设备上的张量不能够直接运算
def test04():
    data1 = torch.tensor([10,20,30])
    data2 = torch.tensor([1,2,3], device='cuda')

    # RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
    # data1 = data1.to('cuda') # 这样可以解决了
    data = data1 + data2 
    print(data)

    # 如果你的电脑上安装 pytorch 不是 gpu 版本的,或者电脑本身没有 gpu 设备环境
    # 否则下面的调用 cuda 函数的代码会报错
    # data1 = data1.cuda() 

if __name__ == "__main__":
    test04() 
相关推荐
会飞的老朱2 小时前
医药集团数智化转型,智能综合管理平台激活集团管理新效能
大数据·人工智能·oa协同办公
聆风吟º4 小时前
CANN runtime 实战指南:异构计算场景中运行时组件的部署、调优与扩展技巧
人工智能·神经网络·cann·异构计算
寻星探路4 小时前
【深度长文】万字攻克网络原理:从 HTTP 报文解构到 HTTPS 终极加密逻辑
java·开发语言·网络·python·http·ai·https
Codebee6 小时前
能力中心 (Agent SkillCenter):开启AI技能管理新时代
人工智能
聆风吟º7 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
uesowys7 小时前
Apache Spark算法开发指导-One-vs-Rest classifier
人工智能·算法·spark
AI_56787 小时前
AWS EC2新手入门:6步带你从零启动实例
大数据·数据库·人工智能·机器学习·aws
User_芊芊君子7 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
ValhallaCoder7 小时前
hot100-二叉树I
数据结构·python·算法·二叉树
智驱力人工智能7 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算