PyTorch_指定运算设备 (包含安装 GPU 的 PyTorch)

PyTorch默认会将张量创建在 CPU 控制的内存中,即:默认的运算设备为 CPU。我们也可以将张量创建在 GPU 上,能够利用对于矩阵计算的优势加快模型训练。

将张量移动到 GPU 上有两种方法:

  1. 使用 cuda 方法
  2. 直接在 GPU 上创建张量
  3. 使用 to 方法指定设备

安装含有 GPU 的 PyTorch

通过这个可以判断电脑里是否已经安装 CUDA

python 复制代码
import torch 

print(torch.cuda.is_available()) # 判断是否有可用的 GPU 设备

如果结果输出是 False,说明设备里没有 GPU 的 PyTorch 的版本。

可以通过 PyTorch官网来安装含有 GPU 的 PyTorch 的版本。

安装 CUDA 12.8 版本的pip命令

复制代码
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

这样就安装好含有 GPU 的 PyTorch 了。


代码

python 复制代码
import torch 

# 使用 cuda 方法
def test01():
    data = torch.tensor([10, 20, 30])
    print("存储设备:", data.device)

    # 将张量移动到 GPU 设备上
    data = data.cuda()
    print("存储设备:", data.device)
    
    # 将张量从 GPU 移动到 CPU 设备上
    data = data.cpu()
    print("存储设备:", data.device)

# 直接将张量创建在指定设备上
def test02():
    data = torch.tensor([10, 20, 30], device='cuda')
    print("存储设备:", data.device)

    # 将张量移动到 CPU 设备上
    data = data.cpu()
    print("存储设备:", data.device)

# 使用 to 方法
def test03():
    data = torch.tensor([10, 20, 30])
    print("存储设备:", data.device)

    # 将张量移动到 GPU 设备上
    data = data.to('cuda')
    print("存储设备:", data.device)
    
    # 将张量从 GPU 移动到 CPU 设备上
    data = data.to('cpu')
    print("存储设备:", data.device)


# 注意点:张量存储在不同设备上的张量不能够直接运算
def test04():
    data1 = torch.tensor([10,20,30])
    data2 = torch.tensor([1,2,3], device='cuda')

    # RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
    # data1 = data1.to('cuda') # 这样可以解决了
    data = data1 + data2 
    print(data)

    # 如果你的电脑上安装 pytorch 不是 gpu 版本的,或者电脑本身没有 gpu 设备环境
    # 否则下面的调用 cuda 函数的代码会报错
    # data1 = data1.cuda() 

if __name__ == "__main__":
    test04() 
相关推荐
张较瘦_6 分钟前
[论文阅读] 人工智能 | 机器学习工作流的“救星”:数据虚拟化服务如何解决数据管理难题?
论文阅读·人工智能·机器学习
蓝卓工业操作系统1 小时前
天铭科技×蓝卓 | “1+2+N”打造AI驱动的汽车零部件行业智能工厂
人工智能·科技·汽车
zzywxc7871 小时前
编程算法在金融、医疗、教育、制造业等领域的落地案例
人工智能·算法·金融·自动化·copilot·ai编程
zzywxc7871 小时前
编程算法在金融、医疗、教育、制造业的落地应用。
人工智能·深度学习·算法·机器学习·金融·架构·开源
修一呀1 小时前
【数据标注】详解使用 Labelimg 进行数据标注的 Conda 环境搭建与操作流程
人工智能·conda
白熊1885 小时前
【大模型LLM】梯度累积(Gradient Accumulation)原理详解
人工智能·大模型·llm
愚戏师5 小时前
机器学习(重学版)基础篇(算法与模型一)
人工智能·算法·机器学习
仰望星空的凡人6 小时前
【JS逆向基础】数据库之MongoDB
javascript·数据库·python·mongodb
F_D_Z6 小时前
【PyTorch】图像多分类项目部署
人工智能·pytorch·python·深度学习·分类
pingzhuyan7 小时前
python入门篇12-虚拟环境conda的安装与使用
python·ai·llm·ocr·conda