pytorch神经网络训练(AlexNet)

  • 导包

    import os

    import torch

    import torch.nn as nn

    import torch.optim as optim

    from torch.utils.data import Dataset, DataLoader

    from PIL import Image

    from torchvision import models, transforms

  • 定义自定义图像数据集

    class CustomImageDataset(Dataset):

定义一个自定义的图像数据集类,继承自Dataset

复制代码
def __init__(self, main_dir, transform=None): 

初始化方法,接收主目录和转换方法

复制代码
        self.main_dir = main_dir 

主目录,包含多个子目录,每个子目录包含同一类别的图像

复制代码
        self.transform = transform

图像转换方法,用于对图像进行预处理

复制代码
        self.files = [] 

存储所有图像文件的路径

复制代码
        self.labels = [] 

存储所有图像的标签

复制代码
        self.label_to_index = {} 

创建一个字典,用于将标签映射到索引

复制代码
        for index, label in enumerate(os.listdir(main_dir)):

遍历主目录中的所有子目录

复制代码
          self.label_to_index[label] = index 

           label_dir = os.path.join(main_dir, label) 

将标签映射到索引,构建标签子目录的路径

复制代码
           if os.path.isdir(label_dir): 

               for file in os.listdir(label_dir): 

                    self.files.append(os.path.join(label_dir, file))

                    self.labels.append(label) 

如果是目录,遍历目录中的所有文件,将文件路径添加到列表,将标签添加到列表

复制代码
def __len__(self):

定义数据集的长度

复制代码
        return len(self.files) 

返回文件列表的长度

复制代码
def __getitem__(self, idx): 

定义获取数据集中单个样本的方法

复制代码
        image = Image.open(self.files[idx]) 

        label = self.labels[idx] 

        if self.transform: 

            image = self.transform(image) 

        return image, self.label_to_index[label] 

打开图像文件,获取图像的标签,如果有转换方法,对图像进行转换,返回图像和对应的标签索引

  • 定义数据转换

    transform = transforms.Compose([

    复制代码
      transforms.Resize((227, 227)),  # AlexNet的输入图像大小
    
      transforms.RandomHorizontalFlip(),  # 随机水平翻转
    
      transforms.RandomRotation(10),  # 随机旋转
    
      transforms.ToTensor(),
    
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # AlexNet的标准化

    ])

  • 创建数据集

    dataset = CustomImageDataset(main_dir="D:\图像处理、深度学习\flowers", transform=transform)

  • 创建数据加载器

    data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

  • 加载预训练的AlexNet模型

    alexnet_model = models.alexnet(pretrained=True)

  • 修改最后几层以适应新的分类任务

    num_ftrs = alexnet_model.classifier[6].in_features

    alexnet_model.classifier[6] = nn.Linear(num_ftrs, len(dataset.label_to_index))

  • 定义损失函数和优化器

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.Adam(alexnet_model.parameters(), lr=0.0001)

  • 如果有多个GPU,可以使用nn.DataParallel来并行化模型

    if torch.cuda.device_count() > 1:

    复制代码
      alexnet_model = nn.DataParallel(alexnet_model)
  • 将模型发送到GPU(如果可用)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    alexnet_model.to(device)

  • 模型评估

    def evaluate_model(model, data_loader, device):

    复制代码
      model.eval()  # 将模型设置为评估模式
    
      correct = 0
    
      total = 0
    
      with torch.no_grad():  # 在这个块中,所有计算都不会计算梯度
    
          for images, labels in data_loader:
    
              images, labels = images.to(device), labels.to(device)
    
              outputs = model(images)
    
              _, predicted = torch.max(outputs.data, 1)
    
              total += labels.size(0)
    
              correct += (predicted == labels).sum().item()
    
      accuracy = 100 * correct / total
    
      return accuracy
  • 训练模型

    num_epochs = 10

    for epoch in range(num_epochs):

    复制代码
      alexnet_model.train()
    
      running_loss = 0.0
    
      for images, labels in data_loader:
    
          images, labels = images.to(device), labels.to(device)

前向传播

复制代码
        outputs = alexnet_model(images)

        loss = criterion(outputs, labels)

反向传播和优化

复制代码
        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        running_loss += loss.item()

在每个epoch结束后评估模型

复制代码
    train_accuracy = evaluate_model(alexnet_model, data_loader, device)

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%')
相关推荐
大飞记Python1 天前
当GitHub不再纯粹:Python自动化测试的未来是AI还是危机?
python·github
eqwaak01 天前
Matplotlib 动画显示进阶:交互式控制、3D 动画与未来趋势
python·tcp/ip·3d·语言模型·matplotlib
GilgameshJSS1 天前
【学习K230-例程23】GT6700-音频FFT柱状图
python·学习·音视频
I'm a winner1 天前
第七章:AI进阶之------输入与输出函数(一)
开发语言·人工智能·python·深度学习·神经网络·microsoft·机器学习
ERP老兵_冷溪虎山1 天前
Python/JS/Go/Java同步学习(第十三篇)四语言“字符串转码解码“对照表: 财务“小南“纸式转码术处理凭证乱码崩溃(附源码/截图/参数表/避坑指南)
java·后端·python
独行soc1 天前
2025年渗透测试面试题总结-67(题目+回答)
网络·python·安全·web安全·网络安全·adb·渗透测试
eybk1 天前
用python的socket写一个局域网传输文件的程序
服务器·网络·python
程序员的世界你不懂1 天前
【Flask】实现一个前后端一体的项目-脚手架
后端·python·flask
花酒锄作田1 天前
[MCP][01]简介与概念
python·llm·mcp
Python私教1 天前
Django全栈班v1.04 Python基础语法 20250912 上午
后端·python·django