[Pytorch案例实践009]基于卷积神经网络和通道注意力机制的草莓生长阶段分类实战

一、项目介绍

项目概述

该项目是一个使用PyTorch构建的深度学习分类模型,用于图像分类任务。模型采用了一种简单的卷积神经网络(CNN),并可以根据配置启用通道注意力机制以提高性能。项目的主要组成部分包括数据预处理、模型定义、训练流程、评估指标可视化以及最终的图像分类应用。

数据集

  • 训练集 :位于 I:\code\pytorch\cnn_channel_attention\datasets\train
  • 验证集 :位于 I:\code\pytorch\cnn_channel_attention\datasets\val。(训练集和验证集记得改为自己电脑的路径,建议也使用绝对路径)
  • 预处理:使用了标准化和缩放操作,将图像大小统一为224x224像素。

模型架构

  • 卷积层:两个卷积层,每个后面跟着批量归一化(Batch Normalization)和ReLU激活函数。
  • 注意力机制(可选):使用了一个简单的通道注意力机制,通过全局平均池化和1x1卷积来生成注意力权重。
  • 全连接层:两个全连接层,用于从卷积特征映射到类别标签。

训练过程

  • 损失函数:交叉熵损失(Cross-Entropy Loss)。
  • 优化器:随机梯度下降(SGD)。
  • 训练策略:使用了批量训练(batch training),并记录了训练损失和验证损失,以及训练准确率和验证准确率。
  • 可视化:训练结束后,绘制了损失和准确率的学习曲线,并保存了这些曲线。

推理应用

  • 模型加载:从训练好的模型文件中加载模型。
  • 图像处理:使用与训练时相同的预处理步骤处理输入图像。
  • 预测:使用加载的模型对图像进行分类,并获取预测类别和置信度。
  • 结果展示:在图像上绘制类别和置信度,并保存结果图像到指定文件夹。

关键点

  • 模型灵活性 :模型可以通过设置 use_attention 参数来选择是否使用注意力机制。
  • 性能监控:通过绘制学习曲线来监控训练过程中的损失和准确率。
  • 应用扩展性:推理脚本可以处理单个图像或整个文件夹中的多个图像。

二、数据集介绍

数据集来源于2021 科大讯飞开发者大赛 (农作物生长情况识别挑战赛)开源数据集,链接目前是百度飞桨官网,数据集下载链接中国农业大学_农作物生长情况识别挑战赛_数据集_数据集-飞桨AI Studio星河社区 (baidu.com)

通过作物不同生长时期的特点可以对作物的生长情况进行识别,给出合理的作物生长阶段。本次大赛提供了大量植株在营养生长阶段的生长情况图片作为样本,参赛选手需基于提供的样本构建模型,对样本生长态势进行检测,判断其生长情况,并将生长情况在csv文件中对应标定出来,给出图像对应的生长阶段。

本次实验对数据集进行了处理,已经划分为训练集验证集和测试集,训练集和验证集下面四个文件夹,草莓幼苗期,开花期,挂果期和成熟期,简单粗暴,直接用拼音拼写的,没用对应英文。

三、通道注意力机制

这里介绍这篇博文,通俗易懂,建议学习通俗易懂理解通道注意力机制(CAM)与空间注意力机制(SAM)-CSDN博客

通道注意力机制是一种用于卷积神经网络(CNNs)的注意力机制,它通过调整不同特征图的重要性来增强网络的表现力。这种机制可以显著提高模型的性能,尤其是在处理视觉任务如图像分类、目标检测等时。

通道注意力机制原理

通道注意力机制的基本思想是让网络学会关注某些特征图而忽略其他特征图。这通常通过全局上下文信息来实现,即网络通过计算所有位置的特征来获得一个全局的特征表示,然后基于该表示生成每个通道的权重。

