[Pytorch案例实践009]基于卷积神经网络和通道注意力机制的草莓生长阶段分类实战

一、项目介绍

项目概述

该项目是一个使用PyTorch构建的深度学习分类模型,用于图像分类任务。模型采用了一种简单的卷积神经网络(CNN),并可以根据配置启用通道注意力机制以提高性能。项目的主要组成部分包括数据预处理、模型定义、训练流程、评估指标可视化以及最终的图像分类应用。

数据集

  • 训练集 :位于 I:\code\pytorch\cnn_channel_attention\datasets\train
  • 验证集 :位于 I:\code\pytorch\cnn_channel_attention\datasets\val。(训练集和验证集记得改为自己电脑的路径,建议也使用绝对路径)
  • 预处理:使用了标准化和缩放操作,将图像大小统一为224x224像素。

模型架构

  • 卷积层:两个卷积层,每个后面跟着批量归一化(Batch Normalization)和ReLU激活函数。
  • 注意力机制(可选):使用了一个简单的通道注意力机制,通过全局平均池化和1x1卷积来生成注意力权重。
  • 全连接层:两个全连接层,用于从卷积特征映射到类别标签。

训练过程

  • 损失函数:交叉熵损失(Cross-Entropy Loss)。
  • 优化器:随机梯度下降(SGD)。
  • 训练策略:使用了批量训练(batch training),并记录了训练损失和验证损失,以及训练准确率和验证准确率。
  • 可视化:训练结束后,绘制了损失和准确率的学习曲线,并保存了这些曲线。

推理应用

  • 模型加载:从训练好的模型文件中加载模型。
  • 图像处理:使用与训练时相同的预处理步骤处理输入图像。
  • 预测:使用加载的模型对图像进行分类,并获取预测类别和置信度。
  • 结果展示:在图像上绘制类别和置信度,并保存结果图像到指定文件夹。

关键点

  • 模型灵活性 :模型可以通过设置 use_attention 参数来选择是否使用注意力机制。
  • 性能监控:通过绘制学习曲线来监控训练过程中的损失和准确率。
  • 应用扩展性:推理脚本可以处理单个图像或整个文件夹中的多个图像。

二、数据集介绍

数据集来源于2021 科大讯飞开发者大赛 (农作物生长情况识别挑战赛)开源数据集,链接目前是百度飞桨官网,数据集下载链接中国农业大学_农作物生长情况识别挑战赛_数据集_数据集-飞桨AI Studio星河社区 (baidu.com)

通过作物不同生长时期的特点可以对作物的生长情况进行识别,给出合理的作物生长阶段。本次大赛提供了大量植株在营养生长阶段的生长情况图片作为样本,参赛选手需基于提供的样本构建模型,对样本生长态势进行检测,判断其生长情况,并将生长情况在csv文件中对应标定出来,给出图像对应的生长阶段。

本次实验对数据集进行了处理,已经划分为训练集验证集和测试集,训练集和验证集下面四个文件夹,草莓幼苗期,开花期,挂果期和成熟期,简单粗暴,直接用拼音拼写的,没用对应英文。

三、通道注意力机制

这里介绍这篇博文,通俗易懂,建议学习通俗易懂理解通道注意力机制(CAM)与空间注意力机制(SAM)-CSDN博客

通道注意力机制是一种用于卷积神经网络(CNNs)的注意力机制,它通过调整不同特征图的重要性来增强网络的表现力。这种机制可以显著提高模型的性能,尤其是在处理视觉任务如图像分类、目标检测等时。

通道注意力机制原理

通道注意力机制的基本思想是让网络学会关注某些特征图而忽略其他特征图。这通常通过全局上下文信息来实现,即网络通过计算所有位置的特征来获得一个全局的特征表示,然后基于该表示生成每个通道的权重。

具体步骤:
  1. 全局平均池化:对每个特征图进行全局平均池化操作,得到每个特征图的一个标量值。
  2. 全连接层:将这些标量值传递给一个或多个全连接层,产生通道权重。
  3. 激活函数:通常使用非线性激活函数(如ReLU、Sigmoid等)来生成最终的通道权重。
  4. 权重乘法:将生成的权重与原始特征图相乘,以调整每个特征图的重要性。

数学公式:

四、代码

训练代码:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import os


# 导入之前定义的 CNN 网络
from cnn_model import CNN

# 定义超参数
BATCH_SIZE = 4
NUM_EPOCHS = 50
LEARNING_RATE = 0.0001
USE_ATTENTION = True  # 设置是否使用通道注意力机制

# 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(model, train_loader, val_loader, num_epochs, learning_rate, use_attention):
    model.to(device)  # 将模型移动到 GPU 或 CPU
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)

    best_val_accuracy = 0.0
    best_model_state = None

    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        correct_train = 0
        total_train = 0

        with tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch') as pbar:
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)  # 移动数据到 GPU 或 CPU

                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total_train += labels.size(0)
                correct_train += (predicted == labels).sum().item()

                pbar.set_postfix({'Train Loss': train_loss / len(train_loader), 'Train Acc': correct_train / total_train})

        train_losses.append(train_loss / len(train_loader))
        train_accuracies.append(correct_train / total_train)

        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        val_losses.append(val_loss / len(val_loader))
        val_accuracy = correct_val / total_val
        val_accuracies.append(val_accuracy)

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_model_state = model.state_dict()

    # 绘制损失和准确率曲线
    plot_curves(train_losses, val_losses, train_accuracies, val_accuracies)


    # 保存最好的模型
    save_path = 'best_model.pth'
    torch.save(best_model_state, save_path)
    print(f"Best model saved at {save_path}")




