[Pytorch案例实践009]基于卷积神经网络和通道注意力机制的草莓生长阶段分类实战

一、项目介绍

项目概述

该项目是一个使用PyTorch构建的深度学习分类模型,用于图像分类任务。模型采用了一种简单的卷积神经网络(CNN),并可以根据配置启用通道注意力机制以提高性能。项目的主要组成部分包括数据预处理、模型定义、训练流程、评估指标可视化以及最终的图像分类应用。

数据集

  • 训练集 :位于 I:\code\pytorch\cnn_channel_attention\datasets\train
  • 验证集 :位于 I:\code\pytorch\cnn_channel_attention\datasets\val。(训练集和验证集记得改为自己电脑的路径,建议也使用绝对路径)
  • 预处理:使用了标准化和缩放操作,将图像大小统一为224x224像素。

模型架构

  • 卷积层:两个卷积层,每个后面跟着批量归一化(Batch Normalization)和ReLU激活函数。
  • 注意力机制(可选):使用了一个简单的通道注意力机制,通过全局平均池化和1x1卷积来生成注意力权重。
  • 全连接层:两个全连接层,用于从卷积特征映射到类别标签。

训练过程

  • 损失函数:交叉熵损失(Cross-Entropy Loss)。
  • 优化器:随机梯度下降(SGD)。
  • 训练策略:使用了批量训练(batch training),并记录了训练损失和验证损失,以及训练准确率和验证准确率。
  • 可视化:训练结束后,绘制了损失和准确率的学习曲线,并保存了这些曲线。

推理应用

  • 模型加载:从训练好的模型文件中加载模型。
  • 图像处理:使用与训练时相同的预处理步骤处理输入图像。
  • 预测:使用加载的模型对图像进行分类,并获取预测类别和置信度。
  • 结果展示:在图像上绘制类别和置信度,并保存结果图像到指定文件夹。

关键点

  • 模型灵活性 :模型可以通过设置 use_attention 参数来选择是否使用注意力机制。
  • 性能监控:通过绘制学习曲线来监控训练过程中的损失和准确率。
  • 应用扩展性:推理脚本可以处理单个图像或整个文件夹中的多个图像。

二、数据集介绍

数据集来源于2021 科大讯飞开发者大赛 (农作物生长情况识别挑战赛)开源数据集,链接目前是百度飞桨官网,数据集下载链接中国农业大学_农作物生长情况识别挑战赛_数据集_数据集-飞桨AI Studio星河社区 (baidu.com)

通过作物不同生长时期的特点可以对作物的生长情况进行识别,给出合理的作物生长阶段。本次大赛提供了大量植株在营养生长阶段的生长情况图片作为样本,参赛选手需基于提供的样本构建模型,对样本生长态势进行检测,判断其生长情况,并将生长情况在csv文件中对应标定出来,给出图像对应的生长阶段。

本次实验对数据集进行了处理,已经划分为训练集验证集和测试集,训练集和验证集下面四个文件夹,草莓幼苗期,开花期,挂果期和成熟期,简单粗暴,直接用拼音拼写的,没用对应英文。

三、通道注意力机制

这里介绍这篇博文,通俗易懂,建议学习通俗易懂理解通道注意力机制(CAM)与空间注意力机制(SAM)-CSDN博客

通道注意力机制是一种用于卷积神经网络(CNNs)的注意力机制,它通过调整不同特征图的重要性来增强网络的表现力。这种机制可以显著提高模型的性能,尤其是在处理视觉任务如图像分类、目标检测等时。

通道注意力机制原理

通道注意力机制的基本思想是让网络学会关注某些特征图而忽略其他特征图。这通常通过全局上下文信息来实现,即网络通过计算所有位置的特征来获得一个全局的特征表示,然后基于该表示生成每个通道的权重。

具体步骤:
  1. 全局平均池化:对每个特征图进行全局平均池化操作,得到每个特征图的一个标量值。
  2. 全连接层:将这些标量值传递给一个或多个全连接层,产生通道权重。
  3. 激活函数:通常使用非线性激活函数(如ReLU、Sigmoid等)来生成最终的通道权重。
  4. 权重乘法:将生成的权重与原始特征图相乘,以调整每个特征图的重要性。

数学公式:

四、代码

训练代码:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import os


# 导入之前定义的 CNN 网络
from cnn_model import CNN

# 定义超参数
BATCH_SIZE = 4
NUM_EPOCHS = 50
LEARNING_RATE = 0.0001
USE_ATTENTION = True  # 设置是否使用通道注意力机制

# 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(model, train_loader, val_loader, num_epochs, learning_rate, use_attention):
    model.to(device)  # 将模型移动到 GPU 或 CPU
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)

    best_val_accuracy = 0.0
    best_model_state = None

    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        correct_train = 0
        total_train = 0

        with tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch') as pbar:
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)  # 移动数据到 GPU 或 CPU

                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total_train += labels.size(0)
                correct_train += (predicted == labels).sum().item()

                pbar.set_postfix({'Train Loss': train_loss / len(train_loader), 'Train Acc': correct_train / total_train})

        train_losses.append(train_loss / len(train_loader))
        train_accuracies.append(correct_train / total_train)

        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        val_losses.append(val_loss / len(val_loader))
        val_accuracy = correct_val / total_val
        val_accuracies.append(val_accuracy)

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_model_state = model.state_dict()

    # 绘制损失和准确率曲线
    plot_curves(train_losses, val_losses, train_accuracies, val_accuracies)


    # 保存最好的模型
    save_path = 'best_model.pth'
    torch.save(best_model_state, save_path)
    print(f"Best model saved at {save_path}")




