pytorch利用简单CNN实现葡萄病虫害图片识别

1 前言

之前我开发了一个葡萄病虫害的可视化系统,最近就想给这个系统增加2个功能,一个是对接一个AI助手,可以进行葡萄病虫害的咨询,直接对接千问大模型,这个在之前的博文里已经介绍过对接方法了,第二个是做一个根据图片识别病虫害(分类)的功能。

2 实现思路

实现思路是想通过pytorch做一个CNN模型的训练,然后根据给出的图片进行类型的预测。

3 数据集

我没有数据集,仅有的一些图片是之前委托我做程序的bro给的,所以我们训练的时候图片并不多,不过这个没关系,数据集可以后期扩充,目前先实现功能部分

4 安装依赖

该功能由python语言实现,使用pip 安装如下依赖

复制代码
torch
torchvision
matplotlib

5 数据位置

数据类似这样去组织,一种类型建一个文件夹,然后同一类型的图片放一起。

6 训练模型

python 复制代码
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # 调整图片大小
    transforms.ToTensor(),            # 转换为 Tensor
])

# 加载数据集
data_dir = 'dataset'
dataset = datasets.ImageFolder(root=data_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=8, shuffle=True)

# 获取类别标签
class_names = dataset.classes
num_classes = len(class_names)

# 构建简单的 CNN 模型
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(32 * 32 * 32, 128)  # 128 = (128/2)*(128/2)*(32/2)*(32/2)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 32 * 32)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 实例化模型
model = SimpleCNN(num_classes)

# 训练配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(torch.cuda.is_available())
print(f'Using device: {device}')

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练循环
num_epochs = 10

for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}")

print("Training finished.")

# 保存模型
torch.save(model.state_dict(), 'plant_disease_model.pth')

执行代码之后得到模型文件:

7 预测模型

然后我们随便去找些病虫害图片,来做预测

python 复制代码
import torch
from torchvision import transforms
from PIL import Image
import os
import torch.nn as nn
import torch.nn.functional as F

# 定义简单的 CNN 模型结构
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(32 * 32 * 32, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 32 * 32)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 预测函数
def predict(image_path, model, class_names):
    # 定义图像预处理
    transform = transforms.Compose([
        transforms.Resize((128, 128)),  # 统一大小
        transforms.ToTensor(),
    ])

    # 加载和预处理图像
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0)  # 增加批次维度

    # 将图像输入模型进行预测
    model.eval()  # 设置模型为评估模式
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)

    # 返回预测的类别
    return class_names[predicted.item()]

if __name__ == "__main__":
    # 加载训练好的模型
    num_classes = 2  # 根据你的数据集类别数量修改
    model = SimpleCNN(num_classes)
    model.load_state_dict(torch.load('plant_disease_model.pth'))
    model.eval()

    # 类别名称(根据你的数据集修改)
    class_names = ['disease1', 'disease2']  # 替换为实际类别名称

    # 测试预测
    test_image_path = '1.jpg'  # 替换为测试图像的路径
    predicted_class = predict(test_image_path, model, class_names)
    print(f'Predicted class: {predicted_class}')

8 结果

给出的图片和图片预测结果如下:

相关推荐
肥猪猪爸2 分钟前
计算机视觉中的Mask是干啥的
图像处理·人工智能·深度学习·神经网络·目标检测·计算机视觉·视觉检测
yiersansiwu123d11 分钟前
工程化破局,2025年AI大模型的价值兑现之路
人工智能
_codemonster13 分钟前
自然语言处理容易混淆知识点(六)SentenceTransformer的训练参数
人工智能·自然语言处理
All The Way North-15 分钟前
PyTorch ExponentialLR:按指数学习率衰减原理、API、参数详解、实战
pytorch·深度学习·学习率优化算法·按指数学习率衰减
2501_9413297215 分钟前
基于DETR的血细胞显微图像检测与分类方法研究【原创】_1
人工智能·数据挖掘
人工智能训练19 分钟前
Docker Desktop WSL 集成配置宝典:选项拆解 + 精准设置指南
linux·运维·服务器·人工智能·docker·容器·ai编程
golang学习记28 分钟前
VS Code使用 GitHub Copilot 高效重构代码:10 大实战技巧 + 自定义指令封装指南
人工智能
阿杰学AI31 分钟前
AI核心知识62——大语言模型之PRM (简洁且通俗易懂版)
人工智能·ai·语言模型·自然语言处理·aigc·prm·过程奖励模型
TG:@yunlaoda360 云老大31 分钟前
华为云国际站代理商的CBR支持哪些云服务备份?
网络·人工智能·华为云
用户51914958484534 分钟前
基础设施模板CLI工具:Boilerplates
人工智能·aigc