基于 PyTroch 的卷积神经网络在图像分类中应用与实践

文章目录


卷积神经网络在图像分类中的应用与实践

一、卷积神经网络基础架构

卷积神经网络(CNN)是深度学习领域中处理图像数据的核心架构。它通过卷积层自动提取图像的空间特征,配合池化层降低特征维度,最终通过全连接层完成分类任务。

1.1 网络结构设计

在图像分类任务中,构建了一个包含三个卷积块的CNN模型:

python 复制代码
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=3,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2,
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, 1, 2),
            nn.ReLU(),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 128, 5, 1, 2),
            nn.ReLU(),
        )
        self.out = nn.Linear(128 * 64 * 64, 20)

参数分析:

  • in_channels=3:输入图像的通道数(RGB三通道)
  • out_channels=16/32/128:卷积核数量,决定特征图的深度
  • kernel_size=5:卷积核尺寸,决定感受野大小
  • stride=1:卷积步长,控制特征图尺寸缩减速度
  • padding=2:边缘填充,保持特征图尺寸不变

二、数据处理与增强策略

2.1 数据预处理流程

图像数据在输入网络前需要经过标准化处理,使用ImageNet数据集的均值和标准差进行归一化:

python 复制代码
data_transforms = {
    'train':
        transforms.Compose([
        transforms.Resize([300, 300]),
        transforms.RandomRotation(45),
        transforms.CenterCrop(256),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.ColorJitter(
            brightness=0.2,
            contrast=0.1,
            saturation=0.1,
            hue=0.1
        ),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ]),
}

数据增强方法:

  • RandomRotation(45):随机旋转±45度,增强旋转不变性
  • RandomHorizontalFlip(p=0.5):50%概率水平翻转
  • ColorJitter:调整亮度、对比度、饱和度和色调,增强色彩鲁棒性

2.2 自定义数据集类

通过继承Dataset类实现数据加载接口:

python 复制代码
class FoodDataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.file_path = file_path
        self.imgs = []
        self.labels = []
        self.transform = transform

        with open(self.file_path) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path, label in samples:
                self.imgs.append(img_path)
                self.labels.append(int(label))

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        image = Image.open(self.imgs[idx])
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        label = torch.tensor(label, dtype=torch.long)
        return image, label

三、模型训练与优化

3.1 训练循环实现

训练过程中采用小批量梯度下降,定期输出损失值监控训练进度:

python 复制代码
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    batch_size_num = 1
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model.forward(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_value = loss.item()
        if batch_size_num % 100 == 0:
            print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")
        batch_size_num += 1

3.2 测试与模型保存

测试阶段计算准确率,并保存性能最佳的模型:

python 复制代码
best_acc = 0
def test(dataloader, model, loss_fn):
    global best_acc
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model.forward(X)
            test_loss = loss_fn(pred, y)
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size

    print(f"Test result:\n Accuracy:{(100 * correct):.2f}%, Avg loss: {test_loss}")

    if correct > best_acc:
        best_acc = correct
        script_model = torch.jit.script(model)
        torch.jit.save(script_model,"best12.pth")

四、优化器与损失函数配置

4.1 Adam优化器

相比传统SGD,Adam优化器结合了动量法和自适应学习率:

python 复制代码
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

参数特点:

  • 学习率lr=0.001:较小的初始学习率保证训练稳定性
  • 自适应调整:为每个参数计算独立的学习率
  • 动量机制:加速收敛并减少震荡

4.2 交叉熵损失函数

多分类任务使用交叉熵损失,衡量预测概率分布与真实分布的差异:

python 复制代码
loss_fn = nn.CrossEntropyLoss()

五、完整训练流程

5.1 主训练循环

设置训练轮数,交替进行训练和验证:

python 复制代码
epochs = 10
acc_s = []
loss_s = []
for t in range(epochs):
    print(f"Epoch {t + 1}\n----------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)

5.2 设备配置

自动检测可用计算设备,优先使用GPU加速:

python 复制代码
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

通过上述完整流程,卷积神经网络能够从原始图像数据中自动学习层次化特征,实现高效的图像分类任务。数据增强策略提升了模型的泛化能力,优化器配置确保了训练过程的稳定性,模型保存机制保留了最佳性能状态。

相关推荐
一招定胜负2 小时前
项目案例:卷积神经网络实现食物图片分类代码详细解析
人工智能·分类·cnn
心态特好4 小时前
pytorch和tenserflow详解
人工智能·pytorch·python
重生之我要成为代码大佬5 小时前
深度学习2-在2024pycharm版本中导入pytorch
人工智能·pytorch·深度学习
Easonmax6 小时前
从0到1:Qwen-1.8B-Chat 在昇腾Atlas 800T A2上的部署与实战指南前言
人工智能·pytorch·深度学习
BOF_dcb17 小时前
【无标题】
pytorch·深度学习·机器学习
人工智能培训20 小时前
深度学习—卷积神经网络(1)
人工智能·深度学习·神经网络·机器学习·cnn·知识图谱·dnn
Keep_Trying_Go21 小时前
文生图算法C4Synth: Cross-Caption Cycle-Consistent Text-to-Image Synthesis详解
人工智能·pytorch·深度学习·计算机视觉·文生图
Niuguangshuo1 天前
# PyTorch 中 `nn.ModuleList` 详解
人工智能·pytorch·python
ASD123asfadxv1 天前
基于改进Faster R-CNN的鸭蛋质量检测与分类系统_x101-32x8d_fpn_ms-3x_coco模型详解
人工智能·分类·cnn