基于卷积神经网络(CNN)和ResNet50的水果与蔬菜图像分类系统

前言

在现代智能生活中,计算机视觉技术已经成为不可或缺的工具,特别是在食物识别领域。想象一下,您只需拍摄一张水果或蔬菜的照片,系统就能自动识别其种类并为您提供丰富的食谱建议。这项技术不仅在日常生活中极具实用性,在农业、食品配送及健康监测等多个行业中也有着广泛的应用。

本文展示了一个基于深度学习的水果与蔬菜分类系统,采用了强大的卷积神经网络(CNN)和先进的数据增强技术,能够在各种复杂环境下准确识别出不同的水果和蔬菜种类。通过使用预训练的ResNet50模型和混合精度训练,系统优化了训练过程的效率和准确度,并且引入了OneCycleLR学习率调度策略,以确保最佳的学习速度。

无论是在个人项目、商业应用,还是在未来的食品识别系统中,本项目都能为您提供强有力的技术支持。通过本代码,您将能够实现从数据加载、模型训练到最终预测的完整流程,轻松将深度学习应用到食品识别的各个方面。

让我们一起探索这个强大的工具,如何帮助我们实现更智能的生活!

概述

本项目实现了一个基于深度学习的水果和蔬菜识别系统,旨在通过计算机视觉技术对图像中的食品进行分类。系统的核心基于卷积神经网络(CNN)架构,结合了数据增强技术、预训练模型、混合精度训练和学习率调度等先进策略,以提高训练效率和分类准确度。

主要功能:

  1. 数据预处理与增强:使用图像预处理技术(如调整大小、随机旋转、颜色调整等)对输入数据进行增强,提高模型的鲁棒性和泛化能力。
  2. 自定义数据集 :通过FruitVegDataset类构建自定义数据集,支持从指定路径加载和标记图像,并能够方便地应用图像转换。
  3. 深度学习模型:利用卷积神经网络(CNN)进行特征提取,并通过ResNet50预训练模型提升识别能力。该模型经过优化,具有较强的表现力,能够识别多达36类水果和蔬菜。
  4. 训练与验证:通过使用AdamW优化器、交叉熵损失函数以及OneCycleLR学习率调度器,优化了训练过程。采用了混合精度训练(Mixed Precision Training)以加速训练过程,同时减少显存使用。
  5. 预测与应用:训练好的模型可用于实时图像预测,用户只需上传一张水果或蔬菜的图片,系统即可返回预测结果,并展示分类的概率信息。

系统特点:

  • 高效训练:通过学习率调度和优化器调整,训练过程不仅更加高效,还能提升模型在验证集上的准确度。
  • 增强现实应用:该模型能够应用于餐厅菜单识别、农业监测、食品配送、健康管理等实际场景,具有较高的商业和应用价值。
  • 简易部署:训练后的模型可以轻松部署到各类应用中,包括移动端应用或web端服务,使得实时食品识别变得更加便捷。

本项目展示了如何通过深度学习技术实现水果和蔬菜的自动分类,推动了食品识别领域的进一步发展,同时为智能农业、健康饮食等领域提供了有力的技术支持。

ResNet50模型介绍

ResNet50 是一种深度残差网络(Residual Network),由微软研究院的何恺明等人于2015年提出。它是ResNet系列中的一个重要变体,具有50层深度,广泛用于计算机视觉任务,如图像分类、目标检测和语义分割。ResNet50的核心思想是引入残差连接(Residual Connections),即通过跳跃连接(skip connections)直接将输入添加到输出,从而解决深层网络中的梯度消失和梯度爆炸问题,促进更深层次网络的训练。

