【无标题】

导入必要库

import torch # PyTorch核心库

import torch.nn as nn # 神经网络层

from torch.utils.data import DataLoader # 数据加载器(批量处理数据)

from torchvision import datasets, transforms, models # 数据集/数据增强/预训练模型

import os # 路径操作

---------------------- 1. 核心配置参数(新手可改这里) ----------------------

DEVICE = torch.device("cpu") # 无GPU则用CPU;有GPU可改为"cuda"

BATCH_SIZE = 8 # 每次处理8张图片(小批量,避免内存不足)

EPOCHS = 20 # 训练20轮(轮数越多,模型越准,但易过拟合)

DATA_DIR = "dataset" # 数据集路径(和项目结构对应)

SAVE_PATH = "plant_model.pth" # 训练好的模型保存路径

---------------------- 2. 数据预处理(关键!统一输入格式) ----------------------

定义训练集/验证集的预处理规则

---------------------- 2. 数据预处理(新增AI图像优化) ----------------------

data_transforms = {

"train": transforms.Compose([ # 训练集:AI图像增强

transforms.Resize((224, 224)),

transforms.RandomHorizontalFlip(),

新增AI级图像优化:随机调整对比度/亮度(增强叶片特征)

transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),

transforms.ToTensor(),

transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

]),

"val": transforms.Compose([ # 验证集:基础优化

transforms.Resize((224, 224)),

transforms.ToTensor(),

transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

])

}

加载数据集(自动读取文件夹结构,生成标签)

image_datasets = {

x: datasets.ImageFolder(os.path.join(DATA_DIR, x), data_transforms[x])

for x in ["train", "val"]

}

创建数据加载器(批量读取+打乱数据)

dataloaders = {

x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True)

for x in ["train", "val"]

}

获取类别名称(自动从文件夹名生成:['healthy', 'water_shortage', 'pest_disease'])

class_names = image_datasets["train"].classes

print("检测类别:", class_names) # 打印类别,确认数据集加载正确

---------------------- 3. 加载预训练模型(迁移学习,新手友好) ----------------------

加载MobileNetV2预训练模型(轻量级,适合电脑/嵌入式设备)

model = models.mobilenet_v2(pretrained=True)

冻结特征提取层(仅训练分类层,加快训练速度,避免过拟合)

for param in model.features.parameters():

param.requires_grad = False

修改最后一层:原模型输出1000类(ImageNet),改为3类(我们的数据集)

num_ftrs = model.classifier[1].in_features # 获取最后一层输入维度

model.classifier[1] = nn.Linear(num_ftrs, len(class_names)) # 这里len(class_names)是7

model = model.to(DEVICE) # 将模型移到CPU/GPU

---------------------- 4. 定义损失函数&优化器(模型学习的核心) ----------------------

criterion = nn.CrossEntropyLoss() # 交叉熵损失(适合分类任务)

优化器:Adam(自适应学习率,比SGD更优),只优化分类层参数

optimizer = torch.optim.Adam(model.classifier.parameters(), lr=0.0001)

---------------------- 5. 训练函数(核心逻辑) ----------------------

def train_model(model, criterion, optimizer, num_epochs=20):

遍历每一轮训练

for epoch in range(num_epochs):

print(f"\nEpoch {epoch+1}/{num_epochs}")

print("-" * 20)

复制代码
    # 分训练和验证两个阶段
    for phase in ["train", "val"]:
        if phase == "train":
            model.train()  # 训练模式:启用梯度下降
        else:
            model.eval()   # 验证模式:禁用梯度下降(只评估,不学习)

        running_loss = 0.0  # 累计损失
        running_corrects = 0  # 累计正确数

        # 遍历数据加载器中的每一批数据
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(DEVICE)  # 图片移到CPU/GPU
            labels = labels.to(DEVICE)  # 标签移到CPU/GPU

            # 梯度清零(避免上一轮梯度累积)
            optimizer.zero_grad()

            # 前向传播(预测)
            with torch.set_grad_enabled(phase == "train"):  # 训练阶段才计算梯度
                outputs = model(inputs)  # 模型输出(3个类的概率)
                _, preds = torch.max(outputs, 1)  # 取概率最大的类作为预测结果
                loss = criterion(outputs, labels)  # 计算损失(预测值vs真实值)

                # 训练阶段:反向传播+更新参数
                if phase == "train":
                    loss.backward()  # 反向传播(计算梯度)
                    optimizer.step()  # 更新模型参数

            # 统计损失和准确率
            running_loss += loss.item() * inputs.size(0)  # 累计损失
            running_corrects += torch.sum(preds == labels.data)  # 累计正确数

        # 计算本轮损失和准确率
        epoch_loss = running_loss / len(image_datasets[phase])
        epoch_acc = running_corrects.double() / len(image_datasets[phase])

        # 打印结果(直观看到训练效果)
        print(f"{phase} - 损失: {epoch_loss:.4f} | 准确率: {epoch_acc:.4f}")

# 训练完成,保存模型
torch.save(model.state_dict(), SAVE_PATH)
print(f"\n模型训练完成!已保存到:{SAVE_PATH}")
return model

启动训练(调用上面的函数)

if name == "main ":

train_model(model, criterion, optimizer, num_epochs=EPOCHS)

相关推荐
咚咚王者2 小时前
人工智能之核心基础 机器学习 第一章 基础概述
人工智能·机器学习
Secede.3 小时前
Windows + WSL2 + Docker + CudaToolkit:深度学习环境配置
windows·深度学习·docker
江上鹤.1484 小时前
Day 50 CBAM 注意力机制
人工智能·深度学习
人工智能培训4 小时前
深度学习—卷积神经网络(1)
人工智能·深度学习·神经网络·机器学习·cnn·知识图谱·dnn
云天徽上4 小时前
【机器学习】Kaggle案例之Rossmann连锁药店销售额预测:时间序列与机器学习完美融合的实战指南
机器学习·数据挖掘·kaggle
啊巴矲5 小时前
小白从零开始勇闯人工智能:机器学习初级篇(贝叶斯算法与SVM算法)
人工智能·机器学习·支持向量机
CoovallyAIHub5 小时前
纯视觉的终结?顶会趋势:不会联觉(多模态)的CV不是好AI
深度学习·算法·计算机视觉
懷淰メ5 小时前
python3GUI--基于深度学习的人脸识别管理系统(详细图文介绍)
人工智能·深度学习·人脸识别·pyqt·人脸·识别系统·人脸管理
CoovallyAIHub5 小时前
一文读懂大语言模型家族:LLM、MLLM、LMM、VLM核心概念全解析
深度学习·算法·计算机视觉