具体步骤:
  1. 全局平均池化:对每个特征图进行全局平均池化操作,得到每个特征图的一个标量值。
  2. 全连接层:将这些标量值传递给一个或多个全连接层,产生通道权重。
  3. 激活函数:通常使用非线性激活函数(如ReLU、Sigmoid等)来生成最终的通道权重。
  4. 权重乘法:将生成的权重与原始特征图相乘,以调整每个特征图的重要性。

数学公式:

四、代码

训练代码:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import os


# 导入之前定义的 CNN 网络
from cnn_model import CNN

# 定义超参数
BATCH_SIZE = 4
NUM_EPOCHS = 50
LEARNING_RATE = 0.0001
USE_ATTENTION = True  # 设置是否使用通道注意力机制

# 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(model, train_loader, val_loader, num_epochs, learning_rate, use_attention):
    model.to(device)  # 将模型移动到 GPU 或 CPU
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)

    best_val_accuracy = 0.0
    best_model_state = None

    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        correct_train = 0
        total_train = 0

        with tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch') as pbar:
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)  # 移动数据到 GPU 或 CPU

                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total_train += labels.size(0)
                correct_train += (predicted == labels).sum().item()

                pbar.set_postfix({'Train Loss': train_loss / len(train_loader), 'Train Acc': correct_train / total_train})

        train_losses.append(train_loss / len(train_loader))
        train_accuracies.append(correct_train / total_train)

        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        val_losses.append(val_loss / len(val_loader))
        val_accuracy = correct_val / total_val
        val_accuracies.append(val_accuracy)

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_model_state = model.state_dict()

    # 绘制损失和准确率曲线
    plot_curves(train_losses, val_losses, train_accuracies, val_accuracies)


    # 保存最好的模型
    save_path = 'best_model.pth'
    torch.save(best_model_state, save_path)
    print(f"Best model saved at {save_path}")




def plot_curves(train_losses, val_losses, train_accuracies, val_accuracies):
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss Curve')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(val_accuracies, label='Val Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Accuracy Curve')
    plt.legend()

    plt.tight_layout()  # 调整布局以适应标题
    plt.savefig('loss_and_accuracy_curves.png')  # 保存图表到文件
    plt.show()  # 显示图表




if __name__ == "__main__":
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.Resize((224, 224))
    ])

    train_dataset = datasets.ImageFolder(root=r'I:\code\pytorch\cnn_channel_attention\datasets\train', transform=transform)
    val_dataset = datasets.ImageFolder(root=r'I:\code\pytorch\cnn_channel_attention\datasets\val', transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    num_classes = len(train_dataset.classes)
    model = CNN(num_classes, USE_ATTENTION)

    train(model, train_loader, val_loader, NUM_EPOCHS, LEARNING_RATE, USE_ATTENTION)

测试代码:

python 复制代码
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont
import os
from cnn_model import CNN

def load_model(model_path, num_classes, device):
    model = CNN(num_classes=4, use_attention=True)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

def predict(model, image, device, transform):
    image = transform(image).unsqueeze(0).to(device)
    outputs = model(image)
    _, preds = torch.max(outputs, 1)
    confidence = nn.functional.softmax(outputs, dim=1)[0][preds].item()
    return preds.item(), confidence


def draw_label(image, label, confidence):
    draw = ImageDraw.Draw(image)

    # 使用更大的字体
    try:
        font = ImageFont.truetype("arial", 30)  # 使用 Arial 字体,字号为 36
    except IOError:
        font = ImageFont.load_default()  # 如果 Arial 字体不可用,使用默认字体

    text = f"{label}: {confidence:.2f}"

    # 使用 textbbox 获取文本边界框
    text_bbox = draw.textbbox((20, 20), text, font=font)
    text_width = text_bbox[2] - text_bbox[0]
    text_height = text_bbox[3] - text_bbox[1]

    # 矩形背景框
    position = (20, 20)
    draw.rectangle([position, (position[0] + text_width, position[1] + text_height)], fill="black")
    draw.text(position, text, fill="white", font=font)
    return image


def process_image(model, image_path, output_dir, device, transform, class_names):
    image = Image.open(image_path).convert("RGB")
    pred, confidence = predict(model, image, device, transform)
    label = class_names[pred]
    image_with_label = draw_label(image, label, confidence)
    output_path = os.path.join(output_dir, os.path.basename(image_path))
    image_with_label.save(output_path)
    image_with_label.show()  # 显示处理后的图像

def process_folder(model, folder_path, output_dir, device, transform, class_names):
    os.makedirs(output_dir, exist_ok=True)
    for filename in os.listdir(folder_path):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
            image_path = os.path.join(folder_path, filename)
            process_image(model, image_path, output_dir, device, transform, class_names)

def main():
    # 硬编码参数
    model_path = r'I:\code\pytorch\cnn_channel_attention\best_model.pth'  # 模型权重文件路径
    input_path = r'I:\code\pytorch\cnn_channel_attention\datasets\test\testA\test_13.jpg'  # 输入图片或文件夹路径
    output_dir = r'I:\code\pytorch\cnn_channel_attention\result'  # 输出保存路径
    num_classes = 4  # 分类任务的类别数
    class_names = ['chengshu', 'guaguo', 'kaihua', 'youmiao']  # 类别名称

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = load_model(model_path, num_classes, device)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.Resize((224, 224))
    ])

    if os.path.isfile(input_path):
        process_image(model, input_path, output_dir, device, transform, class_names)
    elif os.path.isdir(input_path):
        process_folder(model, input_path, output_dir, device, transform, class_names)
    else:
        print("Invalid input path. Must be a file or directory.")