ResNet50的特点
  1. 残差连接(Residual Connections)

    • 传统的深层网络容易出现梯度消失或梯度爆炸的问题,使得训练变得困难。ResNet通过引入残差连接,将输入数据直接跳跃到输出端,形成"捷径"(shortcut)。这使得网络能够更容易地学习到残差(输入和输出的差值),而非直接学习整个映射函数。
    • 这种设计可以有效避免深层网络中的退化问题,提升网络的训练效率和性能。
  2. 深度网络结构

    • ResNet50 的深度为50层,采用了多个卷积层(Convolutional Layers)批量归一化层(Batch Normalization),通过堆叠的方式构成深层的神经网络。每一层的输出与输入之间通过跳跃连接直接相加,简化了网络的训练过程。
    • ResNet50相比于其它较浅的网络(如ResNet18、ResNet34)提供了更多的学习能力,能够学习到更复杂的特征。
  3. 残差模块(Residual Block)

    • 在ResNet50中,残差模块是由多个卷积层和残差连接组成的。通常,一个残差模块包括两到三层卷积,每层后跟一个批量归一化层和ReLU激活函数。
    • 每个模块通过1x1卷积(通常用于减少或恢复通道数)与输入建立直接的跳跃连接,最终将输入和输出相加。
    • 通过残差模块,ResNet能够在避免过拟合的情况下训练非常深的网络,并保持较高的准确率。
  4. 瓶颈结构(Bottleneck Architecture)

    • ResNet50采用了瓶颈结构,即每个残差块包含三个卷积层:一个1x1卷积层(用于降低维度),一个3x3卷积层(用于特征提取),以及一个1x1卷积层(用于恢复维度)。
    • 这种结构有效减少了计算量,并且提高了网络的效率。相比于普通的卷积层,瓶颈结构大大减少了参数数量和计算量,使得网络能够在有限的硬件资源上运行得更加高效。
  5. 跳跃连接的应用

    • ResNet50的最大创新之一就是其跳跃连接,它允许信号在网络中传递得更远。每个跳跃连接将前一层的输出与当前层的输出相加,生成最终的输出,这样有助于更容易地训练更深的网络,减少了网络中的退化问题。
    • 通过这种方式,网络不仅可以学习到更复杂的特征,还能够避免梯度在反向传播中的衰减。
  6. 预训练和迁移学习

    • ResNet50常常用作预训练模型,尤其在迁移学习中非常流行。通过在大规模数据集(如ImageNet)上进行预训练,ResNet50能够学习到通用的图像特征,这些特征可以迁移到其他特定的任务上,从而提高目标任务的性能。
    • 由于其出色的特征提取能力,ResNet50作为特征提取器在许多计算机视觉任务中表现出色,并且能够显著减少训练时间。
  7. 较低的计算成本

    • ResNet50相较于更深的网络(如ResNet101、ResNet152)在保持高性能的同时,计算成本相对较低。50层深度的网络结构相较于更深的变体,参数和计算量适中,适合于资源受限的环境。
ResNet50的应用
  • 图像分类:ResNet50被广泛用于图像分类任务,特别是在ImageNet等大规模数据集上训练后,能够为图像提供强大的特征表示。它在ImageNet挑战赛中表现出色,取得了很高的准确率。
  • 目标检测与语义分割:通过结合其它架构(如Faster R-CNN、Mask R-CNN),ResNet50也常用于目标检测和语义分割任务,提取高质量的特征来帮助检测和分割任务。
  • 迁移学习 :由于其优异的特征提取能力,ResNet50常作为迁移学习模型的基础,能够应用于医疗图像分析、面部识别、视频分析等领域。

模型的核心逻辑

本项目采用了基于深度学习的卷积神经网络(CNN)来进行水果与蔬菜分类任务。具体的核心逻辑包括以下几个部分:

1. 使用预训练模型作为特征提取器

核心的模型结构基于ResNet50,该模型在ImageNet上预训练过,已经学到了有效的图像特征。因此,在我们的任务中,ResNet50能够有效地提取水果和蔬菜图像中的低层次和高层次特征。

  • 冻结部分层:为了减少计算量,并且避免在较少的数据集上过拟合,我们选择冻结ResNet50模型的前30层(即不更新这些层的权重)。这使得模型能够专注于学习更高层次的特征,而不需要重新学习基础的图像特征。
python 复制代码
self.backbone = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)  
for param in list(self.backbone.parameters())[:-30]:  
    param.requires_grad = False  
  • 替换全连接层:ResNet50的原始全连接层被替换成自定义的全连接层,这一层是针对水果和蔬菜分类任务进行设计的。通过新的全连接层将提取到的特征映射到目标类别(水果与蔬菜类别)。
