深度学习---pytorch搭建深度学习模型(附带图片五分类实例)

一、PyTorch搭建深度学习模型流程

1. 环境准备

安装PyTorch及相关库:

bash 复制代码
pip install torch torchvision numpy matplotlib
2. 数据准备
  • 数据集加载:使用内置数据集(如CIFAR-10)或自定义数据集。
  • 数据预处理:包括归一化、数据增强(随机翻转、旋转等)。
  • 数据划分:将数据集分为训练集、验证集和测试集。
  • 数据加载器 :使用DataLoader实现批量加载。
3. 模型构建
  • 网络结构 :通过继承nn.Module定义模型,使用卷积层、池化层、全连接层等。
  • 激活函数:如ReLU、Sigmoid。
  • 正则化:Dropout层、BatchNorm层。
  • 损失函数 :如交叉熵损失(CrossEntropyLoss)。
  • 优化器:如Adam、SGD。
4. 训练流程
  • 前向传播:计算模型输出。
  • 损失计算:根据预测和标签计算损失。
  • 反向传播 :通过loss.backward()计算梯度。
  • 参数更新 :优化器通过optimizer.step()更新权重。
  • 训练循环:多轮迭代训练,记录训练损失和准确率。
5. 验证与测试
  • 模型评估模式 :使用model.eval()关闭Dropout和BatchNorm。
  • 禁用梯度计算 :使用torch.no_grad()加速推理。
  • 指标计算:如准确率、F1分数。
6. 模型保存与加载
  • 保存模型参数:torch.save(model.state_dict(), "model.pth")
  • 加载模型参数:model.load_state_dict(torch.load("model.pth"))
7. 高级功能
  • 学习率调度 :动态调整学习率(如StepLR)。
  • 自定义层/损失函数 :继承nn.Moduleautograd.Function
  • 多GPU训练 :使用DataParallelDistributedDataParallel

二、图片五分类实例(基于CIFAR-10子集)

1. 数据准备
python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 数据预处理与增强
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# 加载CIFAR-10并过滤前5类
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True)

# 过滤出前5类(类别0-4)
train_idx = np.isin(train_dataset.targets, [0,1,2,3,4])
test_idx = np.isin(test_dataset.targets, [0,1,2,3,4])

train_dataset = Subset(train_dataset, np.where(train_idx)[0])
test_dataset = Subset(test_dataset, np.where(test_idx)[0])

# 应用Transform
train_dataset.dataset.transform = transform_train
test_dataset.dataset.transform = transform_test

# 划分训练集和验证集
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

# 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
2. 模型定义
python 复制代码
class CNN(nn.Module):
    def __init__(self, num_classes=5):
        super(CNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(32 * 8 * 8, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

model = CNN(num_classes=5).to(device)
3. 训练配置
python 复制代码
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
4. 训练循环
python 复制代码
num_epochs = 10
best_val_acc = 0.0

for epoch in range(num_epochs):
    # 训练阶段
    model.train()
    train_loss, train_correct, total = 0, 0, 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        train_correct += predicted.eq(labels).sum().item()

    # 验证阶段
    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()

    # 打印结果
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss/total:.4f}, Acc: {train_correct/total:.4f}")
    print(f"Val Loss: {val_loss/val_total:.4f}, Acc: {val_correct/val_total:.4f}\n")

    # 保存最佳模型
    if val_correct/val_total > best_val_acc:
        best_val_acc = val_correct/val_total
        torch.save(model.state_dict(), "best_model.pth")
5. 测试模型
python 复制代码
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
test_correct, test_total = 0, 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

print(f"Test Accuracy: {test_correct/test_total:.4f}")

三、关键点总结

  1. 数据增强:通过随机翻转、旋转提升模型泛化能力。
  2. 模型结构:使用卷积层提取特征,全连接层进行分类。
  3. 设备管理:自动检测GPU加速训练。
  4. 训练技巧:Dropout防止过拟合,Adam优化器自适应学习率。
  5. 模型保存:保存验证集上表现最好的模型。
相关推荐
道传科技上位机1 小时前
深度学习环境搭建(pycharm+yolov5)
深度学习·yolo·pycharm
大模型铲屎官2 小时前
【深度学习-Day 29】PyTorch模型持久化指南:从保存到部署的第一步
人工智能·pytorch·python·深度学习·机器学习·大模型·llm
一点.点3 小时前
李沐动手深度学习(pycharm中运行笔记)——11.模型选择+过拟合欠拟合
pytorch·深度学习
西猫雷婶4 小时前
python学智能算法(十二)|机器学习朴素贝叶斯方法初步-拉普拉斯平滑计算条件概率
开发语言·人工智能·python·深度学习·机器学习·矩阵
半桔4 小时前
【Linux手册】进程的状态:从创建到消亡的“生命百态”
linux·运维·服务器·汇编·深度学习·面试
知舟不叙5 小时前
深度学习——迁移学习(Transfer Learning)
人工智能·深度学习·迁移学习
强盛小灵通专卖员5 小时前
多智能体强化学习与图神经网络-无人机基站
人工智能·深度学习·神经网络·机器学习·无人机·核心期刊·中文核心
有Li6 小时前
基于集体智能长尾识别的超声乳腺病变亚型分类|文献速递-深度学习医疗AI最新文献
论文阅读·人工智能·深度学习·医学生
嗷嗷哦润橘_6 小时前
如何用一台服务器用dify私有部署通用的大模型应用?
运维·服务器·人工智能·python·深度学习·计算机视觉
白熊1887 小时前
【深度学习】卷积神经网络(CNN):计算机视觉的革命性引擎
深度学习·计算机视觉·cnn