【无标题】

导入必要库

import torch # PyTorch核心库

import torch.nn as nn # 神经网络层

from torch.utils.data import DataLoader # 数据加载器(批量处理数据)

from torchvision import datasets, transforms, models # 数据集/数据增强/预训练模型

import os # 路径操作

---------------------- 1. 核心配置参数(新手可改这里) ----------------------

DEVICE = torch.device("cpu") # 无GPU则用CPU;有GPU可改为"cuda"

BATCH_SIZE = 8 # 每次处理8张图片(小批量,避免内存不足)

EPOCHS = 20 # 训练20轮(轮数越多,模型越准,但易过拟合)

DATA_DIR = "dataset" # 数据集路径(和项目结构对应)

SAVE_PATH = "plant_model.pth" # 训练好的模型保存路径

---------------------- 2. 数据预处理(关键!统一输入格式) ----------------------

定义训练集/验证集的预处理规则

---------------------- 2. 数据预处理(新增AI图像优化) ----------------------

data_transforms = {

"train": transforms.Compose([ # 训练集:AI图像增强

transforms.Resize((224, 224)),

transforms.RandomHorizontalFlip(),

新增AI级图像优化:随机调整对比度/亮度(增强叶片特征)

transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),

transforms.ToTensor(),

transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

]),

"val": transforms.Compose([ # 验证集:基础优化

transforms.Resize((224, 224)),

transforms.ToTensor(),

transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

])

}

加载数据集(自动读取文件夹结构,生成标签)

image_datasets = {

x: datasets.ImageFolder(os.path.join(DATA_DIR, x), data_transforms[x])

for x in ["train", "val"]

}

创建数据加载器(批量读取+打乱数据)

dataloaders = {

x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True)

for x in ["train", "val"]

}

获取类别名称(自动从文件夹名生成:['healthy', 'water_shortage', 'pest_disease'])

class_names = image_datasets["train"].classes

print("检测类别:", class_names) # 打印类别,确认数据集加载正确

---------------------- 3. 加载预训练模型(迁移学习,新手友好) ----------------------

加载MobileNetV2预训练模型(轻量级,适合电脑/嵌入式设备)

model = models.mobilenet_v2(pretrained=True)

冻结特征提取层(仅训练分类层,加快训练速度,避免过拟合)

for param in model.features.parameters():

param.requires_grad = False

修改最后一层:原模型输出1000类(ImageNet),改为3类(我们的数据集)

num_ftrs = model.classifier[1].in_features # 获取最后一层输入维度

model.classifier[1] = nn.Linear(num_ftrs, len(class_names)) # 这里len(class_names)是7

model = model.to(DEVICE) # 将模型移到CPU/GPU

---------------------- 4. 定义损失函数&优化器(模型学习的核心) ----------------------

criterion = nn.CrossEntropyLoss() # 交叉熵损失(适合分类任务)

优化器:Adam(自适应学习率,比SGD更优),只优化分类层参数

optimizer = torch.optim.Adam(model.classifier.parameters(), lr=0.0001)

---------------------- 5. 训练函数(核心逻辑) ----------------------

def train_model(model, criterion, optimizer, num_epochs=20):

遍历每一轮训练

for epoch in range(num_epochs):

print(f"\nEpoch {epoch+1}/{num_epochs}")

print("-" * 20)

复制代码
    # 分训练和验证两个阶段
    for phase in ["train", "val"]:
        if phase == "train":
            model.train()  # 训练模式:启用梯度下降
        else:
            model.eval()   # 验证模式:禁用梯度下降(只评估,不学习)

        running_loss = 0.0  # 累计损失
        running_corrects = 0  # 累计正确数

        # 遍历数据加载器中的每一批数据
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(DEVICE)  # 图片移到CPU/GPU
            labels = labels.to(DEVICE)  # 标签移到CPU/GPU

            # 梯度清零(避免上一轮梯度累积)
            optimizer.zero_grad()

            # 前向传播(预测)
            with torch.set_grad_enabled(phase == "train"):  # 训练阶段才计算梯度
                outputs = model(inputs)  # 模型输出(3个类的概率)
                _, preds = torch.max(outputs, 1)  # 取概率最大的类作为预测结果
                loss = criterion(outputs, labels)  # 计算损失(预测值vs真实值)

                # 训练阶段:反向传播+更新参数
                if phase == "train":
                    loss.backward()  # 反向传播(计算梯度)
                    optimizer.step()  # 更新模型参数

            # 统计损失和准确率
            running_loss += loss.item() * inputs.size(0)  # 累计损失
            running_corrects += torch.sum(preds == labels.data)  # 累计正确数

        # 计算本轮损失和准确率
        epoch_loss = running_loss / len(image_datasets[phase])
        epoch_acc = running_corrects.double() / len(image_datasets[phase])

        # 打印结果(直观看到训练效果)
        print(f"{phase} - 损失: {epoch_loss:.4f} | 准确率: {epoch_acc:.4f}")

# 训练完成,保存模型
torch.save(model.state_dict(), SAVE_PATH)
print(f"\n模型训练完成!已保存到:{SAVE_PATH}")
return model

启动训练(调用上面的函数)

if name == "main ":

train_model(model, criterion, optimizer, num_epochs=EPOCHS)

相关推荐
AI街潜水的八角19 小时前
深度学习烟叶病害分割系统3:含训练测试代码、数据集和GUI交互界面
人工智能·深度学习
AI街潜水的八角19 小时前
深度学习烟叶病害分割系统1:数据集说明(含下载链接)
人工智能·深度学习
weixin_4469340319 小时前
统计学中“in sample test”与“out of sample”有何区别?
人工智能·python·深度学习·机器学习·计算机视觉
莫非王土也非王臣19 小时前
循环神经网络
人工智能·rnn·深度学习
Lips61119 小时前
第五章 神经网络(含反向传播计算)
人工智能·深度学习·神经网络
wubba lubba dub dub75020 小时前
第三十三周 学习周报
学习·算法·机器学习
猫天意20 小时前
【深度学习小课堂】| torch | 升维打击还是原位拼接?深度解码 PyTorch 中 stack 与 cat 的几何奥义
开发语言·人工智能·pytorch·深度学习·神经网络·yolo·机器学习
cyyt20 小时前
深度学习周报(1.12~1.18)
人工智能·算法·机器学习
wuk99821 小时前
基于遗传算法优化BP神经网络实现非线性函数拟合
人工智能·深度学习·神经网络
白日做梦Q1 天前
深度学习中的正则化技术全景:从Dropout到权重衰减的优化逻辑
人工智能·深度学习