python 复制代码
self.backbone.fc = nn.Sequential(  
    nn.Linear(num_features, 1024),  
    nn.BatchNorm1d(1024),  
    nn.ReLU(inplace=True),  
    nn.Dropout(0.3),  
    nn.Linear(1024, 512),  
    nn.BatchNorm1d(512),  
    nn.ReLU(inplace=True),  
    nn.Dropout(0.3),  
    nn.Linear(512, num_classes)  
)  
2. 数据增强与预处理

为了增加训练数据的多样性,减少模型的过拟合,输入图像经过了一系列的数据增强操作。这些操作包括:

  • 缩放、裁剪:通过随机缩放、随机裁剪等操作确保模型能够应对不同尺度的图像。
  • 旋转与翻转:通过随机旋转、水平和垂直翻转等,增强模型的鲁棒性。
  • 颜色抖动:对图像的亮度、对比度、饱和度等进行随机变化,以增加模型对颜色变化的适应性。

这些数据增强方法提高了模型在未见数据上的泛化能力。

python 复制代码
train_transform = transforms.Compose([  
    transforms.Resize((256, 256)),  
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),  
    transforms.RandomHorizontalFlip(),  
    transforms.RandomVerticalFlip(),  
    transforms.RandomRotation(20),  
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  
    transforms.ToTensor(),  
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  
])  
3. 模型训练与优化

训练过程中,使用了以下几个重要的技术:

  • OneCycleLR学习率调度器 :为了加速训练过程并避免过拟合,使用了OneCycleLR学习率调度器,它帮助在训练初期增加学习率,然后逐渐减小,以使模型收敛得更快并且避免在训练结束时陷入局部最优解。
python 复制代码
scheduler = OneCycleLR(  
    optimizer,  
    max_lr=config.learning_rate,  
    epochs=config.epochs,  
    steps_per_epoch=len(train_loader),  
    pct_start=0.1,  
    anneal_strategy='cos'  
)  
  • 优化器 :使用了AdamW优化器,它是一种基于自适应估计的优化方法,适合深度学习任务。通过AdamW优化器,我们能够有效地更新模型参数。

  • 混合精度训练 :为了提高训练效率和减少显存占用,使用了PyTorch的混合精度训练(autocastGradScaler)。这使得在计算过程中部分操作使用半精度浮点数(FP16),以提高速度和节省内存,同时保持较高的精度。

python 复制代码
with autocast():  
    outputs = model(inputs)  
    loss = criterion(outputs, labels)  
4. 损失函数与评估
  • 损失函数 :使用了交叉熵损失(Cross-Entropy Loss)作为训练的目标函数,因为它适用于多类别分类任务。模型通过最小化交叉熵损失来优化其分类精度。
python 复制代码
criterion = nn.CrossEntropyLoss()  
  • 评估指标 :除了损失函数,训练过程中还监控了准确率(Accuracy),即模型在给定的测试集上的分类正确率。通过准确率来评估模型的性能,并在训练过程中选择最优的模型。
5. 模型预测与推断

训练完成后,模型可以用于对新的图像进行预测。输入图像首先经过相同的数据预处理和增强(例如调整大小、规范化等),然后输入到训练好的模型中,得到模型的预测输出。

模型输出的结果通过softmax函数转化为每个类别的概率值,最终返回最可能的类别及其对应的概率。

python 复制代码
def predict_image(url, model):  
    response = requests.get(url)  
    image = Image.open(BytesIO(response.content)).convert('RGB')  
    input_tensor = transform(image).unsqueeze(0)  
    with torch.no_grad():  
        output = model(input_tensor)  
        probabilities = torch.nn.functional.softmax(output[0], dim=0)  
        predicted_class = torch.argmax(probabilities).item()  
    return predicted_class, probabilities[predicted_class].item()  

代码实现

1. 设置随机种子和设备

为了保证结果的可重复性,我们设置了随机种子。然后确定是否使用GPU,如果GPU可用,则使用GPU,否则使用CPU。

python 复制代码
!pip install ultralytics -i  https://mirrors.aliyun.com/pypi/simple/ numpy
!pip install albumentations -i  https://mirrors.aliyun.com/pypi/simple/ numpy
!pip install timm -i  https://mirrors.aliyun.com/pypi/simple/ numpy
!pip install wandb -i  https://mirrors.aliyun.com/pypi/simple/ numpy
python 复制代码
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import random