def plot_curves(train_losses, val_losses, train_accuracies, val_accuracies):
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss Curve')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(val_accuracies, label='Val Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Accuracy Curve')
    plt.legend()

    plt.tight_layout()  # 调整布局以适应标题
    plt.savefig('loss_and_accuracy_curves.png')  # 保存图表到文件
    plt.show()  # 显示图表




if __name__ == "__main__":
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.Resize((224, 224))
    ])

    train_dataset = datasets.ImageFolder(root=r'I:\code\pytorch\cnn_channel_attention\datasets\train', transform=transform)
    val_dataset = datasets.ImageFolder(root=r'I:\code\pytorch\cnn_channel_attention\datasets\val', transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    num_classes = len(train_dataset.classes)
    model = CNN(num_classes, USE_ATTENTION)

    train(model, train_loader, val_loader, NUM_EPOCHS, LEARNING_RATE, USE_ATTENTION)

测试代码:

python 复制代码
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont
import os
from cnn_model import CNN

def load_model(model_path, num_classes, device):
    model = CNN(num_classes=4, use_attention=True)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

def predict(model, image, device, transform):
    image = transform(image).unsqueeze(0).to(device)
    outputs = model(image)
    _, preds = torch.max(outputs, 1)
    confidence = nn.functional.softmax(outputs, dim=1)[0][preds].item()
    return preds.item(), confidence


def draw_label(image, label, confidence):
    draw = ImageDraw.Draw(image)

    # 使用更大的字体
    try:
        font = ImageFont.truetype("arial", 30)  # 使用 Arial 字体,字号为 36
    except IOError:
        font = ImageFont.load_default()  # 如果 Arial 字体不可用,使用默认字体

    text = f"{label}: {confidence:.2f}"

    # 使用 textbbox 获取文本边界框
    text_bbox = draw.textbbox((20, 20), text, font=font)
    text_width = text_bbox[2] - text_bbox[0]
    text_height = text_bbox[3] - text_bbox[1]

    # 矩形背景框
    position = (20, 20)
    draw.rectangle([position, (position[0] + text_width, position[1] + text_height)], fill="black")
    draw.text(position, text, fill="white", font=font)
    return image


def process_image(model, image_path, output_dir, device, transform, class_names):
    image = Image.open(image_path).convert("RGB")
    pred, confidence = predict(model, image, device, transform)
    label = class_names[pred]
    image_with_label = draw_label(image, label, confidence)
    output_path = os.path.join(output_dir, os.path.basename(image_path))
    image_with_label.save(output_path)
    image_with_label.show()  # 显示处理后的图像

def process_folder(model, folder_path, output_dir, device, transform, class_names):
    os.makedirs(output_dir, exist_ok=True)
    for filename in os.listdir(folder_path):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
            image_path = os.path.join(folder_path, filename)
            process_image(model, image_path, output_dir, device, transform, class_names)

def main():
    # 硬编码参数
    model_path = r'I:\code\pytorch\cnn_channel_attention\best_model.pth'  # 模型权重文件路径
    input_path = r'I:\code\pytorch\cnn_channel_attention\datasets\test\testA\test_13.jpg'  # 输入图片或文件夹路径
    output_dir = r'I:\code\pytorch\cnn_channel_attention\result'  # 输出保存路径
    num_classes = 4  # 分类任务的类别数
    class_names = ['chengshu', 'guaguo', 'kaihua', 'youmiao']  # 类别名称

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = load_model(model_path, num_classes, device)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.Resize((224, 224))
    ])

    if os.path.isfile(input_path):
        process_image(model, input_path, output_dir, device, transform, class_names)
    elif os.path.isdir(input_path):
        process_folder(model, input_path, output_dir, device, transform, class_names)
    else:
        print("Invalid input path. Must be a file or directory.")

if __name__ == "__main__":
    main()

五、结果展示

模型迭代50轮之后,在训练集和验证集上损失函数以及准确率曲线如下:

测试集上的测试结果:

六、总结

此次我们更换了新的数据集,构建了基于通道注意力机制的卷积神经网络用于草莓生长阶段的识别,效果较好。希望能通过这些实际的例子,逐步掌握python和pytorch的使用,在实践中学习和进步,养成独立思考的能力,求知若渴,虚心若愚,一起共勉。

相关推荐
农民小飞侠1 分钟前
python AutoGen接入开源模型xLAM-7b-fc-r,测试function calling的功能
开发语言·python
战神刘玉栋4 分钟前
《程序猿之设计模式实战 · 观察者模式》
python·观察者模式·设计模式
敲代码不忘补水6 分钟前
Python 项目实践:简单的计算器
开发语言·python·json·项目实践
WPG大大通10 分钟前
有奖直播 | onsemi IPM 助力汽车电气革命及电子化时代冷热管理
大数据·人工智能·汽车·方案·电气·大大通·研讨会
百锦再12 分钟前
AI对汽车行业的冲击和比亚迪新能源汽车市场占比
人工智能·汽车
ws20190715 分钟前
抓机遇,促发展——2025第十二届广州国际汽车零部件加工技术及汽车模具展览会
大数据·人工智能·汽车
Zhangci]19 分钟前
Opencv图像预处理(三)
人工智能·opencv·计算机视觉
新加坡内哥谈技术36 分钟前
口哨声、歌声、boing声和biotwang声:用AI识别鲸鱼叫声
人工智能·自然语言处理
wx7408513261 小时前
小琳AI课堂:机器学习
人工智能·机器学习
FL16238631291 小时前
[数据集][目标检测]车油口挡板开关闭合检测数据集VOC+YOLO格式138张2类别
人工智能·yolo·目标检测