和鲸社区深度学习基础训练营2025年关卡4

使用 pytorch 构建一个简单的卷积神经网络(CNN)模型,完成对 CIFAR-10 数据集的图像分类任务。 直接使用 CNN 进行分类的模型性能。 提示: 数据集:CIFAR-10 网络结构:可以使用 2-3 层卷积层,ReLU 激活,MaxPooling 层,最后连接全连接层。

复制代码
#1. 数据预处理与加载
import torch
import torchvision
import torchvision.transforms as transforms

# 数据增强与归一化(使用CIFAR-10官方均值和标准差)
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),       # 随机裁剪增强泛化性
    transforms.RandomHorizontalFlip(),          # 随机水平翻转
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

# 加载数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# 数据加载器
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

#2. CNN模型架构
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)  # 输入通道3(RGB),输出32通道
        self.bn1 = nn.BatchNorm2d(32)                 # 批量归一化
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool = nn.MaxPool2d(2, 2)                # 池化层(尺寸减半)
        self.fc1 = nn.Linear(128 * 4 * 4, 256)       # 全连接层(输入尺寸计算:32x32 → 16x16 → 8x8 → 4x4)
        self.fc2 = nn.Linear(256, 10)                 # 输出10类

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))  # 32x32 → 16x16
        x = self.pool(F.relu(self.bn2(self.conv2(x))))  # 16x16 → 8x8
        x = self.pool(F.relu(self.bn3(self.conv3(x))))  # 8x8 → 4x4
        x = x.view(-1, 128 * 4 * 4)                    # 展平
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 实例化模型并移至GPU(若可用)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = SimpleCNN().to(device)

#3. 训练与优化
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)  # 每5轮学习率×0.1

# 训练循环(10个epoch)
for epoch in range(10):
    net.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 100 == 99:  # 每100批次打印一次
            print(f'Epoch [{epoch+1}/10], Step [{i+1}/{len(trainloader)}], Loss: {running_loss/100:.3f}')
            running_loss = 0.0
    
    scheduler.step()  # 更新学习率
    print(f"Epoch {epoch+1} completed, learning rate: {scheduler.get_last_lr()[0]:.6f}")

#4. 模型评估与可视化
net.eval()
correct, total = 0, 0
with torch.no_grad():
    for (images, labels) in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')

运行结果:

相关推荐
zhaoshuzhaoshu4 小时前
人工智能(AI)发展史:详细里程碑
人工智能·职场和发展
Luke~4 小时前
阿里云计算巢已上架!3分钟部署 Loki AI 事故分析引擎,SRE 复盘时间直接砍掉 80%
人工智能·阿里云·云计算·loki·devops·aiops·sre
weixin_156241575764 小时前
基于YOLOv8深度学习花卉识别系统摄像头实时图片文件夹多图片等另有其他的识别系统可二开
大数据·人工智能·python·深度学习·yolo
QQ676580084 小时前
AI赋能轨道交通智能巡检 轨道交通故障检测 轨道缺陷断裂检测 轨道裂纹识别 鱼尾板故障识别 轨道巡检缺陷数据集深度学习yolo第10303期
人工智能·深度学习·yolo·智能巡检·轨道交通故障检测·鱼尾板故障识别·轨道缺陷断裂检测
小陈工4 小时前
2026年4月7日技术资讯洞察:下一代数据库融合、AI基础设施竞赛与异步编程实战
开发语言·前端·数据库·人工智能·python
tq10864 小时前
组织的本质:从科层制到伴星系统的决断理论
人工智能
科技与数码4 小时前
互联网保险迎来新篇章,元保方锐分享行业发展前沿洞察
大数据·人工智能
云程笔记4 小时前
002.计算机视觉与目标检测发展简史:从传统方法到深度学习
深度学习·yolo·目标检测·计算机视觉
汽车仪器仪表相关领域4 小时前
NHFID-1000型非甲烷总烃分析仪:技术破局,重构固定污染源监测新体验
java·大数据·网络·人工智能·单元测试·可用性测试·安全性测试