【无标题】

导入必要库

import torch # PyTorch核心库

import torch.nn as nn # 神经网络层

from torch.utils.data import DataLoader # 数据加载器(批量处理数据)

from torchvision import datasets, transforms, models # 数据集/数据增强/预训练模型

import os # 路径操作

---------------------- 1. 核心配置参数(新手可改这里) ----------------------

DEVICE = torch.device("cpu") # 无GPU则用CPU;有GPU可改为"cuda"

BATCH_SIZE = 8 # 每次处理8张图片(小批量,避免内存不足)

EPOCHS = 20 # 训练20轮(轮数越多,模型越准,但易过拟合)

DATA_DIR = "dataset" # 数据集路径(和项目结构对应)

SAVE_PATH = "plant_model.pth" # 训练好的模型保存路径

---------------------- 2. 数据预处理(关键!统一输入格式) ----------------------

定义训练集/验证集的预处理规则

---------------------- 2. 数据预处理(新增AI图像优化) ----------------------

data_transforms = {

"train": transforms.Compose([ # 训练集:AI图像增强

transforms.Resize((224, 224)),

transforms.RandomHorizontalFlip(),

新增AI级图像优化:随机调整对比度/亮度(增强叶片特征)

transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),

transforms.ToTensor(),

transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

]),

"val": transforms.Compose([ # 验证集:基础优化

transforms.Resize((224, 224)),

transforms.ToTensor(),

transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

])

}

加载数据集(自动读取文件夹结构,生成标签)

image_datasets = {

x: datasets.ImageFolder(os.path.join(DATA_DIR, x), data_transforms[x])

for x in ["train", "val"]

}

创建数据加载器(批量读取+打乱数据)

dataloaders = {

x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True)

for x in ["train", "val"]

}

获取类别名称(自动从文件夹名生成:['healthy', 'water_shortage', 'pest_disease'])

class_names = image_datasets["train"].classes

print("检测类别:", class_names) # 打印类别,确认数据集加载正确

---------------------- 3. 加载预训练模型(迁移学习,新手友好) ----------------------

加载MobileNetV2预训练模型(轻量级,适合电脑/嵌入式设备)

model = models.mobilenet_v2(pretrained=True)

冻结特征提取层(仅训练分类层,加快训练速度,避免过拟合)

for param in model.features.parameters():

param.requires_grad = False

修改最后一层:原模型输出1000类(ImageNet),改为3类(我们的数据集)

num_ftrs = model.classifier[1].in_features # 获取最后一层输入维度

model.classifier[1] = nn.Linear(num_ftrs, len(class_names)) # 这里len(class_names)是7

model = model.to(DEVICE) # 将模型移到CPU/GPU

---------------------- 4. 定义损失函数&优化器(模型学习的核心) ----------------------

criterion = nn.CrossEntropyLoss() # 交叉熵损失(适合分类任务)

优化器:Adam(自适应学习率,比SGD更优),只优化分类层参数

optimizer = torch.optim.Adam(model.classifier.parameters(), lr=0.0001)

---------------------- 5. 训练函数(核心逻辑) ----------------------

def train_model(model, criterion, optimizer, num_epochs=20):

遍历每一轮训练

for epoch in range(num_epochs):

print(f"\nEpoch {epoch+1}/{num_epochs}")

print("-" * 20)

复制代码
    # 分训练和验证两个阶段
    for phase in ["train", "val"]:
        if phase == "train":
            model.train()  # 训练模式:启用梯度下降
        else:
            model.eval()   # 验证模式:禁用梯度下降(只评估,不学习)

        running_loss = 0.0  # 累计损失
        running_corrects = 0  # 累计正确数

        # 遍历数据加载器中的每一批数据
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(DEVICE)  # 图片移到CPU/GPU
            labels = labels.to(DEVICE)  # 标签移到CPU/GPU

            # 梯度清零(避免上一轮梯度累积)
            optimizer.zero_grad()

            # 前向传播(预测)
            with torch.set_grad_enabled(phase == "train"):  # 训练阶段才计算梯度
                outputs = model(inputs)  # 模型输出(3个类的概率)
                _, preds = torch.max(outputs, 1)  # 取概率最大的类作为预测结果
                loss = criterion(outputs, labels)  # 计算损失(预测值vs真实值)

                # 训练阶段:反向传播+更新参数
                if phase == "train":
                    loss.backward()  # 反向传播(计算梯度)
                    optimizer.step()  # 更新模型参数

            # 统计损失和准确率
            running_loss += loss.item() * inputs.size(0)  # 累计损失
            running_corrects += torch.sum(preds == labels.data)  # 累计正确数

        # 计算本轮损失和准确率
        epoch_loss = running_loss / len(image_datasets[phase])
        epoch_acc = running_corrects.double() / len(image_datasets[phase])

        # 打印结果(直观看到训练效果)
        print(f"{phase} - 损失: {epoch_loss:.4f} | 准确率: {epoch_acc:.4f}")

# 训练完成,保存模型
torch.save(model.state_dict(), SAVE_PATH)
print(f"\n模型训练完成!已保存到:{SAVE_PATH}")
return model

启动训练(调用上面的函数)

if name == "main ":

train_model(model, criterion, optimizer, num_epochs=EPOCHS)

相关推荐
冰西瓜60033 分钟前
深度学习的数学原理(三十三)—— Transformer编码器完整实现
人工智能·深度学习·transformer
我是大聪明.2 小时前
CUDA矩阵乘法优化:共享内存分块与Warp级执行机制深度解析
人工智能·深度学习·线性代数·机器学习·矩阵
码云数智-大飞2 小时前
大模型幻觉:成因解析与有效避免策略
人工智能·深度学习
Mr数据杨2 小时前
四子棋智能体构建与在线对抗决策应用
机器学习·数据分析·kaggle
木枷3 小时前
rl/swe/sft相关论文列表
人工智能·深度学习
A7bert7773 小时前
【YOLOv8pose部署至RDK X5】模型训练→转换bin→Sunrise 5部署
c++·python·深度学习·yolo·目标检测
爱学习的张大3 小时前
具身智能论文精度(八):Pi0.6
人工智能·深度学习
AI科技星4 小时前
科幻艺术书本封面:《全域数学》第一部·数术本源 第三卷 代数原本(P95-141)完整五级目录【乖乖数学】
算法·机器学习·数学建模·数据挖掘·量子计算
墨北小七4 小时前
从目标检测到行为识别:YOLO 模型微调实战
人工智能·深度学习·神经网络
Mr数据杨4 小时前
灾害推文识别与应急信息筛选优化
机器学习·数据分析·kaggle