pytorch2.5实例教程

以下是再次为你提供的一个详细的PyTorch使用教程:

一、安装PyTorch

  • 环境准备
    • 确保系统已安装合适版本的Python(推荐3.10及以上)。
  • 安装方式
    • CPU版本
      • 对于Linux和macOS:
        • 使用命令 pip install torch torchvision torchaudio
      • 对于Windows:
        • 先处理好依赖项,然后使用类似的pip命令安装。
    • GPU版本(依赖于CUDA)
      • 依据CUDA版本在官网查找对应命令。例如,若CUDA为12.4:
        • 执行 conda install pytorch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 pytorch-cuda=12.4 -c pytorch -c nvidia。

二、PyTorch基础概念

  • 张量(Tensors)
    • 核心数据结构,类似NumPy数组且可在GPU加速计算。
    • 创建方式
      • 从列表创建:
        • 示例:
          • import torch
          • my_list = [1, 2, 3]
          • tensor = torch.tensor(my_list)
      • 创建随机张量:
        • 例如:random_tensor = torch.randn(3, 3)(创建3x3随机正态分布张量)。
  • 计算图与自动微分
    • 计算基于构建计算图,操作张量时自动构建。
    • 示例:
      • 计算 y = x^2 + 3x 的梯度。
        • x = torch.tensor([2.0], requires_grad = True)
        • y = x ** 2+3 * x
        • y.backward()
        • print(x.grad)

三、创建神经网络模型

  • 定义网络结构

    • 使用 nn.Module 类。

    • 示例(全连接神经网络):

      • import torch.nn as nn
      复制代码
        class MyNet(nn.Module):
            def __init__(self):
                super(MyNet, self).__init__()
                self.fc1 = nn.Linear(10, 5)
                self.fc2 = nn.Linear(5, 1)
            def forward(self, x):
                x = torch.relu(self.fc1(x))
                x = self.fc2(x)
                return x
  • 模型初始化与参数查看

    • 初始化:model = MyNet()

    • 参数查看:
      *

      复制代码
        for name, param in model.named_parameters():
            print(name, param.size())

四、数据处理

  • 数据加载

    • 使用 DataLoader 类,需先创建数据集类(继承 torch.utils.data.Dataset)。

    • 示例:
      *

      复制代码
        from torch.utils.data import Dataset, DataLoader
        class MyDataset(Dataset):
            def __init__(self):
                self.data = torch.randn(100, 10)
                self.labels = torch.randint(0, 2, (100,))
            def __getitem__(self, index):
                return self.data[index], self.labels[index]
            def __len__(self):
                return len(self.data)
        dataset = MyDataset()
        dataloader = DataLoader(dataset, batch_size = 10, shuffle = True)
  • 数据预处理

    • 以图像数据为例,使用 torchvision.transforms

    • 示例:
      *

      复制代码
        import torchvision.transforms as transforms
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

五、训练模型

  • 定义损失函数和优化器
    • 损失函数
      • 例如回归问题用均方误差(MSE):criterion = nn.MSELoss()
    • 优化器
      • 如随机梯度下降(SGD):optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)
  • 训练循环
    • 多轮训练:
      *

      复制代码
        num_epochs = 10
        for epoch in range(num_epochs):
            for batch_data, batch_labels in dataloader:
                optimizer.zero_grad()
                outputs = model(batch_data)
                loss = criterion(outputs, batch_labels)
                loss.backward()
                optimizer.step()
            print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

六、模型评估与预测

  • 模型评估

    • 以分类问题计算准确率为例:
      *

      复制代码
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_data, batch_labels in dataloader:
                outputs = model(batch_data)
                _, predicted = torch.max(outputs.data, 1)
                total += batch_labels.size(0)
                correct += (predicted == batch_labels).sum().item()
        accuracy = correct / total
        print(f'Accuracy: {accuracy}')
  • 预测新数据

    • 示例:
      *

      复制代码
        new_data = torch.randn(1, 10)
        with torch.no_grad():
            prediction = model(new_data)
        print(f'Prediction: {prediction}')
相关推荐
向qian看_-_9 分钟前
Linux 使用pip报错(error: externally-managed-environment )解决方案
linux·python·pip
Nicole-----34 分钟前
Python - Union联合类型注解
开发语言·python
Eric.5653 小时前
python advance -----object-oriented
python
云天徽上4 小时前
【数据可视化-107】2025年1-7月全国出口总额Top 10省市数据分析:用Python和Pyecharts打造炫酷可视化大屏
开发语言·python·信息可视化·数据挖掘·数据分析·pyecharts
THMAIL4 小时前
机器学习从入门到精通 - 数据预处理实战秘籍:清洗、转换与特征工程入门
人工智能·python·算法·机器学习·数据挖掘·逻辑回归
@HNUSTer4 小时前
Python数据可视化科技图表绘制系列教程(六)
python·数据可视化·科技论文·专业制图·科研图表
THMAIL5 小时前
深度学习从入门到精通 - AutoML与神经网络搜索(NAS):自动化模型设计未来
人工智能·python·深度学习·神经网络·算法·机器学习·逻辑回归
山烛5 小时前
深度学习:残差网络ResNet与迁移学习
人工智能·python·深度学习·残差网络·resnet·迁移学习
eleqi5 小时前
Python+DRVT 从外部调用 Revit:批量创建梁(2)
python·系统集成·revit·自动化生产流水线·外部访问
BYSJMG6 小时前
计算机毕设大数据方向:基于Spark+Hadoop的餐饮外卖平台数据分析系统【源码+文档+调试】
大数据·hadoop·分布式·python·spark·django·课程设计