pytorch神经网络训练(AlexNet)

  • 导包

    import os

    import torch

    import torch.nn as nn

    import torch.optim as optim

    from torch.utils.data import Dataset, DataLoader

    from PIL import Image

    from torchvision import models, transforms

  • 定义自定义图像数据集

    class CustomImageDataset(Dataset):

定义一个自定义的图像数据集类,继承自Dataset

复制代码
def __init__(self, main_dir, transform=None): 

初始化方法,接收主目录和转换方法

复制代码
        self.main_dir = main_dir 

主目录,包含多个子目录,每个子目录包含同一类别的图像

复制代码
        self.transform = transform

图像转换方法,用于对图像进行预处理

复制代码
        self.files = [] 

存储所有图像文件的路径

复制代码
        self.labels = [] 

存储所有图像的标签

复制代码
        self.label_to_index = {} 

创建一个字典,用于将标签映射到索引

复制代码
        for index, label in enumerate(os.listdir(main_dir)):

遍历主目录中的所有子目录

复制代码
          self.label_to_index[label] = index 

           label_dir = os.path.join(main_dir, label) 

将标签映射到索引,构建标签子目录的路径

复制代码
           if os.path.isdir(label_dir): 

               for file in os.listdir(label_dir): 

                    self.files.append(os.path.join(label_dir, file))

                    self.labels.append(label) 

如果是目录,遍历目录中的所有文件,将文件路径添加到列表,将标签添加到列表

复制代码
def __len__(self):

定义数据集的长度

复制代码
        return len(self.files) 

返回文件列表的长度

复制代码
def __getitem__(self, idx): 

定义获取数据集中单个样本的方法

复制代码
        image = Image.open(self.files[idx]) 

        label = self.labels[idx] 

        if self.transform: 

            image = self.transform(image) 

        return image, self.label_to_index[label] 

打开图像文件,获取图像的标签,如果有转换方法,对图像进行转换,返回图像和对应的标签索引

  • 定义数据转换

    transform = transforms.Compose([

    复制代码
      transforms.Resize((227, 227)),  # AlexNet的输入图像大小
    
      transforms.RandomHorizontalFlip(),  # 随机水平翻转
    
      transforms.RandomRotation(10),  # 随机旋转
    
      transforms.ToTensor(),
    
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # AlexNet的标准化

    ])

  • 创建数据集

    dataset = CustomImageDataset(main_dir="D:\图像处理、深度学习\flowers", transform=transform)

  • 创建数据加载器

    data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

  • 加载预训练的AlexNet模型

    alexnet_model = models.alexnet(pretrained=True)

  • 修改最后几层以适应新的分类任务

    num_ftrs = alexnet_model.classifier[6].in_features

    alexnet_model.classifier[6] = nn.Linear(num_ftrs, len(dataset.label_to_index))

  • 定义损失函数和优化器

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.Adam(alexnet_model.parameters(), lr=0.0001)

  • 如果有多个GPU,可以使用nn.DataParallel来并行化模型

    if torch.cuda.device_count() > 1:

    复制代码
      alexnet_model = nn.DataParallel(alexnet_model)
  • 将模型发送到GPU(如果可用)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    alexnet_model.to(device)

  • 模型评估

    def evaluate_model(model, data_loader, device):

    复制代码
      model.eval()  # 将模型设置为评估模式
    
      correct = 0
    
      total = 0
    
      with torch.no_grad():  # 在这个块中,所有计算都不会计算梯度
    
          for images, labels in data_loader:
    
              images, labels = images.to(device), labels.to(device)
    
              outputs = model(images)
    
              _, predicted = torch.max(outputs.data, 1)
    
              total += labels.size(0)
    
              correct += (predicted == labels).sum().item()
    
      accuracy = 100 * correct / total
    
      return accuracy
  • 训练模型

    num_epochs = 10

    for epoch in range(num_epochs):

    复制代码
      alexnet_model.train()
    
      running_loss = 0.0
    
      for images, labels in data_loader:
    
          images, labels = images.to(device), labels.to(device)

前向传播

复制代码
        outputs = alexnet_model(images)

        loss = criterion(outputs, labels)

反向传播和优化

复制代码
        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        running_loss += loss.item()

在每个epoch结束后评估模型

复制代码
    train_accuracy = evaluate_model(alexnet_model, data_loader, device)

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%')
相关推荐
凛铄linshuo33 分钟前
爬虫简单实操2——以贴吧为例爬取“某吧”前10页的网页代码
爬虫·python·学习
牛客企业服务36 分钟前
2025年AI面试推荐榜单,数字化招聘转型优选
人工智能·python·算法·面试·职场和发展·金融·求职招聘
胡斌附体1 小时前
linux测试端口是否可被外部访问
linux·运维·服务器·python·测试·端口测试·临时服务器
likeGhee1 小时前
python缓存装饰器实现方案
开发语言·python·缓存
项目題供诗2 小时前
黑马python(二十五)
开发语言·python
读书点滴2 小时前
笨方法学python -练习14
java·前端·python
笑衬人心。2 小时前
Ubuntu 22.04 修改默认 Python 版本为 Python3 笔记
笔记·python·ubuntu
蛋仔聊测试2 小时前
Playwright 中 Page 对象的常用方法详解
python
前端付豪2 小时前
17、自动化才是正义:用 Python 接管你的日常琐事
后端·python
jioulongzi2 小时前
记录一次莫名奇妙的跨域502(badgateway)错误
开发语言·python