DAY41简单 CNN

一、知识回顾要点
  1. 数据增强:通过对训练图像做翻转、裁剪、旋转等变换,扩充数据集,提升模型泛化能力,避免过拟合。
  2. 卷积神经网络定义写法:按层堆叠的方式搭建模型,核心包含卷积层、池化层、全连接层等组件,需明确输入输出维度与层间连接逻辑。
  3. Batch 归一化:对一个批次数据的分布进行标准化调整,加速训练收敛、稳定梯度,在图像分类任务中尤为常用。
  4. 特征图:特指卷积操作输出的二维 / 三维数据,承载了输入图像的局部特征信息,是卷积层的核心输出形式。
  5. 学习率调度器:动态修改基础学习率,在训练后期降低学习率,让模型更稳定地收敛到最优解。

二、卷积操作标准流程
  1. 特征提取阶段 :输入 → 卷积层 → Batch 归一化层(可选) → 池化层 → 激活函数 → 下一层
    • 卷积层:提取局部特征
    • Batch 归一化:优化数据分布,加速训练
    • 池化层:降维并保留关键特征
    • 激活函数:引入非线性,增强模型表达能力
  2. 分类输出阶段:Flatten(展平特征图) → Dense(全连接层,可搭配 Dropout 防过拟合) → Dense(输出层,对应分类类别数)

💡 补充说明
  • Batch 归一化通常放在卷积层之后、激活函数之前,也有部分架构将其放在激活函数后,需根据具体任务选择。
  • 特征图仅由卷积层产生,池化层输出是特征图的降采样版本,全连接层输出为一维向量,不再称为特征图。
  • 学习率调度器需配合优化器使用,常见策略有 StepLR、CosineAnnealing 等。

🧩 简单 CNN PyTorch 代码模板(含数据增强、BatchNorm、学习率调度器)

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

# ---------------------- 1. 数据增强与加载 ----------------------
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # 随机裁剪
    transforms.RandomHorizontalFlip(),     # 随机水平翻转
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),  # 归一化
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# 以CIFAR10为例
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

# ---------------------- 2. CNN模型定义 ----------------------
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        # 特征提取部分:卷积 -> BatchNorm -> 池化 -> 激活
        self.features = nn.Sequential(
            # 第一层
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  # 输入通道3,输出32
            nn.BatchNorm2d(32),                          # Batch归一化
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),       # 池化降维
            
            # 第二层
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # 第三层
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
        # 分类部分:Flatten -> Dense -> Dropout -> Output
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 512),  # 32x32经3次池化后为4x4
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),              # Dropout防过拟合
            nn.Linear(512, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# ---------------------- 3. 训练配置 ----------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)  # 学习率调度器:每30轮学习率×0.1

# ---------------------- 4. 训练循环 ----------------------
def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

def test():
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.2f}%)\n')

# 运行训练
for epoch in range(1, 61):
    train(epoch)
    test()
    scheduler.step()  # 更新学习率

@浙大疏锦行

相关推荐
云烟成雨TD1 小时前
Spring AI Alibaba 1.x 系列【69】Token 用量统计
java·人工智能·spring
十三画者1 小时前
【AI学习笔记】:DeepSeek 大模型本地部署与调用实战指南
人工智能
丁常彦-自媒体-常言道1 小时前
从首发4nm智驾芯片到兜底城市领航安全,比亚迪开启AI新征程
人工智能
小杨在厦门2 小时前
从AI验布到智能质检:纺织企业智能化升级的三个台阶
人工智能·服装·服装厂·服装机械·铺布机
达之云*驭影2 小时前
解锁流量密码:详解抖音AI智能推荐封面功能
人工智能
火山引擎开发者社区3 小时前
ArkClaw 投研助理 —— 零门槛做投研,从一句话开始产出你的第一份深度研报
人工智能
码农小白AI3 小时前
AI报告审核加速融入自动化实验室:IACheck破解智能设备时代报告管理新挑战
运维·人工智能·自动化
xingyuzhisuan3 小时前
自建聚合网关VS第三方聚合平台,适配场景与数据实测
人工智能·ai·云计算·oneapi
tedcloud1233 小时前
DeepSeek-TUI部署教程:打造CLI AI助手环境
服务器·人工智能·word·excel·dreamweaver