卷积神经网络(CNN)前向传播手撕

题目

手写数字识别的卷积神经网络(CNN)代码,实现前向传播

解答

python 复制代码
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        # super(Net, self).__init__()
        super().__init__()
        self.model = nn.Sequential(
            # The size of the picture is 28x28
            nn.Conv2d(in_channels = 1,out_channels = 16,kernel_size = 3,stride = 1,padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            
            # The size of the picture is 14x14
            nn.Conv2d(in_channels = 16,out_channels = 32,kernel_size = 3,stride = 1,padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            
            # The size of the picture is 7x7
            nn.Conv2d(in_channels = 32,out_channels = 64,kernel_size = 3,stride = 1,padding = 1),
            nn.ReLU(),
            
            nn.Flatten(),
            nn.Linear(in_features = 7 * 7 * 64,out_features = 128),
            nn.ReLU(),
            nn.Linear(in_features = 128,out_features = 10),
            nn.Softmax(dim=1)
        )
        
    def forward(self,input):
        output = self.model(input)
        return output

net = Net()
# 将模型转换到device中,并将其结构显示出来
# print(net.to(device))

trainImgs = torch.Tensor(32, 1, 28, 28)  # [B, C, H, W]
outputs = net(trainImgs)
print(outputs.shape)  # torch.Size([32, 10])

注意

在 Python 中,super(Net, self).__init__()或``super().__init__() 的作用是调用父类的构造函数 ,确保子类 Net 继承自父类(如 torch.nn.Module)的属性和方法被正确初始化。

1. 代码含义

  • super():返回父类的代理对象,用于调用父类的方法。

  • Net:当前子类的名称。

  • self:当前子类的实例对象。

  • __init__():父类的构造函数方法。

组合起来

调用 Net 的父类(例如 torch.nn.Module)的 __init__() 方法,确保父类的初始化逻辑被执行。


2. 为什么需要这行代码?

  • 继承父类功能

    在 PyTorch 中,自定义神经网络模型必须继承 torch.nn.Module

    父类 Module 内部定义了模型的核心机制(如参数管理、GPU 转换等)。

    如果不调用父类的 __init__(),这些功能将无法正确初始化。

  • 避免潜在错误

    如果省略这行代码,子类 Net 将无法使用 Module 的功能,导致以下问题:

    • 模型参数(如 Conv2d 的权重)不会被识别和优化。

    • 无法将模型移动到 GPU(.to(device))。

    • 无法正确保存或加载模型(torch.save / torch.load)。

3.在 PyTorch 中的具体作用

在 PyTorch 模型中,父类 torch.nn.Module__init__() 会做以下关键操作:

  1. 注册参数(Parameters)和子模块(Submodules)

    self.conv1self.linear 等子层添加到模型的参数列表中,优化器(如 torch.optim.SGD)才能找到并更新这些参数。

  2. 设备管理

    跟踪模型所在的设备(CPU/GPU),确保输入数据和模型参数在同一设备上。

  3. 模型序列化

    支持模型的保存(torch.save)和加载(torch.load)。

相关推荐
小陈phd10 分钟前
TensorRT 入门完全指南(一)——从核心定义到生态工具全解析
人工智能·笔记
CeshirenTester28 分钟前
从0到1学自动化测试该怎么规划?
人工智能
:mnong31 分钟前
以知识驱动 AIAD 行业进化
人工智能·cad
ZhengEnCi39 分钟前
03-注意力机制基础 📚
人工智能
我是大聪明.1 小时前
CUDA矩阵乘法优化:共享内存分块与Warp级执行机制深度解析
人工智能·深度学习·线性代数·机器学习·矩阵
郑寿昌1 小时前
文化差异如何重塑AI语言理解能力
人工智能
lizhihai_991 小时前
股市学习心得-六张分时保命图
大数据·人工智能·学习
码云数智-大飞1 小时前
大模型幻觉:成因解析与有效避免策略
人工智能·深度学习
我星期八休息1 小时前
IT疑难杂症诊疗室:AI时代工程师Superpowers进化论
linux·开发语言·数据结构·人工智能·python·散列表