if __name__ == "__main__":
    main()

五、结果展示

模型迭代50轮之后,在训练集和验证集上损失函数以及准确率曲线如下:

测试集上的测试结果:

六、总结

此次我们更换了新的数据集,构建了基于通道注意力机制的卷积神经网络用于草莓生长阶段的识别,效果较好。希望能通过这些实际的例子,逐步掌握python和pytorch的使用,在实践中学习和进步,养成独立思考的能力,求知若渴,虚心若愚,一起共勉。

相关推荐
神奇夜光杯6 分钟前
Python酷库之旅-第三方库Pandas(202)
开发语言·人工智能·python·excel·pandas·标准库及第三方库·学习与成长
正义的彬彬侠8 分钟前
《XGBoost算法的原理推导》12-14决策树复杂度的正则化项 公式解析
人工智能·决策树·机器学习·集成学习·boosting·xgboost
千天夜17 分钟前
使用UDP协议传输视频流!(分片、缓存)
python·网络协议·udp·视频流
Debroon18 分钟前
RuleAlign 规则对齐框架:将医生的诊断规则形式化并注入模型,无需额外人工标注的自动对齐方法
人工智能
测试界的酸菜鱼21 分钟前
Python 大数据展示屏实例
大数据·开发语言·python
羊小猪~~25 分钟前
神经网络基础--什么是正向传播??什么是方向传播??
人工智能·pytorch·python·深度学习·神经网络·算法·机器学习
AI小杨26 分钟前
【车道线检测】一、传统车道线检测:基于霍夫变换的车道线检测史诗级详细教程
人工智能·opencv·计算机视觉·霍夫变换·车道线检测
晨曦_子画30 分钟前
编程语言之战:AI 之后的 Kotlin 与 Java
android·java·开发语言·人工智能·kotlin
道可云32 分钟前
道可云人工智能&元宇宙每日资讯|2024国际虚拟现实创新大会将在青岛举办
大数据·人工智能·3d·机器人·ar·vr
人工智能培训咨询叶梓41 分钟前
探索开放资源上指令微调语言模型的现状
人工智能·语言模型·自然语言处理·性能优化·调优·大模型微调·指令微调