pytorch神经网络训练(AlexNet)

  • 导包

    import os

    import torch

    import torch.nn as nn

    import torch.optim as optim

    from torch.utils.data import Dataset, DataLoader

    from PIL import Image

    from torchvision import models, transforms

  • 定义自定义图像数据集

    class CustomImageDataset(Dataset):

定义一个自定义的图像数据集类,继承自Dataset

复制代码
def __init__(self, main_dir, transform=None): 

初始化方法,接收主目录和转换方法

复制代码
        self.main_dir = main_dir 

主目录,包含多个子目录,每个子目录包含同一类别的图像

复制代码
        self.transform = transform

图像转换方法,用于对图像进行预处理

复制代码
        self.files = [] 

存储所有图像文件的路径

复制代码
        self.labels = [] 

存储所有图像的标签

复制代码
        self.label_to_index = {} 

创建一个字典,用于将标签映射到索引

复制代码
        for index, label in enumerate(os.listdir(main_dir)):

遍历主目录中的所有子目录

复制代码
          self.label_to_index[label] = index 

           label_dir = os.path.join(main_dir, label) 

将标签映射到索引,构建标签子目录的路径

复制代码
           if os.path.isdir(label_dir): 

               for file in os.listdir(label_dir): 

                    self.files.append(os.path.join(label_dir, file))

                    self.labels.append(label) 

如果是目录,遍历目录中的所有文件,将文件路径添加到列表,将标签添加到列表

复制代码
def __len__(self):

定义数据集的长度

复制代码
        return len(self.files) 

返回文件列表的长度

复制代码
def __getitem__(self, idx): 

定义获取数据集中单个样本的方法

复制代码
        image = Image.open(self.files[idx]) 

        label = self.labels[idx] 

        if self.transform: 

            image = self.transform(image) 

        return image, self.label_to_index[label] 

打开图像文件,获取图像的标签,如果有转换方法,对图像进行转换,返回图像和对应的标签索引

  • 定义数据转换

    transform = transforms.Compose([

    复制代码
      transforms.Resize((227, 227)),  # AlexNet的输入图像大小
    
      transforms.RandomHorizontalFlip(),  # 随机水平翻转
    
      transforms.RandomRotation(10),  # 随机旋转
    
      transforms.ToTensor(),
    
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # AlexNet的标准化

    ])

  • 创建数据集

    dataset = CustomImageDataset(main_dir="D:\图像处理、深度学习\flowers", transform=transform)

  • 创建数据加载器

    data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

  • 加载预训练的AlexNet模型

    alexnet_model = models.alexnet(pretrained=True)

  • 修改最后几层以适应新的分类任务

    num_ftrs = alexnet_model.classifier[6].in_features

    alexnet_model.classifier[6] = nn.Linear(num_ftrs, len(dataset.label_to_index))

  • 定义损失函数和优化器

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.Adam(alexnet_model.parameters(), lr=0.0001)

  • 如果有多个GPU,可以使用nn.DataParallel来并行化模型

    if torch.cuda.device_count() > 1:

    复制代码
      alexnet_model = nn.DataParallel(alexnet_model)
  • 将模型发送到GPU(如果可用)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    alexnet_model.to(device)

  • 模型评估

    def evaluate_model(model, data_loader, device):

    复制代码
      model.eval()  # 将模型设置为评估模式
    
      correct = 0
    
      total = 0
    
      with torch.no_grad():  # 在这个块中,所有计算都不会计算梯度
    
          for images, labels in data_loader:
    
              images, labels = images.to(device), labels.to(device)
    
              outputs = model(images)
    
              _, predicted = torch.max(outputs.data, 1)
    
              total += labels.size(0)
    
              correct += (predicted == labels).sum().item()
    
      accuracy = 100 * correct / total
    
      return accuracy
  • 训练模型

    num_epochs = 10

    for epoch in range(num_epochs):

    复制代码
      alexnet_model.train()
    
      running_loss = 0.0
    
      for images, labels in data_loader:
    
          images, labels = images.to(device), labels.to(device)

前向传播

复制代码
        outputs = alexnet_model(images)

        loss = criterion(outputs, labels)

反向传播和优化

复制代码
        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        running_loss += loss.item()

在每个epoch结束后评估模型

复制代码
    train_accuracy = evaluate_model(alexnet_model, data_loader, device)

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%')
相关推荐
MATLAB代码顾问4 小时前
5大智能算法优化标准测试函数对比(Python实现)
开发语言·python
ting94520004 小时前
Tornado 全栈技术深度指南:从原理到实战
人工智能·python·架构·tornado
果汁华4 小时前
Browserbase Skills:让 Claude Agent 真正“看见“网页世界
人工智能·python
ZhengEnCi4 小时前
04-缩放点积注意力代码实现 💻
人工智能·python
2zcode4 小时前
基于LSTM神经网络的金属材料机器学习本构模型研究(硕士级别)
神经网络·机器学习·lstm·金属材料
DeepReinforce5 小时前
三、AI量化投资:使用akshare获取A股主板20260430所有的涨停股票
python·量化·akshare·龙头战法
段一凡-华北理工大学5 小时前
【高炉炼铁领域炉温监测、预警、调控智能体设计与应用】~系列文章08:多模态数据融合:让数据更聪明
人工智能·python·高炉炼铁·ai赋能·工业智能体·高炉炉温
万粉变现经纪人5 小时前
如何解决 pip install llama-cpp-python 报错 未安装 CMake/Ninja 或 CPU 不支持 AVX 问题
开发语言·python·开源·aigc·pip·ai写作·llama
其实防守也摸鱼5 小时前
CTF密码学综合教学指南--第五章
开发语言·网络·笔记·python·安全·网络安全·密码学
callJJ6 小时前
Spring Data Redis 两种编程模型详解:同步 vs 响应式
java·spring boot·redis·python·spring