卷积神经网络(CNN)前向传播手撕

题目

手写数字识别的卷积神经网络(CNN)代码,实现前向传播

解答

python 复制代码
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        # super(Net, self).__init__()
        super().__init__()
        self.model = nn.Sequential(
            # The size of the picture is 28x28
            nn.Conv2d(in_channels = 1,out_channels = 16,kernel_size = 3,stride = 1,padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            
            # The size of the picture is 14x14
            nn.Conv2d(in_channels = 16,out_channels = 32,kernel_size = 3,stride = 1,padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            
            # The size of the picture is 7x7
            nn.Conv2d(in_channels = 32,out_channels = 64,kernel_size = 3,stride = 1,padding = 1),
            nn.ReLU(),
            
            nn.Flatten(),
            nn.Linear(in_features = 7 * 7 * 64,out_features = 128),
            nn.ReLU(),
            nn.Linear(in_features = 128,out_features = 10),
            nn.Softmax(dim=1)
        )
        
    def forward(self,input):
        output = self.model(input)
        return output

net = Net()
# 将模型转换到device中,并将其结构显示出来
# print(net.to(device))

trainImgs = torch.Tensor(32, 1, 28, 28)  # [B, C, H, W]
outputs = net(trainImgs)
print(outputs.shape)  # torch.Size([32, 10])

注意

在 Python 中,super(Net, self).__init__()或``super().__init__() 的作用是调用父类的构造函数 ,确保子类 Net 继承自父类(如 torch.nn.Module)的属性和方法被正确初始化。

1. 代码含义

  • super():返回父类的代理对象,用于调用父类的方法。

  • Net:当前子类的名称。

  • self:当前子类的实例对象。

  • __init__():父类的构造函数方法。

组合起来

调用 Net 的父类(例如 torch.nn.Module)的 __init__() 方法,确保父类的初始化逻辑被执行。


2. 为什么需要这行代码?

  • 继承父类功能

    在 PyTorch 中,自定义神经网络模型必须继承 torch.nn.Module

    父类 Module 内部定义了模型的核心机制(如参数管理、GPU 转换等)。

    如果不调用父类的 __init__(),这些功能将无法正确初始化。

  • 避免潜在错误

    如果省略这行代码,子类 Net 将无法使用 Module 的功能,导致以下问题:

    • 模型参数(如 Conv2d 的权重)不会被识别和优化。

    • 无法将模型移动到 GPU(.to(device))。

    • 无法正确保存或加载模型(torch.save / torch.load)。

3.在 PyTorch 中的具体作用

在 PyTorch 模型中,父类 torch.nn.Module__init__() 会做以下关键操作:

  1. 注册参数(Parameters)和子模块(Submodules)

    self.conv1self.linear 等子层添加到模型的参数列表中,优化器(如 torch.optim.SGD)才能找到并更新这些参数。

  2. 设备管理

    跟踪模型所在的设备(CPU/GPU),确保输入数据和模型参数在同一设备上。

  3. 模型序列化

    支持模型的保存(torch.save)和加载(torch.load)。

相关推荐
互联网江湖5 分钟前
携程当学胖东来
人工智能
陌殇殇19 分钟前
001 Spring AI Alibaba框架整合百炼大模型平台 — 快速入门
人工智能·spring boot·ai
Proxy_ZZ027 分钟前
用Matlab绘制BER曲线对比SPA与Min-Sum性能
人工智能·算法·机器学习
黎阳之光28 分钟前
黎阳之光:以视频孪生领跑全球,赋能数字孪生水利智能监测新征程
大数据·人工智能·算法·安全·数字孪生
宇擎智脑科技38 分钟前
基于 SAM3 + FastAPI 搭建智能图像标注工具实战
人工智能·计算机视觉
F_U_N_1 小时前
效率提升80%:AI全流程研发真实项目落地复盘
人工智能·ai编程
月诸清酒1 小时前
24-260409 AI 科技日报 (Gemma 4发布一周下载破千万,开源模型生态加速演进)
人工智能·开源
2501_933329551 小时前
技术架构深度解析:Infoseek舆情监测系统的全链路设计与GEO时代的技术实践
开发语言·人工智能·分布式·架构
X journey1 小时前
机器学习进阶(16):如何防止过拟合
人工智能·机器学习
AI_Claude_code1 小时前
ZLibrary访问困境方案四:利用Cloudflare Workers等边缘计算实现访问
javascript·人工智能·爬虫·python·网络爬虫·边缘计算·爬山算法