# Set seeds for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create folder for saving results
os.makedirs('results', exist_ok=True)
2. 数据集展示

这一部分的代码用来展示数据集的结构,打印数据集的类和图像数量,并随机展示一些训练集的图像。

python 复制代码
def explore_data(data_path):
    """Explore and visualize the dataset"""
    print("\nExploring Dataset Structure:")
    print("-" * 50)
    
    splits = ['train', 'validation', 'test']
    for split in splits:
        split_path = os.path.join(data_path, split)
        if os.path.exists(split_path):
            classes = sorted(os.listdir(split_path))
            total_images = sum(len(os.listdir(os.path.join(split_path, cls))) 
                             for cls in classes)
            
            print(f"\n{split.capitalize()} Set:")
            print(f"Number of classes: {len(classes)}")
            print(f"Total images: {total_images}")
            print(f"Example classes: {', '.join(classes[:5])}...")
    
    # Visualize sample images
    print("\nVisualizing Sample Images...")
    train_path = os.path.join(data_path, 'train')
    classes = sorted(os.listdir(train_path))
    
    plt.figure(figsize=(15, 10))
    for i in range(9):
        class_name = random.choice(classes)
        class_path = os.path.join(train_path, class_name)
        img_name = random.choice(os.listdir(class_path))
        img_path = os.path.join(class_path, img_name)
        
        img = Image.open(img_path)
        plt.subplot(3, 3, i+1)
        plt.imshow(img)
        plt.title(f'Class: {class_name}')
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('results/sample_images.png')
    plt.show()

# Explore dataset
data_path = "/home/mw/input/Fruit1112533/Fruits and Vegetables Image Recognition Dataset"
explore_data(data_path)


3. 自定义数据集类

这部分代码定义了一个自定义的PyTorch Dataset 类,FruitVegDataset,用于加载数据集,并支持图像的转换(如缩放、裁剪等)。

python 复制代码
class FruitVegDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = os.path.join(root_dir, split)
        self.transform = transform
        self.classes = sorted(os.listdir(self.root_dir))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        
        self.images = []
        self.labels = []
        
        for class_name in self.classes:
            class_path = os.path.join(self.root_dir, class_name)
            for img_name in os.listdir(class_path):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.images.append(os.path.join(class_path, img_name))
                    self.labels.append(self.class_to_idx[class_name])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        
        return image, label
4. 数据增强和预处理

这里定义了数据增强和预处理流程。使用了常见的数据增强方法,如随机水平翻转、随机旋转、颜色抖动等。并且对图像进行标准化处理。

python 复制代码
# Define transforms
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Visualize augmentations
def show_augmentations(dataset, num_augments=5):
    """Show original image and its augmented versions"""
    idx = random.randint(0, len(dataset)-1)
    img_path = dataset.images[idx]
    original_img = Image.open(img_path).convert('RGB')
    
    plt.figure(figsize=(15, 5))
    
    # Show original
    plt.subplot(1, num_augments+1, 1)
    plt.imshow(original_img)
    plt.title('Original')
    plt.axis('off')
    
    # Show augmented versions
    for i in range(num_augments):
        augmented = train_transform(original_img)
        augmented = augmented.permute(1, 2, 0).numpy()
        augmented = (augmented * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406])
        augmented = np.clip(augmented, 0, 1)
        
        plt.subplot(1, num_augments+1, i+2)
        plt.imshow(augmented)
        plt.title(f'Augmented {i+1}')
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('results/augmentations.png')
    plt.show()

# Create datasets and show augmentations
train_dataset = FruitVegDataset(data_path, 'train', train_transform)
show_augmentations(train_dataset)
5. 卷积块和网络结构

这一部分代码定义了一个卷积块(ConvBlock)和一个自定义的卷积神经网络(FruitVegCNN)用于图像分类。

python 复制代码
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        
    def forward(self, x):
        return self.conv(x)

class FruitVegCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        
        self.features = nn.Sequential(
            ConvBlock(3, 64),
            ConvBlock(64, 128),
            ConvBlock(128, 256),
            ConvBlock(256, 512),
            ConvBlock(512, 512)
        )
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
        
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Function to visualize feature maps
def visualize_feature_maps(model, sample_image):
    """Visualize feature maps after each conv block"""
    model.eval()
    
    # Get feature maps after each conv block
    feature_maps = []
    x = sample_image.unsqueeze(0).to(device)
    
    for block in model.features:
        x = block(x)
        feature_maps.append(x.detach().cpu())
    
    # Plot feature maps
    plt.figure(figsize=(15, 10))
    for i, fmap in enumerate(feature_maps):
        # Plot first 6 channels of each block
        fmap = fmap[0][:6].permute(1, 2, 0)
        fmap = (fmap - fmap.min()) / (fmap.max() - fmap.min())
        
        for j in range(min(6, fmap.shape[-1])):
            plt.subplot(5, 6, i*6 + j + 1)
            plt.imshow(fmap[:, :, j], cmap='viridis')
            plt.title(f'Block {i+1}, Ch {j+1}')
            plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('results/feature_maps.png')
    plt.show()

# Initialize model and visualize feature maps
model = FruitVegCNN(num_classes=len(train_dataset.classes)).to(device)
sample_image, _ = train_dataset[0]
visualize_feature_maps(model, sample_image)
6. 训练和验证函数

定义了训练(train_one_epoch)和验证(validate)函数。这些函数在每个epoch中更新模型权重,并计算损失和准确率。

python 复制代码
def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc='Training')
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100.*correct/total:.2f}%'
        })
    
    return running_loss / len(train_loader), 100. * correct / total

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc='Validation'):
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return running_loss / len(val_loader), 100. * correct / total

def plot_training_progress(history):
    """Plot and save training progress"""
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title('Loss History')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.title('Accuracy History')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('results/training_progress.png')
    plt.show()
7. 训练与验证过程

在此部分代码中,我们定义了训练和验证的数据加载器,并设置了模型训练的相关配置。使用CrossEntropyLoss作为损失函数,AdamW优化器来优化模型,同时设置了学习率调度器ReduceLROnPlateau以自动调整学习率。训练过程包括多轮的训练与验证,并在每个周期结束时记录和打印训练与验证的损失和准确率。此外,还会保存每个周期的模型权重并在验证准确率提高时保存最佳模型。

python 复制代码
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_dataset = FruitVegDataset(data_path, 'validation', val_transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True)

# Training loop
num_epochs = 30
best_val_acc = 0
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': []
}

print("\nStarting training...")
for epoch in range(num_epochs):
    print(f'\nEpoch {epoch+1}/{num_epochs}')
    
    train_loss, train_acc = train_one_epoch(
        model, train_loader, criterion, optimizer, device)
    
    val_loss, val_acc = validate(
        model, val_loader, criterion, device)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')
    
    # Plot progress
    if (epoch + 1) % 5 == 0:
        plot_training_progress(history)
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        print(f'New best validation accuracy: {best_val_acc:.2f}%')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_acc': best_val_acc,
        }, 'results/best_model.pth')

# Final training visualization
plot_training_progress(history)


8. 绘制训练与验证的准确率与损失曲线

此部分代码用于可视化训练过程中模型的准确率和损失变化情况。通过绘制训练和验证集上的准确率与损失曲线,帮助我们直观地观察模型在不同训练周期中的表现。同时,代码会输出训练和验证过程中达到的最佳准确率,以便进一步分析模型的性能。

python 复制代码
import matplotlib.pyplot as plt

def plot_accuracy_loss(history):
    """Plot training and validation accuracy/loss curves"""
    plt.figure(figsize=(12, 4))
    
    # Plot Accuracy
    plt.subplot(1, 2, 1)
    plt.plot(history['train_acc'], label='Training', marker='o')
    plt.plot(history['val_acc'], label='Validation', marker='o')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True)
    
    # Plot Loss
    plt.subplot(1, 2, 2)
    plt.plot(history['train_loss'], label='Training', marker='o')
    plt.plot(history['val_loss'], label='Validation', marker='o')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig('results/accuracy_loss_curves.png')
    plt.show()

    # Print best accuracy values
    best_train_acc = max(history['train_acc'])
    best_val_acc = max(history['val_acc'])
    print(f"\nBest Training Accuracy: {best_train_acc:.2f}%")
    print(f"Best Validation Accuracy: {best_val_acc:.2f}%")

