pytorch神经网络训练(AlexNet)

  • 导包

    import os

    import torch

    import torch.nn as nn

    import torch.optim as optim

    from torch.utils.data import Dataset, DataLoader

    from PIL import Image

    from torchvision import models, transforms

  • 定义自定义图像数据集

    class CustomImageDataset(Dataset):

定义一个自定义的图像数据集类,继承自Dataset

复制代码
def __init__(self, main_dir, transform=None): 

初始化方法,接收主目录和转换方法

复制代码
        self.main_dir = main_dir 

主目录,包含多个子目录,每个子目录包含同一类别的图像

复制代码
        self.transform = transform

图像转换方法,用于对图像进行预处理

复制代码
        self.files = [] 

存储所有图像文件的路径

复制代码
        self.labels = [] 

存储所有图像的标签

复制代码
        self.label_to_index = {} 

创建一个字典,用于将标签映射到索引

复制代码
        for index, label in enumerate(os.listdir(main_dir)):

遍历主目录中的所有子目录

复制代码
          self.label_to_index[label] = index 

           label_dir = os.path.join(main_dir, label) 

将标签映射到索引,构建标签子目录的路径

复制代码
           if os.path.isdir(label_dir): 

               for file in os.listdir(label_dir): 

                    self.files.append(os.path.join(label_dir, file))

                    self.labels.append(label) 

如果是目录,遍历目录中的所有文件,将文件路径添加到列表,将标签添加到列表

复制代码
def __len__(self):

定义数据集的长度

复制代码
        return len(self.files) 

返回文件列表的长度

复制代码
def __getitem__(self, idx): 

定义获取数据集中单个样本的方法

复制代码
        image = Image.open(self.files[idx]) 

        label = self.labels[idx] 

        if self.transform: 

            image = self.transform(image) 

        return image, self.label_to_index[label] 

打开图像文件,获取图像的标签,如果有转换方法,对图像进行转换,返回图像和对应的标签索引

  • 定义数据转换

    transform = transforms.Compose([

    复制代码
      transforms.Resize((227, 227)),  # AlexNet的输入图像大小
    
      transforms.RandomHorizontalFlip(),  # 随机水平翻转
    
      transforms.RandomRotation(10),  # 随机旋转
    
      transforms.ToTensor(),
    
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # AlexNet的标准化

    ])

  • 创建数据集

    dataset = CustomImageDataset(main_dir="D:\图像处理、深度学习\flowers", transform=transform)

  • 创建数据加载器

    data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

  • 加载预训练的AlexNet模型

    alexnet_model = models.alexnet(pretrained=True)

  • 修改最后几层以适应新的分类任务

    num_ftrs = alexnet_model.classifier[6].in_features

    alexnet_model.classifier[6] = nn.Linear(num_ftrs, len(dataset.label_to_index))

  • 定义损失函数和优化器

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.Adam(alexnet_model.parameters(), lr=0.0001)

  • 如果有多个GPU,可以使用nn.DataParallel来并行化模型

    if torch.cuda.device_count() > 1:

    复制代码
      alexnet_model = nn.DataParallel(alexnet_model)
  • 将模型发送到GPU(如果可用)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    alexnet_model.to(device)

  • 模型评估

    def evaluate_model(model, data_loader, device):

    复制代码
      model.eval()  # 将模型设置为评估模式
    
      correct = 0
    
      total = 0
    
      with torch.no_grad():  # 在这个块中,所有计算都不会计算梯度
    
          for images, labels in data_loader:
    
              images, labels = images.to(device), labels.to(device)
    
              outputs = model(images)
    
              _, predicted = torch.max(outputs.data, 1)
    
              total += labels.size(0)
    
              correct += (predicted == labels).sum().item()
    
      accuracy = 100 * correct / total
    
      return accuracy
  • 训练模型

    num_epochs = 10

    for epoch in range(num_epochs):

    复制代码
      alexnet_model.train()
    
      running_loss = 0.0
    
      for images, labels in data_loader:
    
          images, labels = images.to(device), labels.to(device)

前向传播

复制代码
        outputs = alexnet_model(images)

        loss = criterion(outputs, labels)

反向传播和优化

复制代码
        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        running_loss += loss.item()

在每个epoch结束后评估模型

复制代码
    train_accuracy = evaluate_model(alexnet_model, data_loader, device)

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%')
相关推荐
meilindehuzi_a24 分钟前
深入浅出数据结构:Python 字典(Dict)与集合(Set)的哈希表底层全链路追踪
数据结构·python·散列表
Lucas凉皮28 分钟前
20243408 2025-2026-2 《Python程序设计》综合实践报告
python·实验报告
键盘上的猫头鹰39 分钟前
【MySQL 教程(八)】索引、事务、用户管理、导入导出与分页查询
数据库·python·mysql
薛定谔的猫-菜鸟程序员1 小时前
2小时智能体开发一个智能体?我用CodeArts Agent 和 AtomCode 开发了一个适老化智能体。
人工智能·python·agent
bigfootyazi2 小时前
python爬虫-基本库-urllib库(常用速查)
开发语言·爬虫·python
瑶总迷弟2 小时前
使用 mis-tei 在昇腾310P上部署 bge-m3模型
pytorch·python·华为·语言模型·自然语言处理·cnn·unix
belong_my_offer3 小时前
认识到精通函数
开发语言·python
卡次卡次14 小时前
vibecoding起步注意点:插件、Skills、MCP、Hooks
服务器·数据库·python·oracle
我的xiaodoujiao4 小时前
API 接口自动化测试详细图文教程学习系列24--如何用Pytest去设计接口测试用例并执行
python·学习·测试工具·pytest
zhangfeng11334 小时前
ai 模型加密,强化版终极防盗方案 支持烧录的显卡列表
人工智能·pytorch·python