卷积神经网络(CNN)入门实践及Sequential 容器封装

学习链接:https://www.bilibili.com/video/BV1hE411t7RN?t=1.4&p=22

推荐网站:CS231n 用于计算机视觉的深度学习

一、CNN 核心层的作用与原理

在搭建模型前,先明确 CNN 中各核心层的功能:

层类型 作用 关键参数示例
卷积层(Conv2d) 提取图像局部特征(如边缘、纹理),通过卷积核实现特征映射 Conv2d(3, 32, 5, padding=2)(输入通道 3,输出通道 32,核大小 5×5,填充 2)
池化层(MaxPool2d) 下采样压缩特征图,减少计算量,同时保留关键特征 MaxPool2d(2)(2×2 窗口做最大池化)
展平层(Flatten) 将二维特征图转换为一维向量,为全连接层做准备 ------
全连接层(Linear) 对提取的特征做非线性变换,最终实现分类或回归任务 Linear(1024, 64)(输入维度 1024,输出维度 64)

二、两种 CNN 模型构建方式对比

我们可以用 "分步骤定义层""Sequential 容器封装" 两种方式构建完全等价的 CNN 模型。

方式 1:分步骤定义每一层

这种方式更直观,适合初学者理解每一层的执行顺序:

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

class Prayer(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        # 定义各层
        self.conv1 = Conv2d(3, 32, 5, padding=2)  # 第一层卷积
        self.maxpool1 = MaxPool2d(2)             # 第一层池化
        self.conv2 = Conv2d(32, 32, 5, padding=2) # 第二层卷积
        self.maxpool2 = MaxPool2d(2)             # 第二层池化
        self.conv3 = Conv2d(32, 64, 5, padding=2) # 第三层卷积
        self.maxpool3 = MaxPool2d(2)             # 第三层池化
        self.flatten = Flatten()                 # 展平层
        self.linear1 = Linear(1024, 64)          # 第一层全连接
        self.linear2 = Linear(64, 10)            # 第二层全连接(10类分类)

    def forward(self, x):
        # 按顺序执行各层
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.conv3(x)
        x = self.maxpool3(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.linear2(x)
        return x

# 测试模型
prayer = Prayer()
print(prayer)
input = torch.ones((64, 3, 32, 32))  # 模拟输入:批量64、3通道、32×32图像
output = prayer(input)
print(output.shape)  # 输出应为torch.Size([64, 10])

方式 2:Sequential 容器封装(更简洁)

当层的执行顺序很明确时,用Sequential把层 "打包",代码更简洁:

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential

class Prayer(nn.Module):
    def __init__(self):
        super(Prayer, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x

# 测试模型
# prayer = Prayer()
# print(tudui)
# input = torch.ones((64, 3, 32, 32))  # 模拟输入:批量64、3通道、32×32图像
# output = prayer(input)
# print(output.shape)  # 输出应为torch.Size([64, 10])

# 测试模型
prayer = Prayer()
print(prayer)
input = torch.ones((64, 3, 32, 32))
output = prayer(input)
print(output.shape)  # 同样输出torch.Size([64, 10])

三、用 TensorBoard 可视化模型结构

python 复制代码
from torch.utils.tensorboard import SummaryWriter

# 初始化SummaryWriter,指定日志保存路径
writer = SummaryWriter("../logs_seq")
# 传入模型和测试输入,生成计算图
writer.add_graph(prayer, input)
writer.close()

运行代码后,在终端执行命令:

python 复制代码
tensorboard --logdir=../logs_seq

然后打开浏览器访问http://localhost:6006,就能看到模型的计算图了,每一层的连接关系一目了然~

四、模型应用场景

这个 CNN 模型的结构非常经典,适合作为图像分类任务的 "baseline(基准)",比如:

  • 对 CIFAR-10 数据集(10 类彩色小图像)做分类;
  • 自定义小型图像数据集的分类任务;
  • 作为更复杂 CNN 模型的 "基石",在此基础上添加残差连接、注意力机制等模块。
相关推荐
那个村的李富贵1 小时前
光影魔术师:CANN加速实时图像风格迁移,让每张照片秒变大师画作
人工智能·aigc·cann
腾讯云开发者2 小时前
“痛点”到“通点”!一份让 AI 真正落地产生真金白银的实战指南
人工智能
CareyWYR2 小时前
每周AI论文速递(260202-260206)
人工智能
hopsky3 小时前
大模型生成PPT的技术原理
人工智能
禁默4 小时前
打通 AI 与信号处理的“任督二脉”:Ascend SIP Boost 加速库深度实战
人工智能·信号处理·cann
心疼你的一切4 小时前
昇腾CANN实战落地:从智慧城市到AIGC,解锁五大行业AI应用的算力密码
数据仓库·人工智能·深度学习·aigc·智慧城市·cann
AI绘画哇哒哒4 小时前
【干货收藏】深度解析AI Agent框架:设计原理+主流选型+项目实操,一站式学习指南
人工智能·学习·ai·程序员·大模型·产品经理·转行
数据分析能量站4 小时前
Clawdbot(现名Moltbot)-现状分析
人工智能
那个村的李富贵4 小时前
CANN加速下的AIGC“即时翻译”:AI语音克隆与实时变声实战
人工智能·算法·aigc·cann
二十雨辰4 小时前
[python]-AI大模型
开发语言·人工智能·python