pytorch神经网络训练(AlexNet)

  • 导包

    import os

    import torch

    import torch.nn as nn

    import torch.optim as optim

    from torch.utils.data import Dataset, DataLoader

    from PIL import Image

    from torchvision import models, transforms

  • 定义自定义图像数据集

    class CustomImageDataset(Dataset):

定义一个自定义的图像数据集类,继承自Dataset

复制代码
def __init__(self, main_dir, transform=None): 

初始化方法,接收主目录和转换方法

复制代码
        self.main_dir = main_dir 

主目录,包含多个子目录,每个子目录包含同一类别的图像

复制代码
        self.transform = transform

图像转换方法,用于对图像进行预处理

复制代码
        self.files = [] 

存储所有图像文件的路径

复制代码
        self.labels = [] 

存储所有图像的标签

复制代码
        self.label_to_index = {} 

创建一个字典,用于将标签映射到索引

复制代码
        for index, label in enumerate(os.listdir(main_dir)):

遍历主目录中的所有子目录

复制代码
          self.label_to_index[label] = index 

           label_dir = os.path.join(main_dir, label) 

将标签映射到索引,构建标签子目录的路径

复制代码
           if os.path.isdir(label_dir): 

               for file in os.listdir(label_dir): 

                    self.files.append(os.path.join(label_dir, file))

                    self.labels.append(label) 

如果是目录,遍历目录中的所有文件,将文件路径添加到列表,将标签添加到列表

复制代码
def __len__(self):

定义数据集的长度

复制代码
        return len(self.files) 

返回文件列表的长度

复制代码
def __getitem__(self, idx): 

定义获取数据集中单个样本的方法

复制代码
        image = Image.open(self.files[idx]) 

        label = self.labels[idx] 

        if self.transform: 

            image = self.transform(image) 

        return image, self.label_to_index[label] 

打开图像文件,获取图像的标签,如果有转换方法,对图像进行转换,返回图像和对应的标签索引

  • 定义数据转换

    transform = transforms.Compose([

    复制代码
      transforms.Resize((227, 227)),  # AlexNet的输入图像大小
    
      transforms.RandomHorizontalFlip(),  # 随机水平翻转
    
      transforms.RandomRotation(10),  # 随机旋转
    
      transforms.ToTensor(),
    
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # AlexNet的标准化

    ])

  • 创建数据集

    dataset = CustomImageDataset(main_dir="D:\图像处理、深度学习\flowers", transform=transform)

  • 创建数据加载器

    data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

  • 加载预训练的AlexNet模型

    alexnet_model = models.alexnet(pretrained=True)

  • 修改最后几层以适应新的分类任务

    num_ftrs = alexnet_model.classifier[6].in_features

    alexnet_model.classifier[6] = nn.Linear(num_ftrs, len(dataset.label_to_index))

  • 定义损失函数和优化器

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.Adam(alexnet_model.parameters(), lr=0.0001)

  • 如果有多个GPU,可以使用nn.DataParallel来并行化模型

    if torch.cuda.device_count() > 1:

    复制代码
      alexnet_model = nn.DataParallel(alexnet_model)
  • 将模型发送到GPU(如果可用)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    alexnet_model.to(device)

  • 模型评估

    def evaluate_model(model, data_loader, device):

    复制代码
      model.eval()  # 将模型设置为评估模式
    
      correct = 0
    
      total = 0
    
      with torch.no_grad():  # 在这个块中,所有计算都不会计算梯度
    
          for images, labels in data_loader:
    
              images, labels = images.to(device), labels.to(device)
    
              outputs = model(images)
    
              _, predicted = torch.max(outputs.data, 1)
    
              total += labels.size(0)
    
              correct += (predicted == labels).sum().item()
    
      accuracy = 100 * correct / total
    
      return accuracy
  • 训练模型

    num_epochs = 10

    for epoch in range(num_epochs):

    复制代码
      alexnet_model.train()
    
      running_loss = 0.0
    
      for images, labels in data_loader:
    
          images, labels = images.to(device), labels.to(device)

前向传播

复制代码
        outputs = alexnet_model(images)

        loss = criterion(outputs, labels)

反向传播和优化

复制代码
        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        running_loss += loss.item()

在每个epoch结束后评估模型

复制代码
    train_accuracy = evaluate_model(alexnet_model, data_loader, device)

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%')
相关推荐
程序猿追9 分钟前
深度解读 CANN HCCL:揭秘昇腾高性能集体通信的同步机制
神经网络·架构
历程里程碑20 分钟前
普通数组----合并区间
java·数据结构·python·算法·leetcode·职场和发展·tornado
weixin_3954489121 分钟前
mult_yolov5_post_copy.c_cursor_0205
c语言·python·yolo
User_芊芊君子25 分钟前
CANN数学计算基石ops-math深度解析:高性能科学计算与AI模型加速的核心引擎
人工智能·深度学习·神经网络·ai
执风挽^37 分钟前
Python基础编程题2
开发语言·python·算法·visual studio code
纤纡.1 小时前
PyTorch 入门精讲:从框架选择到 MNIST 手写数字识别实战
人工智能·pytorch·python
kjkdd1 小时前
6.1 核心组件(Agent)
python·ai·语言模型·langchain·ai编程
小镇敲码人1 小时前
剖析CANN框架中Samples仓库:从示例到实战的AI开发指南
c++·人工智能·python·华为·acl·cann
萧鼎1 小时前
Python 包管理的“超音速”革命:全面上手 uv 工具链
开发语言·python·uv
摘星编程1 小时前
CANN ops-nn Pooling算子解读:CNN模型下采样与特征提取的核心
人工智能·神经网络·cnn