def plot_curves(train_losses, val_losses, train_accuracies, val_accuracies):
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss Curve')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(val_accuracies, label='Val Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Accuracy Curve')
    plt.legend()

    plt.tight_layout()  # 调整布局以适应标题
    plt.savefig('loss_and_accuracy_curves.png')  # 保存图表到文件
    plt.show()  # 显示图表




if __name__ == "__main__":
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.Resize((224, 224))
    ])

    train_dataset = datasets.ImageFolder(root=r'I:\code\pytorch\cnn_channel_attention\datasets\train', transform=transform)
    val_dataset = datasets.ImageFolder(root=r'I:\code\pytorch\cnn_channel_attention\datasets\val', transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    num_classes = len(train_dataset.classes)
    model = CNN(num_classes, USE_ATTENTION)

    train(model, train_loader, val_loader, NUM_EPOCHS, LEARNING_RATE, USE_ATTENTION)

测试代码:

python 复制代码
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont
import os
from cnn_model import CNN

def load_model(model_path, num_classes, device):
    model = CNN(num_classes=4, use_attention=True)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

def predict(model, image, device, transform):
    image = transform(image).unsqueeze(0).to(device)
    outputs = model(image)
    _, preds = torch.max(outputs, 1)
    confidence = nn.functional.softmax(outputs, dim=1)[0][preds].item()
    return preds.item(), confidence


def draw_label(image, label, confidence):
    draw = ImageDraw.Draw(image)

    # 使用更大的字体
    try:
        font = ImageFont.truetype("arial", 30)  # 使用 Arial 字体,字号为 36
    except IOError:
        font = ImageFont.load_default()  # 如果 Arial 字体不可用,使用默认字体

    text = f"{label}: {confidence:.2f}"

    # 使用 textbbox 获取文本边界框
    text_bbox = draw.textbbox((20, 20), text, font=font)
    text_width = text_bbox[2] - text_bbox[0]
    text_height = text_bbox[3] - text_bbox[1]

    # 矩形背景框
    position = (20, 20)
    draw.rectangle([position, (position[0] + text_width, position[1] + text_height)], fill="black")
    draw.text(position, text, fill="white", font=font)
    return image


def process_image(model, image_path, output_dir, device, transform, class_names):
    image = Image.open(image_path).convert("RGB")
    pred, confidence = predict(model, image, device, transform)
    label = class_names[pred]
    image_with_label = draw_label(image, label, confidence)
    output_path = os.path.join(output_dir, os.path.basename(image_path))
    image_with_label.save(output_path)
    image_with_label.show()  # 显示处理后的图像

def process_folder(model, folder_path, output_dir, device, transform, class_names):
    os.makedirs(output_dir, exist_ok=True)
    for filename in os.listdir(folder_path):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
            image_path = os.path.join(folder_path, filename)
            process_image(model, image_path, output_dir, device, transform, class_names)

def main():
    # 硬编码参数
    model_path = r'I:\code\pytorch\cnn_channel_attention\best_model.pth'  # 模型权重文件路径
    input_path = r'I:\code\pytorch\cnn_channel_attention\datasets\test\testA\test_13.jpg'  # 输入图片或文件夹路径
    output_dir = r'I:\code\pytorch\cnn_channel_attention\result'  # 输出保存路径
    num_classes = 4  # 分类任务的类别数
    class_names = ['chengshu', 'guaguo', 'kaihua', 'youmiao']  # 类别名称

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = load_model(model_path, num_classes, device)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.Resize((224, 224))
    ])

    if os.path.isfile(input_path):
        process_image(model, input_path, output_dir, device, transform, class_names)
    elif os.path.isdir(input_path):
        process_folder(model, input_path, output_dir, device, transform, class_names)
    else:
        print("Invalid input path. Must be a file or directory.")

if __name__ == "__main__":
    main()

五、结果展示

模型迭代50轮之后,在训练集和验证集上损失函数以及准确率曲线如下:

测试集上的测试结果:

六、总结

此次我们更换了新的数据集,构建了基于通道注意力机制的卷积神经网络用于草莓生长阶段的识别,效果较好。希望能通过这些实际的例子,逐步掌握python和pytorch的使用,在实践中学习和进步,养成独立思考的能力,求知若渴,虚心若愚,一起共勉。

相关推荐
字节数据平台几秒前
火山引擎数据飞轮探索零售企业大促新场景:下放营销活动权限
大数据·人工智能
啊啊啊六子13 分钟前
windows下安装wsl的ubuntu,同时配置深度学习环境
windows·深度学习·ubuntu
刘天远29 分钟前
django实现paypal订阅记录
后端·python·django
努力学习的啊张32 分钟前
消息称三星正与 OpenAI 洽谈,有望令 Galaxy AI 整合ChatGPT,三星都要和chatgpt合作了,你会使用chatgpt了吗?
人工智能·chatgpt
Together_CZ32 分钟前
GPT-4 Technical Report——GPT-4技术报告
人工智能·gpt-4
菜鸟小贤贤2 小时前
python+pytest+allure利用fix实现接口关联
python·macos·自动化·pytest
huaqianzkh2 小时前
人工智能大趋势下软件开发的未来
人工智能
vvvae12342 小时前
Python 网络爬虫操作指南
python
years_GG2 小时前
【Git多人开发与协作之团队的环境搭建】
spring boot·深度学习·vue·github·团队开发·个人开发
不灭蚊香2 小时前
神经网络归一化方法总结
深度学习·神经网络·in·归一化·gn·ln·bn