卷积神经网络(CNN)前向传播手撕

题目

手写数字识别的卷积神经网络(CNN)代码,实现前向传播

解答

python 复制代码
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        # super(Net, self).__init__()
        super().__init__()
        self.model = nn.Sequential(
            # The size of the picture is 28x28
            nn.Conv2d(in_channels = 1,out_channels = 16,kernel_size = 3,stride = 1,padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            
            # The size of the picture is 14x14
            nn.Conv2d(in_channels = 16,out_channels = 32,kernel_size = 3,stride = 1,padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            
            # The size of the picture is 7x7
            nn.Conv2d(in_channels = 32,out_channels = 64,kernel_size = 3,stride = 1,padding = 1),
            nn.ReLU(),
            
            nn.Flatten(),
            nn.Linear(in_features = 7 * 7 * 64,out_features = 128),
            nn.ReLU(),
            nn.Linear(in_features = 128,out_features = 10),
            nn.Softmax(dim=1)
        )
        
    def forward(self,input):
        output = self.model(input)
        return output

net = Net()
# 将模型转换到device中,并将其结构显示出来
# print(net.to(device))

trainImgs = torch.Tensor(32, 1, 28, 28)  # [B, C, H, W]
outputs = net(trainImgs)
print(outputs.shape)  # torch.Size([32, 10])

注意

在 Python 中,super(Net, self).__init__()或``super().__init__() 的作用是调用父类的构造函数 ,确保子类 Net 继承自父类(如 torch.nn.Module)的属性和方法被正确初始化。

1. 代码含义

  • super():返回父类的代理对象,用于调用父类的方法。

  • Net:当前子类的名称。

  • self:当前子类的实例对象。

  • __init__():父类的构造函数方法。

组合起来

调用 Net 的父类(例如 torch.nn.Module)的 __init__() 方法,确保父类的初始化逻辑被执行。


2. 为什么需要这行代码?

  • 继承父类功能

    在 PyTorch 中,自定义神经网络模型必须继承 torch.nn.Module

    父类 Module 内部定义了模型的核心机制(如参数管理、GPU 转换等)。

    如果不调用父类的 __init__(),这些功能将无法正确初始化。

  • 避免潜在错误

    如果省略这行代码,子类 Net 将无法使用 Module 的功能,导致以下问题:

    • 模型参数(如 Conv2d 的权重)不会被识别和优化。

    • 无法将模型移动到 GPU(.to(device))。

    • 无法正确保存或加载模型(torch.save / torch.load)。

3.在 PyTorch 中的具体作用

在 PyTorch 模型中,父类 torch.nn.Module__init__() 会做以下关键操作:

  1. 注册参数(Parameters)和子模块(Submodules)

    self.conv1self.linear 等子层添加到模型的参数列表中,优化器(如 torch.optim.SGD)才能找到并更新这些参数。

  2. 设备管理

    跟踪模型所在的设备(CPU/GPU),确保输入数据和模型参数在同一设备上。

  3. 模型序列化

    支持模型的保存(torch.save)和加载(torch.load)。

相关推荐
Joern-Lee5 分钟前
初探机器学习与深度学习
人工智能·深度学习·机器学习
云卓SKYDROID20 分钟前
无人机数据处理与特征提取技术分析!
人工智能·科技·无人机·科普·云卓科技
R²AIN SUITE32 分钟前
金融合规革命:R²AIN SUITE 如何重塑银行业务智能
大数据·人工智能
Code_流苏41 分钟前
《Python星球日记》 第69天:生成式模型(GPT 系列)
python·gpt·深度学习·机器学习·自然语言处理·transformer·生成式模型
新知图书1 小时前
DeepSeek基于注意力模型的可控图像生成
人工智能·深度学习·计算机视觉
白熊1881 小时前
【计算机视觉】OpenCV实战项目: Fire-Smoke-Dataset:基于OpenCV的早期火灾检测项目深度解析
人工智能·opencv·计算机视觉
↣life♚1 小时前
从SAM看交互式分割与可提示分割的区别与联系:Interactive Segmentation & Promptable Segmentation
人工智能·深度学习·算法·sam·分割·交互式分割
zqh176736464691 小时前
2025年阿里云ACP人工智能高级工程师认证模拟试题(附答案解析)
人工智能·算法·阿里云·人工智能工程师·阿里云acp·阿里云认证·acp人工智能
程序员小杰@1 小时前
【MCP教程系列】SpringBoot 搭建基于 Spring AI 的 SSE 模式 MCP 服务
人工智能·spring boot·spring
上海锝秉工控2 小时前
智能视觉检测技术:制造业质量管控的“隐形守护者”
人工智能·计算机视觉·视觉检测