CUDA和torch的安装

cuda的下载及安装

cuda版本

如何判断自己应该下载什么版本的cuda

win+r打开控制台

图中我们可以得到nvidia的驱动版本和GPU支持的最高cuda版本。

网站:

CUDA Toolkit Archive | NVIDIA Developer

进来后我们选择对应的版本即可

cuda安装

然后选择自定义安装

我们一定要取消Visual Studio Integration的勾选,不然安装就会报错。

然后一直点下一步等待安装完成就行。

验证是否安装成功

打开控制台输入nvcc -V

如果能够成功输出cuda的版本就说明我们对应的cuda安装成功了

torch的安装

由于的torch这个第三方库太大了,想要通过pip来安装对于我们网络要求太高了。所以我们就提前下载好对应的本地文件来实现安装。

来到torch官网,在下方选择自己对应的版本。

复制代码
如图我们就能获取对应的一个网站
https://download.pytorch.org/whl/cu126

在这里选择我们想要下载的torch版本,有+cpu代表是无GPU的电脑下载的,cp后面的数字

代表我们的python版本,最后是我们的Windows版本或是linux版本。

pip install torch-2.4.0+cu121-cp310-cp310-win_amd64.whl

只需等待安装完成即可。

第一个torch程序

复制代码
import torch
print(torch.__version__)

from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

training_data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor(),
)

test_data = datasets.MNIST(
    root='data',
    train=False,
    download=True,
    transform=ToTensor(),
)
print(len(training_data))

from matplotlib import pyplot as plt
fig = plt.figure()
for i in range(9):
    img,label= training_data[i+50000]

    fig.add_subplot(3,3,i+1)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img.squeeze(),cmap="gray")
    a = img.squeeze()
plt.show()
复制代码
# 加载训练数据集
training_data = datasets.MNIST(
    root='data',  # 数据存储目录
    train=True,   # 表示这是训练集
    download=True,# 如果本地没有数据则自动下载
    transform=ToTensor(),  # 将图像转换为Tensor格式
)

# 加载测试数据集
test_data = datasets.MNIST(
    root='data',
    train=False,  # 表示这是测试集
    download=True,
    transform=ToTensor(),
)

NumPy 数组

仅支持在CPU上计算,无法直接利用 GPU 进行加速。

Tensor

支持在CPU 和 GPU上计算,利用 CUDA 进行并行加速。

复制代码
fig = plt.figure()  # 创建一个图形

for i in range(9):
    img, label = training_data[i+50000]  # 获取第50000+i个样本的图像和标签
    fig.add_subplot(3, 3, i+1)  # 添加子图,3行3列布局
    plt.title(label)  # 显示标签(即图像对应的数字)
    plt.axis("off")  # 关闭坐标轴显示
    plt.imshow(img.squeeze(), cmap="gray")  # 显示图像,转为灰度图
    a = img.squeeze()  # 降维
    
plt.show()  # 显示图形