# Plot the curves
plot_accuracy_loss(history)
9. 优化的训练配置与增强数据增强

此部分代码实现了一个优化的训练流程,主要包括改进的超参数配置、增强的数据预处理以及混合精度训练技术。通过使用 ResNet50 作为骨干网络,添加了逐层冻结策略、增强的分类器结构(带有Dropout和Batch Normalization)以及One Cycle Learning Rate调度器等技术,可以提升模型的训练效果和泛化能力。此外,训练过程中应用了混合精度训练来加速计算并减少显存占用,进一步优化了训练过程。

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.cuda.amp import autocast, GradScaler

# Improved training configurations
class OptimizedConfig:
    def __init__(self):
        self.image_size = 256  # Increased from 224
        self.batch_size = 16   # Smaller batch size for better generalization
        self.learning_rate = 3e-4
        self.weight_decay = 0.01
        self.epochs = 50
        self.dropout = 0.3
        
# Enhanced data augmentation
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.3),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Optimized model architecture
class OptimizedCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        
        # Use pretrained ResNet50 as backbone
        self.backbone = torch.hub.load('pytorch/vision:v0.10.0', 
                                     'resnet50', pretrained=True)
        
        # Freeze early layers
        for param in list(self.backbone.parameters())[:-30]:
            param.requires_grad = False
            
        # Modified classifier
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Linear(num_features, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, x):
        return self.backbone(x)

# Optimized training function
def train_with_optimization(model, train_loader, val_loader, config):
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.AdamW(model.parameters(), 
                           lr=config.learning_rate, 
                           weight_decay=config.weight_decay)
    
    # One Cycle Learning Rate Scheduler
    scheduler = OneCycleLR(
        optimizer,
        max_lr=config.learning_rate,
        epochs=config.epochs,
        steps_per_epoch=len(train_loader),
        pct_start=0.1,
        anneal_strategy='cos'
    )
    
    # Gradient Scaler for mixed precision training
    scaler = GradScaler()
    
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': []
    }
    
    best_val_acc = 0
    
    for epoch in range(config.epochs):
        # Training
        model.train()
        train_loss = 0
        correct = 0
        total = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config.epochs}')
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            # Mixed precision training
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100.*correct/total:.2f}%',
                'lr': f'{scheduler.get_last_lr()[0]:.6f}'
            })
            
        train_acc = 100. * correct / total
        train_loss = train_loss / len(train_loader)
        
        # Validation
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc='Validation'):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        val_acc = 100. * correct / total
        val_loss = val_loss / len(val_loader)
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f'\nEpoch {epoch+1}/{config.epochs}:')
        print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_acc': best_val_acc,
            }, 'optimized_model.pth')
            print(f'New best validation accuracy: {best_val_acc:.2f}%')
    
    return history

# Create dataloaders with optimized configuration
config = OptimizedConfig()
train_dataset = FruitVegDataset(data_path, 'train', train_transform)
val_dataset = FruitVegDataset(data_path, 'validation', val_transform)

train_loader = DataLoader(train_dataset, 
                         batch_size=config.batch_size,
                         shuffle=True,
                         num_workers=4,
                         pin_memory=True)
val_loader = DataLoader(val_dataset,
                       batch_size=config.batch_size,
                       shuffle=False,
                       num_workers=4,
                       pin_memory=True)

# Initialize and train optimized model
model = OptimizedCNN(num_classes=len(train_dataset.classes)).to(device)
history = train_with_optimization(model, train_loader, val_loader, config)
10. 优化结果的可视化

此部分代码负责可视化优化后的训练和验证过程中的准确率与损失值。通过图表展示模型在训练和验证集上的表现,帮助评估优化策略的有效性。代码还输出了最佳的训练和验证准确率,便于进一步分析模型的性能。

python 复制代码
def plot_optimized_results(history):
    plt.style.use('seaborn-v0_8')
    plt.figure(figsize=(15, 5))
    
    # Plot Accuracy
    plt.subplot(1, 2, 1)
    plt.plot(history['train_acc'], label='Training', marker='o')
    plt.plot(history['val_acc'], label='Validation', marker='o')
    plt.title('Model Accuracy with Optimizations')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True)
    
    # Plot Loss
    plt.subplot(1, 2, 2)
    plt.plot(history['train_loss'], label='Training', marker='o')
    plt.plot(history['val_loss'], label='Validation', marker='o')
    plt.title('Model Loss with Optimizations')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig('optimized_results.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print best metrics
    best_train_acc = max(history['train_acc'])
    best_val_acc = max(history['val_acc'])
    print(f"\nBest Training Accuracy: {best_train_acc:.2f}%")
    print(f"Best Validation Accuracy: {best_val_acc:.2f}%")

# Plot results
plot_optimized_results(history)
11. 模型加载与图像预测

这段代码提供了一个从URL加载图像并用训练好的模型进行预测的流程。首先,加载已保存的模型,并通过预处理步骤对图像进行转换,然后进行推理并展示前5个预测结果。

python 复制代码
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import requests
from io import BytesIO

# Load the saved model
def load_model():
    # Check if model file exists
    try:
        # Load model checkpoint
        checkpoint = torch.load('optimized_model.pth')
        model = OptimizedCNN(num_classes=36)  # Same as training
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        print("Model loaded successfully!")
        return model
    except FileNotFoundError:
        print("Model file 'optimized_model.pth' not found!")
        return None

# Prediction function
def predict_image(url, model):
    # Image preprocessing
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # Load image from URL
    response = requests.get(url)
    image = Image.open(BytesIO(response.content)).convert('RGB')
    
    # Transform image
    input_tensor = transform(image).unsqueeze(0)
    
    # Make prediction
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        
        # Get top 5 predictions
        top_probs, top_indices = torch.topk(probabilities, 5)
    
    # Show results
    plt.figure(figsize=(12, 4))
    
    # Show image
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title('Input Image')
    plt.axis('off')
    
    # Show predictions
    plt.subplot(1, 2, 2)
    classes = sorted(os.listdir("/home/mw/input/Fruit1112533/Fruits and Vegetables Image Recognition Dataset/train"))
    y_pos = range(5)
    plt.barh(y_pos, [prob.item() * 100 for prob in top_probs])
    plt.yticks(y_pos, [classes[idx] for idx in top_indices])
    plt.xlabel('Probability (%)')
    plt.title('Top 5 Predictions')
    
    plt.tight_layout()
    plt.show()
    
    # Print predictions
    print("\nPredictions:")
    print("-" * 30)
    for i in range(5):
        print(f"{classes[top_indices[i]]:20s}: {top_probs[i]*100:.2f}%")

# Load model
model = load_model()

# Now you can use it like this:
predict_image('https://pngimg.com/uploads/watermelon/watermelon_PNG2640.png', model)

注意

python 复制代码
# 需要完整代码以及数据集请点击以下链接:
https://mbd.pub/o/bread/mbd-Z5yclpZu
相关推荐
不如语冰32 分钟前
深度学习Python基础(2)
人工智能·python·深度学习·语言模型
Gauss松鼠会1 小时前
GaussDB 企业版轻量化部署探索(二)
数据库·人工智能·docker·华为云·gaussdb
程序猿阿伟1 小时前
《C++与 Armadillo:线性代数助力人工智能算法简化之路》
c++·人工智能·线性代数
沐欣工作室_lvyiyi1 小时前
基于单片机的无线水塔监控系统设计(论文+源码)
人工智能·stm32·单片机·嵌入式硬件·单片机毕业设计
野蛮的大西瓜1 小时前
BigBlueButton视频会议 vs 华为云会议的详细对比
人工智能·自动化·音视频·实时音视频·信息与通信·视频编解码
野蛮的大西瓜2 小时前
文心一言对接FreeSWITCH实现大模型呼叫中心
人工智能·机器人·自动化·音视频·实时音视频·文心一言·信息与通信
lover_putter3 小时前
ai学习报告:训练
人工智能·学习
Srlua3 小时前
基于预测反馈的情感分析情境学习
人工智能·python
这个男人是小帅3 小时前
【AutoDL】通过【SSH远程连接】【vscode】
运维·人工智能·pytorch·vscode·深度学习·ssh
野蛮的大西瓜3 小时前
BigBlueButton视频会议 vs 钉钉视频会议系统的详细对比
人工智能·自然语言处理·自动化·音视频·实时音视频·信息与通信·视频编解码