1 nn简介
torch.nn是Neural Networks(神经网络)的缩写。神经网络是一种受生物神经系统启发的计算模型,是深度学习的基础。它由大量相互连接的节点(称为"神经元")组成,能够学习复杂的模式和关系。
1.1 神经网络概述
1 核心结构
可以用图来表示神经网络,把最左边的一列称为输入层,最右边的一列称为输出层,中间的一列称为中间层。中间层有时也称为隐藏层,"隐藏"一词的意思是,隐藏层的神经元(和输入层、输出层不同)不可见。

图中神经网络一共由3层神经元构成,但实质上只有2层神经元有权重,所以将其称为"2层网络"。请注意,有的地方也会根据构成网络的层数,把图中网络称为"3层网络"。当然隐藏层不一定只有一层,把只有一层隐藏层的神经网络称为浅层神经网络,如上图所示;把大于或等于2层隐藏层的网络称为深度神经网络(DNN),可见,深度神经网络里的"深度",就体现在:有多层隐藏层(中间层),层数越多,网络越能学习更复杂、更抽象的特征。
图中的各个节点被称作神经网络的基本单元------人工神经元,隐藏层和输出层的神经元被称作是标准人工神经元,内部固定由4部分构成。
(1)输入信号:x1,x2,...,xn,来自上一层神经元的输出。
(2)权重(Weight):w1,w2,...,wn,每条输入对应一个权重,表示这条连接的重要程度。
(3)偏置(Bias):b,一个常数,用来整理平移激活阈值,让模型更灵活。
(4)激活函数(Activation Function):f(),引用非线性,决定神经元是否"firing",让多层网络有意义。
神经元的工作由两步完成:
第一步:加权求和,z=w1x1+w2x2+...+wnxn+b,z体现工作的线性部分;
第二步:激活,a=f(z),z是上一步的加权求和结果,经由激活函数f输出结果a,这里激活函数一般是非线性变换,正是因为这里的非线性部分,才使得之前的线性空间产生弯曲,才能让模型能够拟合各种复杂的曲线。
直观来看各权重w控制神经元"听谁多一些",偏置b控制神经元"多容易被激活",激活函数f控制神经元"要不要输出,输出多大"。一句话总结,一个神经元就是一个可调参数的非线性函数,后续会根据神经网络的工作原理进一步分析神经元本质。
在torch中可以通过如下一句话创建包含神经元的中间层:
nn.Linear(784, 32)
上述代码中,表示一个有784个输入,32个神经元的隐藏层,其中每个神经元有784个权重w及1个偏置b。
2 工作原理
先给出神经网络各个层的核心功能,输入层用于喂数据,隐藏层用于提取特征(层数代表深度,越深越能获取复杂的规律),输出层用于给出结果。所以神经网络可以看作是一个可以学习的函数逼近器,它的目标是学习一个函数f(x)≈y,它通过喂入数据进行学习,学习的本质就是迭代调整各层权重和偏置,神经网络层数越多越能包含更多的权重和偏置,就越能表示更复杂的函数,随着学习的进行会最终给出一个复合函数f(x)=fL(fL-1(...f1(x))),复合函数的输出能不断逼近真实值y,这里符合函数的每一层fi(x)=σ(wx+b),下面结合神经网络的几个核心概念继续进行讲解。
(1)前向传播(预测过程)
所谓的前向传播就是各层各神经元执行各自的工作,数学上每个神经元的工作可以用公式a=f(wx+b)表示,假设L层神经网络的输入层神经元个数是m,第1层隐藏层的神经元个数是n,则对于隐藏层的每个神经元其输出可以表示为:

则第一层隐藏层的输出可由以下公式表示:

其中A1、X1、B1是n维列向量,W1是nxm矩阵,X1对应隐藏层的m个输入,W1第i行对应隐藏层第i个神经元的m个权重,B1对应隐藏层n个神经元的偏置,A1对应隐藏层n个神经与的输出。依次类推,之后的L-1个层输出可表示为:

可见最后会得到输出层的输出Y,传播完成。
(2)损失函数(评价过程)
损失函数用于获取预测值y_pred和真实值y_true的偏差:Loss = L(y_pred, y_true),Loss越小,说明获得的复合函数f(由所有权重和偏置确定)输出越准确,即损失函数能评价正向传播结果的"好坏"。
(3)反向传播(学习过程)
反向传播的目标是调整w,b使得Loss最小,如何调整w,b呢?这里需要先回忆一下梯度的概念。对于一个多元函数f(x1,x2,...,xn),它的梯度是一个向量,由函数对每个变量的偏导数组成:

这里∇f又称作梯度算子,它是一个向量函数,但是在反向传播中使用的是某一点上的梯度,这时梯度是梯度算子在该点上计算出来的结果,是纯数值向量。在反向传播过程中,会计算Loss对每一层输出的梯度,根据链式法则一路往回传,算出Loss对每一个w、b的梯度,之后沿梯度反方向更新权重:

公式中η为学习率(learning rate),它必须是正数,决定参数更新步长,过大不稳定,过小收敛慢,最佳值通常在"稳定下降的最大范围"。公式中的负号体现了沿梯度反方向更新的含义,即当梯度值为正时,说明该权重对损失的影响是"正向"的,权重增大损失变大,要想Loss下降必须减小w,等价于公式中w减去一个正值;相反当梯度值为负时,说明该权重对损失的影响是"反向"的,权重增大损失变小,要想Loss下降增大w就行,等价于公式中w减去一个负值;最后当梯度为零时,说明该权重对损失无影响,保持该权重不变就行。
1.2 主要类型
主流神经网络按结构与计算模式可以分为几大类。
1 前馈神经网络(FNN,Feedforward Neural Network)
这是最基础的神经网络类型,如多层感知机(MLP)就是典型的前馈神经网络。
核心机制:由输入层、多个隐藏层和输出层组成,神经元之间全连接,信息单向向前流动。
特点:每个神经元与前一层的所有神经元相连。
应用:分类任务、回归分析、简单的模式识别,对应数据为表格、简单特征。
2 卷积神经网络(CNN,Convolutional Neural Network)
CNN 是计算机视觉(CV)领域的霸主。
核心机制:通过"卷积核(Filter)"提取局部特征,具有"权值共享"和"空间局部性"的特性。
典型能力:卷积层用于提取边缘、纹理等特征;池化层(Pooling)用于压缩特征图,减少计算量,增强平移不变性。
应用:图像分类(ResNet)、目标检测(YOLO)、人脸识别、医学影像分析,对应的数据为图像、视频。
3 循环神经网络(RNN,Recurrent Neural Network)
RNN专门用于处理序列数据。
核心机制:引入了"记忆"概念,神经元的输出会反馈给自身作为下一时刻的输入。
典型能力:LSTM (长短期记忆网络)解决了标准 RNN 的梯度消失问题,擅长处理长序列;GRU (门控循环单元)LSTM 的精简版,计算开销更小。
应用:语音识别、机器翻译(旧版)、时间序列预测(股价、天气),对应的数据为语音、简单文本。
4 注意力机制网络(Transformer)
这是目前AI领域的"版本答案",也是 GPT、BERT等大模型的核心。
核心机制:自注意力机制(Self-Attention),它不像RNN那样按顺序读取,而是能够同时观察序列中的所有元素,并计算它们之间的关联权重。
典型能力:并行计算能力极强,能捕捉超长距离的上下文关系。
应用:大语言模型(LLM)、多模态学习、视觉Transformer(ViT),对应的数据为文本、多模态、大模型。
5 生成对抗网络(GAN)
GAN 由两个网络相互"博弈"组成。
生成器 (Generator):负责制造假数据,试图欺骗判别器。
判别器 (Discriminator):负责分辨数据是真的还是生成的。
结果:两者在对抗中共同进步,直到生成器能制造出真假难辨的数据。
应用:Deepfake(换脸)、图像生成(StyleGAN)、图像修复,对应的数据为图像。
2 troch.nn主要组件
2.1 模型(Module)
这是整个torch.nn的根基,无论是一个简单的线性层(nn.Linear),还是拥有数千亿参数的GPT模型,其本质都是一个继承自nn.Module的类。可以把nn.Module理解为一个高级容器,它不仅装着网络层,还负责管理参数、处理设备转移(CPU/GPU)以及控制数据的流向。
1 基本结构
当定义一个模型时,必须完成两件事:
(1)init (初始化):定义网络里有哪些"组件"(积木)。
(2)forward (前向传播):定义数据进来后,如何穿过这些组件(拼接逻辑)。
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
2 四大核心能力
(1)自动参数追踪 (Parameter Tracking)
这是nn.Module最强大的地方,当把nn.Linear赋值给self.fc时,nn.Module会自动把这个层里的weight和bias注册到模型的参数列表中。可以直接调用model.parameters()把所有参数一次性丢给优化器(Optimizer),而不需要手动去写[w1, b1, w2, b2...]。
(2)递归式嵌套 (Recursive Nesting)
nn.Module既可以是"层",也可以是"完整的模型",一个nn.Module内部可以包含其他的nn.Module。比如一个ResNet模型是由多个ResBlock组成的,而每个ResBlock又是nn.Module的子类。
(3)状态管理 (Buffer&Parameter)
它区分了两种数据:
Buffers:模型运行需要但不需要更新的数据(如Batch Normalization中的均值和方差);
Parameters:需要通过梯度下降更新的权重(如线性层的W);
转换命令可以一键把整个模型从CPU搬到GPU上,如model.to('cuda')。
(4)训练/评估模式切换
神经网络在"训练"和"推理"时的行为可能不同(例如Dropout和BatchNorm)。
model.train():激活训练行为。
model.eval():关闭Dropout,固定BatchNorm,用于测试。
2.2 层(Layers)
1 线性层
在torch.nn模块中,nn.Linear是最基础、也是使用频率最高的层,它实现了神经网络中核心的线性变换(也称为全连接层或仿射层)。
(1)数学原理
nn.Linear执行的是经典的矩阵运算公式:

其中:
x:输入张量(Input Tensor)。
A:权重矩阵(Weight),在nn.Linear中存储为weight。
b:偏置向量(Bias),在 nn.Linear 中存储为bias。
y:输出张量(Output Tensor)。
(2)定义及参数详解
在torch中可以使用如下语句定义线性层:
nn.Linear(in_features, out_features, bias=True)
in_features:输入神经元的数量,即该层输入张量的维度;
out_features:输出神经元的数量,即该层输出张量的维度;
bias:是否添加偏置,默认为True,如果不加偏置,直线将经过原点。
一旦实例化一个Linear层,torch会自动随机初始化两个张量(Tensors):weight,形状为(out_features, in_features),注意它是反过来存的,目的是为了在计算xAT时矩阵维度能对上;bias,形状为(out_features,)。
(3)数据流转与维度变化
nn.Linear不仅能处理二维数据(Batch, Features),还能处理高维数据,但它只作用于最后一个维度。
二维输入:[Batch_Size, in_features]→[Batch_Size, out_features];
三维输入(常见于 Transformer):[Batch_Size, Seq_Len, in_features]→[Batch_Size, Seq_Len, out_features];
这里Batch_Size代表一次性送入神经网络训练的样本数量,Seq_Len代表一个样本内部包含的连续特征点数量,这在处理有先后顺序的数据时至关重要,例如处理文本时,一句话有20个单词,那么Seq_Len = 20。作用于最后一维指的是,如果数据形状是[32, 10, 128],表示nn.Linear(128, 64)会被复用32*10 = 320次,Batch_Size和Seq_Len的维度在经过线性层后保持不变,结果形状变为[32, 10, 64]。
(4)权重初始化
权重会使用Xavier(适合激活函数Sigmoid/Tanh)或He(适合激活函数ReLU/LeakyReLU)进行初始化,前者权重初始化后方差=1/输入神经元数,后者方差=2/输入神经元数,现代网络几乎全是ReLU,所以基本都用He。
(5)全连接
之所以叫Fully Connected(FC),是因为输出空间的每一个维度都与输入空间的每一个维度相连。如果输入有M个神经元,输出有N个神经元,则总共会有M*N个权重连接。全连接让模型能学习到特征之间最复杂的组合关系,但也正是因为这种全连接,当维度很高时(比如处理高清图片),参数量会变得巨大。
2 卷积层
nn.Conv2d是处理图像数据的核心组件,与全连接层(nn.Linear)观察全局不同,nn.Conv2d采用局部连接和权值共享的策略,能够高效提取图像的边缘、纹理和形状等空间特征。nn.Conv2d通过一个可学习的"卷积核(Kernel/Filter)"在输入图像上滑动,在每一个位置,卷积核与其覆盖的局部区域进行点积(Dot Product)运算,并将结果相加得到输出特征图(Feature Map)上的一个点。
Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
in_channels:int,输入通道数,如对于彩色RGB图片通道数是3,黑白图通道数是1;
out_channels:int,输出通道数,卷积核的数量,表示想提取多少种不同的特征;
kernel_size:int或tuple,卷积核大小,扫描窗口的大小,如3表示3*3;
stride:int或tuple,步长,扫描窗口每次滑动的跨度,默认为 1,步长越大,输出图像越小;
padding:int或tuple,填充,输入的每一条边补充0的层数,防止图像越卷越小,并保留边缘信息;
dilation:int或tuple,卷积核元素之间的间距;
groups:int,从输入通道到输出通道的阻塞连接数;
bias:bool,True表示添加偏置;
假如输入维度为(Hin, Win),输出维度为(Hout, Wout),则有以下计算关系:

对于二维卷积,有以下三大特性:
局部感知(Local Connectivity):每个神经元只看图像的一个局部区域,而不是整张图;
权值共享(Weight Sharing):同一个卷积核在整张图像上滑动。这意味着如果它学到了如何识别"垂直边缘",那么无论这个边缘在左上角还是右下角,它都能识别出来;
平移不变性(Translation Invariance):由于权值共享,目标在图中的位置改变,输出特征也会随之移动,但特征本身依然能被提取;
类似2维卷积,还有1维卷积和3维卷积。1D卷积的卷积核只在一个维度(通常是时间轴或序列轴)上滑动,它的处理对象是序列数据,如振幅随时间变化的音频信号、单词按顺序排列的自然语言、取值按时间变化的传感器数值等,它在寻找"局部时序模式",比如在一段音频中寻找某个特定的频率特征。3D卷积的卷积核在三个维度(长、宽、高/深)上滑动,处理对象是具有空间深度或时间连续性的数据,如视频数据[长,宽, 帧数(时间)]、医学影像(人体器官切片堆叠形成的3D实体)等,它在寻找"体积特征",比如在视频中识别"挥手"这个动作,或者在CT中识别肿瘤的立体形状。
class torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
class torch.nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
3 池化层
在卷积神经网络(CNN)中,池化层(Pooling Layer)紧跟在卷积层之后,如果说卷积层是在"找特征",那么池化层就是在"做摘要"。它的核心作用是下采样(Downsampling):在保留主要特征的同时,减小数据的空间尺寸(宽度和高度),从而减少计算量并防止过拟合。
(1)MaxPool2d
nn.MaxPool2d 是目前最常用的池化方式,它在窗口内只保留数值最大的一个像素,抛弃其他所有值,代表"最强特征的提取"。比如一个窗口里有一个像素点代表"猫的耳朵边缘",最大池化就会把这个最重要的信号留下来,而忽略掉周围模糊的背景,它的特点是平移不变性和特征突出。
class torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)
(2)MaxUnpool2d
如果说MaxPool2d是把图像"压缩"了,那么MaxUnpool2d就是试图把这个过程"还原"回去。但这里有一个关键:信息一旦丢失是无法凭空恢复的。 所以,MaxUnpool2d必须配合MaxPool2d留下的"小纸条"才能工作。MaxPool2d在取最大值时,可以偷偷记住那个最大值原本在窗口里的坐标,保存相应位置索引,即调用nn.MaxPool2d时,须设置return_indices=True,在之后还原位置时,nn.MaxUnpool2d接收压缩后的特征图,并根据保存的索引,把数值放回原来的位置,其余地方补零。MaxUnpool2d能精准地还原边缘和纹理的结构。
class torch.nn.MaxUnpool2d(kernel_size, stride=None, padding=0)
(3)AvgPool2d
nn.AvgPool2d会计算窗口内所有像素的平均值作为输出,代表"整体特征的平滑"。它不追求最极端的信号,而是保留区域内的背景信息,它的特点是信息保留和平衡噪声。
class torch.nn.AvgPool2d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True)
和卷积层类似,池化层也有相应的1维和3维形式。
4 归一化层
在深度学习中,nn.BatchNorm2d(二维批归一化)被称为深层网络的"续命神器"。如果没有它,超过 10 层的网络极难收敛;有了它,即便上百层的网络也能稳健训练。它的核心作用是:将每一层神经元的输出强行拉回到均值为 0、方差为 1 的标准正态分布附近。在深层网络训练中,前面层参数的微小变化,经过层层累加,会导致后面层接收到的输入分布发生剧烈波动。这导致的后果是神经元的激活值会掉进激活函数(如 Sigmoid 或 ReLU)的饱和区(梯度为 0 的区域),导致梯度消失,模型停止学习。为了解决这一问题,在进入激活函数前,先把数据"归一化",让它们重新回到对梯度敏感的活跃区。
对于一个Batch中的数据,nn.BatchNorm2d执行以下四个步骤:
(1)计算均值(μB):计算当前Batch中所有像素的平均值;
(2)计算方差(σB2):计算当前Batch的离散程度;
(3)标准化:

其中ε是一个极小的数,防止除数为0;
(4)缩放与位移:

γ(weight)和β(Bias)是两个可学习的参数,如果网络发现"标准正态分布"不适合这一层,它可以通过学习把分布再挪回去,这赋予了模型极大的灵活性。
class torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True)
num_features:特征通道数,须和前一层卷积输出的out_channels一致;
eps:稳定系数,就是公式中的ε;
momentum:动量系数,用于更新"运行时均值/方差"的权重;
affine:是否学习γ和β,如果设为False,该层就不包含可学习参数;
该层的放置位置一般为:卷积层(Conv)→批归一化 (BN)→激活函数(ReLU)。该层在训练模式和评估模式的行为有所不同:
训练模式(model.train()):使用当前Batch的均值和方差进行归一化。同时,它会用"滑动平均"的方法偷偷记下全局的均值和方差,存为running_mean和running_var。
评估模式(model.eval()):不再计算Batch统计量,它直接使用训练时存下来的全局running_mean和running_var。
5 Dropout
在深度学习中,nn.Dropout是一种极其简单却暴力有效的正则化(Regularization)技术。如果说BatchNorm是为了让网络"跑得稳",那么Dropout就是为了让网络"不读死书",防止过拟合(Overfitting)。它的核心理念是:在训练过程中,随机"关掉"一部分神经元,强迫网络不依赖于某些特定的路径。
(1)核心工作原理
在训练阶段的每一轮迭代中,nn.Dropout 会以概率p将输入张量中的部分元素强制置为0。
直观理解:想象一个团队在完成任务。如果某个人(神经元)太强,其他人就会产生依赖,变得"懒惰"。Dropout相当于随机给成员放假,逼着剩下的每个人都必须学会处理任务。这样,整个团队(网络)的鲁棒性就提高了。
数学修正:为了保证在训练和测试时,下一层接收到的信号总强度(期望值)一致,被保留下来的元素会被放大1/(1-p)倍。
class torch.nn.Dropout(p=0.5, inplace=False)
p:将元素置0的概率,默认0.5,在全连接层常用0.5,在靠近输入的层常用0.2;
in-place:若设置为True,会在原地执行操作,默认为False,设为True可以稍微节省一点显存,但调试时可能看不清原始输入;
Dropout在不同模式下的行为截然不同:
训练模式(model.train()):激活随机丢弃逻辑,这是模型"磨炼"的过程;
评估模式 (model.eval()):完全关闭丢弃逻辑,所有的神经元全部参与计算,且不再进行缩放;
6 Recurrent
torch.nn模块中,Recurrent Layers(循环层)是专门为处理时序数据或序列数据设计的,如果说CNN擅长处理"空间"信息(图片里的上下左右),那么 Recurrent Layers 就擅长处理"时间"信息(句子里的前后文、股票的历史波动)。
(1)核心特点
普通的全连接层(nn.Linear)是无记忆的:输入A得到输出A',输入B得到输出B',它们之间互不干涉。循环层(RNN/LSTM/GRU)则不同,在处理当前时刻t的输入xt 时,会参考上一时刻t-1留下的隐藏状态(Hidden State, ht-1)。换句话说,它的输出不仅取决于"现在看到了什么",还取决于"过去记住了什么"。
(2)三种循环模型
nn.RNN:最基础的模型,但在实际工程中很少单独使用,因为它存在严重的梯度消失问题,如果你给它一个很长的句子,它处理到结尾时,早就把开头的词忘光了。
nn.LSTM:为了解决 RNN 的健忘问题,LSTM 引入了"门控机制"。它有一个专门的"细胞状态(Cell State)",就像一条传送带,信息可以在上面流传很久而不被稀释。它还引入门控机制,其中,遗忘门决定丢弃哪些旧信息,输入门决定存入哪些新信息,输出门决定下一时刻输出什么。LSTM适用于长文本理解、复杂的语音识别。
nn.GRU:LSTM 的"精简版",将遗忘门和输入门合并为一个"更新门",它参数更少,计算速度比LSTM快,但在很多任务上效果不相上下。
2.3 损失函数(Loss Functions)
在torch.nn中,Loss Function(损失函数)是模型的"教鞭"。它负责计算模型输出与真实标签之间的距离,并将这个距离转化为一个标量,这个标量越小,说明模型学得越好。根据任务类型(回归、分类、序列),最常用的损失函数可以分为以下几大类:
1 回归任务(Regression)
用于预测连续的数值(如房价、股票、坐标)。
(1)nn.MSELoss (均方误差)
公式:计算预测值与真实值差值的平方的平均数。
特点:对离群点(Outliers)非常敏感,因为误差会被平方放大。
场景:最基础的回归任务。
(2)nn.L1Loss (平均绝对误差)
公式:计算预测值与真实值差值的绝对值的平均数。
特点:比MSE更稳健(Robust),不容易受离群点干扰。
(3)nn.HuberLoss/nn.SmoothL1Loss
特点:结合了 MSE 和 L1 的优点。在误差小时像 MSE(收敛快),在误差大时像 L1(对离群点不敏感)。常用于物体检测(如 YOLO)的边界框回归。
2 分类任务(Classification)
用于判断类别(如猫、狗、数字)。
(1)nn.CrossEntropyLoss (交叉熵损失)------王者级别
特点:它是LogSoftmax和NLLLoss(负对数似然损失)的结合体。
重要细节:使用它时,你的网络最后一层不需要加Softmax激活函数,它内部已经帮你做好了概率归一化。
场景:多分类任务(如MNIST手写数字、ImageNet图像识别)。
(2)nn.BCELoss & nn.BCEWithLogitsLoss (二元交叉熵)
BCELoss:输入必须是经过Sigmoid后的概率值。
BCEWithLogitsLoss:输入是原始输出(Logits),内部集成了Sigmoid。
场景:二分类任务(是/否)或多标签分类(一张图里既有猫又有狗)。
3 序列与特殊任务
(1)nn.CTCLoss (连接时序分类)
场景:处理输入序列和输出序列长度不一致的情况,如语音识别(音频很长,文字很短)或 OCR(文字识别)。
(2)nn.KLDivLoss (KL 散度)
场景:衡量两个概率分布之间的相似度。常用于模型蒸馏(让小模型模仿大模型的输出分布)。
(3)nn.CosineEmbeddingLoss (余弦相似度损失)
场景:判断两个向量是否相似,常用于人脸识别或推荐系统。
2.4 激活函数(Activations)
在神经网络中,非线性激活函数(Non-linear Activations)是赋予模型"智能"的关键。正如之前讨论的,如果没有它们,无论网络堆叠多少层,本质上都只是一个巨大的线性方程,无法处理现实世界中复杂的曲线和模式。而激活函数的作用就像是一个过滤器或门控,决定了神经元接收到的信号中,哪些信息应该被传递到下一层,以及传递的强度是多少。数学上已经证明,只要有足够的隐藏层和非线性激活函数,神经网络可以逼近任意复杂的连续函数。
1 最常用的激活函数
(1)ReLU (Rectified Linear Unit) ------ 现代深层网络的标配
公式:f(x) = max(0, x);
优点:计算极快,只需要判断是否大于 0;缓解梯度消失,在x > 0区域,梯度恒为1,这让深层网络训练变得容易;
缺点:Dead ReLU问题,如果输入恒小于0,神经元会彻底"熄灭",不再更新权重。
(2)Sigmoid------经典的概率映射
公式:f(x) = 1/(1+e-x);
特点:将输入压缩到(0, 1)之间;
应用:常用于二分类任务的输出层;
缺点:容易导致梯度消失,当输入非常大或非常小时,梯度几乎为0,导致模型学不动;
(3)Tanh (双曲正切)------零中心化的改进
公式:f(x) = (ex-e-x)/(ex+e-x);
特点:将输入压缩到(-1, 1)之间,输出的均值为0;
优点:在循环神经网络(RNN)中表现通常比Sigmoid好;
(4)LeakyReLU
针对ReLU的改进
公式:f(x) = x (if x > 0) else σx;
初衷:给负区间一个微小的斜率(如α = 0.01),防止神经元彻底死掉;
2 高级与现代激活函数
(1)ELU (Exponential Linear Unit):在负区间使用指数曲线,使输出均值更接近 0,提高收敛速度。
(2)GELU (Gaussian Error Linear Unit):这是BERT、GPT等Transformer大模型中的标准配置,它通过随机正则化的思想,在输入较小时以更高概率"关掉"神经元。
(3)SiLU (Swish):由Google提出,f(x) = x·sigmoid(x),在YOLOv5/v8等目标检测模型中表现优异。
2.5 容器(Containers)
在torch.nn中,容器(Containers) 负责将各种独立的层(如 nn.Linear, nn.Conv2d)按照特定的逻辑组织在一起。如果把"层"看作积木,而"容器"就是如何把这些积木拼成房子的图纸和框架。
1 nn.Module
万物之母,虽然它经常被当作一个"类",但它本质上是所有神经网络的顶级容器。它可以包含其他nn.Module(子模块),它负责追踪所有参数,并提供to('cuda')等一键转移设备的方法。
2 nn.Sequential
使用频率最高的容器,适用于单线条、流水线式的网络结构。
逻辑:数据按照你定义的顺序,依次经过每一个层。前一层的输出自动作为后一层的输入。
优点:简洁,不需要手动写 forward 函数。
缺点:不支持多输入、多输出或残差连接(Skip Connections)。
model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
3 nn.ModuleList
当需要批量管理许多层,但这些层并没有固定的执行顺序时,使用nn.ModuleList。使用nn.ModuleList时,它的每一层都会被正确地注册到模型中。
4 nn.ModuleDict
类似于Python的dict,它允许通过字符串索引来访问和存储模块。当需要根据配置动态选择不同的激活函数或子网络时非常有用,它支持 keys(), values(), items()等字典操作,且能正确注册参数。
5 nn.ParameterList & nn.ParameterDict
这两个容器不装"层",而是专门装自定义权重张量(Parameters)。如果不想使用预定义的层,而是想自己手写矩阵乘法逻辑,并管理一堆权重向量,就用它们。
3 手写数字识别
3.1 基础版本
1 DeepMNIST模型及训练
先给出模型及训练的完整代码:


1 import torch
2 import torch.nn as nn
3 import torch.nn.functional as F
4 import torch.optim as optim
5 from torchvision import datasets, transforms
6 from torch.utils.data import DataLoader
7
8 class DeepMNIST(nn.Module):
9 def __init__(self):
10 super(DeepMNIST, self).__init__()
11
12 # 1. 卷积特征提取层 (使用 nn.Sequential 容器)
13 self.features = nn.Sequential(
14 # 第一层:卷积 -> 批归一化 -> 激活 -> 池化
15 nn.Conv2d(1, 32, kernel_size=3, padding=1), # in_channels=1 (灰度图)
16 nn.BatchNorm2d(32), # 稳定分布
17 nn.ReLU(inplace=True), # 非线性激活
18 nn.MaxPool2d(kernel_size=2, stride=2), # 28x28 -> 14x14
19
20 # 第二层:进一步提取深层特征
21 nn.Conv2d(32, 64, kernel_size=3, padding=1),
22 nn.BatchNorm2d(64),
23 nn.ReLU(inplace=True),
24 nn.MaxPool2d(kernel_size=2, stride=2), # 14x14 -> 7x7
25 nn.Dropout(0.25) # 随机丢弃,防止过拟合
26 )
27
28 # 2. 分类层 (全连接层)
29 self.classifier = nn.Sequential(
30 nn.Flatten(), # 将 64x7x7 压平为 1D 向量
31 nn.Linear(64 * 7 * 7, 128), # 特征线性组合
32 nn.ReLU(inplace=True),
33 nn.Dropout(0.5), # 全连接层常用的 Dropout
34 nn.Linear(128, 10) # 最终映射到 10 个类别
35 )
36
37 def forward(self, x):
38 x = self.features(x)
39 x = self.classifier(x)
40 return x
41
42 def train_model():
43 # 检测设备,优先使用显卡
44 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45 print(f"Using device: {device}")
46
47 # 数据预处理:转换为 Tensor 并归一化
48 transform = transforms.Compose([
49 transforms.ToTensor(),
50 transforms.Normalize((0.1307,), (0.3081,))
51 ])
52
53 # 加载数据集
54 train_loader = DataLoader(datasets.MNIST('./data', train=True, download=True, transform=transform), batch_size=64, shuffle=True)
55 test_loader = DataLoader(datasets.MNIST('./data', train=False, transform=transform), batch_size=1000)
56
57 model = DeepMNIST().to(device)
58
59 # 定义损失函数 (内部含 LogSoftmax) 和 优化器
60 criterion = nn.CrossEntropyLoss()
61 optimizer = optim.Adam(model.parameters(), lr=0.001)
62
63 # 训练循环
64 model.train()
65 for epoch in range(1, 4): # 跑 3 个轮次
66 for batch_idx, (data, target) in enumerate(train_loader):
67 data, target = data.to(device), target.to(device)
68
69 optimizer.zero_grad() # 1. 梯度清零
70 output = model(data) # 2. 前向传播
71 loss = criterion(output, target) # 3. 计算损失
72 loss.backward() # 4. 反向传播 (计算梯度)
73 optimizer.step() # 5. 更新参数权重
74
75 if batch_idx % 100 == 0:
76 print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}')
77
78 # 测试逻辑
79 model.eval() # 切换到评估模式 (关闭 Dropout 和 BatchNorm 的训练行为)
80 correct = 0
81 with torch.no_grad(): # 测试时不计算梯度,节省显存
82 for data, target in test_loader:
83 data, target = data.to(device), target.to(device)
84 output = model(data)
85 pred = output.argmax(dim=1, keepdim=True)
86 correct += pred.eq(target.view_as(pred)).sum().item()
87
88 print(f'\nTest Set Accuracy: {100. * correct / len(test_loader.dataset)}%')
89 torch.save(model.state_dict(), "mnist_cnn.pth")
90
91 if __name__ == '__main__':
92 train_model()
模型整体结构
输入 (1×28×28)
↓
卷积特征提取(features)
↓
特征图 (64×7×7)
↓
展平 + 全连接(classifier)
↓
输出 (10类)
模型输入数据
x.shape = (N, 1, 28, 28)
使用的是经典MNIST数据,其中N是batch size,1表示单通道的灰度图,图像大小是28x28。
(1)卷积特征提取
对应代码第13~26行的self.features部分内容,该Sequential容器中又包含两块。第一层卷积块先通过第15行代码对1通道的灰度图做卷积,输出(32, 28, 28)特征图,具体到该训练实例,输入是(64, 1, 28, 28),即一次输入64张1x28x28的单通道灰度图,经过第一层卷积输出(64, 32, 28, 28),batch_size保持64不变,通道数变为32,图片尺寸不变,在该卷积中有32个卷积核,每个卷积核作用在所有输入通道上,然后求和,所以卷积输出包含32通道的图片;之后第16行代码对之前输出数据做批归一化,让数据分布稳定,加速训练、防止梯度消失,该步不会改变数据shape;接下来第17行使用ReLU激活函数引入非线性,inplace = True直接覆盖原数据,省显存,该步仍不会改变数据shape;再后来18行使用MaxPool最大池化把图像宽高缩小一半,该步之后输出数据shape为(64, 32, 14, 14)。第二层卷积块通过21行代码进一步对32通道数据提取深层特征,输出shape为(64, 64, 14, 14)的数据;之后22,23行做和之前类似的归一化和非线性激活;第24行对数据做进一步池化操作,此后输出数据shape为(64, 64, 7, 7);接下来第25行通过Dropout随机关闭25%的神经元,防止过拟合,输入数据经过该步骤后,输出数据形状不变仍为(64, 64, 7, 7),但是批次中数据各个通道内大约有25%的数据会变成0,而剩余的75%会放大至y=x/0.75。
(2)全连接分类
对应代码第29~35行的self.classifier部分内容,在该Sequential容器中,先通过第30行代码将原64x7x7的数据展平成一维数据,长度是64x7x7=3136,此时输出数据为(64, 3136);之后31行为全连接层,它会将长度为3136输入映射为长度是128的输出,包含batch_size时shape为(64, 128);接下来第32,33行进行ReLU非线性激活和参数为0.5的Dropout正则化;最后34行再次通过Linear全连接,将128的输入输出映射到最终的10个类别。
网络中的各层功能可以总结为:Conv2d提取边缘、纹理、形状,BN让训练更稳更快,ReLU让网络能学复杂规律,MaxPool缩小图片,保留关键特征,Dropout防止死记硬背(过拟合),Linear根据特征做最终判断。
模型中第37行的forward函数在把模型当作"函数"调用时,如model(x),会被触发调用,该函数中又会分别触发模型features和classifier的forward函数调用,进一步会触发两个容器内各层的forward调用,通过各层前向传播完成模型预测结果的计算。
源码中第42行的train_model函数用于完成模型的训练及测试。第44行当存在cuda显卡时优先使用设备进行训练;第48~51行中的transforms.Compose将多个变换操作组合成一个管道,按顺序依次执行,ToTensor用于将MNIST中的图像转换为PyTorch张量同时将像素值从[0, 255]范围缩放到[0.0, 1.0],而Normalize继续对数据进行归一化,其中mean=0.1307是MNIST数据集的像素均值,std = 0.3081是MNIST数据集的像素标准差,使用数据集的统计特性进行标准化,可以让模型训练更稳定;之后第54,55行分别加载训练集数据和测试集数据,train=True训练集数据是6万张,train=False测试集数据是1万张,训练时batch_size=64,并且通过shuffle=True打乱数据,防止模型记顺序;接下来57行创建模型并搬到GPU/CPU设备上,此时所有权重w、b都在这个设备上;之后60,61行定义损失函数和优化器,CrossEntropyLoss是分类任务专用损失,内部会对输入数据output:(N, 10)最终分类结果和target:(N, )正确的标签计算损失,Adam优化器负责更新权重,其中lr=0.001是学习率;64行切换为训练模式;65行for循环用于确定进行3轮次的训练,一个epoch会遍历完整数据集一次;第66行for循环会按batch_size=64,依次取出相应的训练数据data和正确标签target;第67行将两组数据集都搬到设备上;第69行用于清零梯度,因为在循环过程中,上一批次的梯度会残留在模型里,所以要每一批必须清零;第70行触发forward调用向前传播,通过模型各层作用输出10分类的预测值;第71行通过比较预测值和真实标签计算损失Loss;第72行通过反向传播从输出层往回计算每个w、b参数的梯度,param.grad = dLoss/dparam;第73行根据各自梯度自动更新所有w、b,完成学习训练;第75,76行用于输出中间过程信息。
第79行切换到评估模式,此时会关闭Dropout和BatchNorm的训练行为;第80行变量correct用于统计训练精度;第81行在进行评估时不计算梯度,可以节省内存;第82行for循环依次从测试集中取出测试数据data和正确标签target;第83行将相应数据搬到设备上;第84行继续触发forward调用,计算预测输出;第85行取10个输出中最大的值作为预测数字,测试时batch_size=1000,所以output的形状为[1000,10],1000表示1000张图片,10表示每张图片对应0~9个数字的得分,dim=1表示沿着10个类别方法取最大值的位置(如果dim=0则表示沿着图片数量方向);第86行中target和pred是形状为[1000]的真实标签和预测值,为了绝对安全,target.view_as(pred)把真实标签的形状改成和预测值完全一样,然后用pred.eq把预测值逐元素和真实标签比较是否相等,相等返回True否则False,.sum会把所以True作为1相加,.item会最终相加和(类型为张量)里的数字取出来,变为普通的python数字,表示本批次1000个数据中预测正确的数量,最终会将所有批次预测正确的数量都加到correct变量上;第88行将总的预测正确数量除于总的测试数据数转为百分比输出;第89行保存最终的模型到本地文件mnist_cnn.pth中。
2 Flask前端
为了验证训练出来的模型的真实预测效果,这里用Flask实现一个简单的Web程序,前端templates/index.html文件内容如下:


1 <!DOCTYPE html>
2 <html lang="zh-CN">
3 <head>
4 <meta charset="UTF-8">
5 <title>MNIST 手写数字识别</title>
6 <style>
7 body {
8 text-align: center;
9 font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
10 background: #f4f7f6;
11 margin-top: 50px;
12 }
13 h1 { color: #333; }
14 .canvas-container {
15 display: inline-block;
16 box-shadow: 0 4px 15px rgba(0,0,0,0.2);
17 border-radius: 8px;
18 overflow: hidden;
19 background: white;
20 }
21 canvas {
22 display: block;
23 cursor: crosshair;
24 touch-action: none; /* 防止移动端滚动干扰 */
25 }
26 .controls { margin: 20px; }
27 button {
28 padding: 10px 25px;
29 font-size: 16px;
30 margin: 0 10px;
31 cursor: pointer;
32 border: none;
33 border-radius: 5px;
34 transition: 0.3s;
35 }
36 .clear-btn { background: #e74c3c; color: white; }
37 .predict-btn { background: #2ecc71; color: white; }
38 button:hover { opacity: 0.8; }
39 #result {
40 margin-top: 20px;
41 font-size: 24px;
42 color: #2c3e50;
43 font-weight: bold;
44 min-height: 30px;
45 }
46 .info { color: #7f8c8d; font-size: 14px; margin-top: 10px; }
47 </style>
48 </head>
49 <body>
50 <h1>MNIST 手写数字识别</h1>
51
52 <div class="canvas-container">
53 <canvas id="canvas" width="280" height="280"></canvas>
54 </div>
55
56 <div class="controls">
57 <button class="clear-btn" onclick="clearCanvas()">清除画布</button>
58 <button class="predict-btn" onclick="predict()">开始识别</button>
59 </div>
60
61 <div id="result">请在上方区域书写数字 (0-9)</div>
62 <div class="info">提示:请尽量在中心区域书写,并保持笔画清晰</div>
63
64 <script>
65 const canvas = document.getElementById('canvas');
66 const ctx = canvas.getContext('2d');
67 let isDrawing = false;
68
69 // 1. 核心修复:初始化白色背景(物理填充)
70 function initCanvas() {
71 ctx.fillStyle = "white";
72 ctx.fillRect(0, 0, canvas.width, canvas.height);
73
74 // 设置笔触样式
75 ctx.lineWidth = 18; // 增加粗度以匹配 MNIST 特征
76 ctx.lineCap = 'round'; // 圆形笔头
77 ctx.lineJoin = 'round'; // 圆形转角
78 ctx.strokeStyle = 'black'; // 黑色字
79 }
80
81 // 页面加载完成后立即执行
82 window.onload = initCanvas;
83
84 // 2. 绘图逻辑控制
85 function getMousePos(e) {
86 const rect = canvas.getBoundingClientRect();
87 // 考虑缩放和位移,精确计算坐标
88 return {
89 x: e.clientX - rect.left,
90 y: e.clientY - rect.top
91 };
92 }
93
94 canvas.addEventListener('mousedown', (e) => {
95 isDrawing = true;
96 const pos = getMousePos(e);
97 ctx.beginPath();
98 ctx.moveTo(pos.x, pos.y);
99 });
100
101 canvas.addEventListener('mousemove', (e) => {
102 if (!isDrawing) return;
103 const pos = getMousePos(e);
104 ctx.lineTo(pos.x, pos.y);
105 ctx.stroke();
106 });
107
108 window.addEventListener('mouseup', () => {
109 isDrawing = false;
110 });
111
112 // 3. 清除逻辑修复
113 function clearCanvas() {
114 // 关键:不只是 clearRect,还要重新填充白色
115 ctx.clearRect(0, 0, canvas.width, canvas.height);
116 initCanvas();
117 document.getElementById('result').innerText = "请在上方区域书写数字 (0-9)";
118 document.getElementById('result').style.color = "#2c3e50";
119 }
120
121 // 4. 发送数据到 Flask 后端
122 async function predict() {
123 const resultDiv = document.getElementById('result');
124 resultDiv.innerText = "识别中...";
125
126 try {
127 const dataURL = canvas.toDataURL('image/png');
128 const response = await fetch('/predict', {
129 method: 'POST',
130 headers: { 'Content-Type': 'application/json' },
131 body: JSON.stringify({ image: dataURL })
132 });
133
134 if (!response.ok) throw new Error("网络响应不正常");
135
136 const res = await response.json();
137 resultDiv.innerHTML = `预测结果: <span style="color: #e67e22; font-size: 40px;">${res.prediction}</span><br>
138 <small style="font-size: 14px; color: #95a5a6;">置信度: ${res.confidence}</small>`;
139 } catch (error) {
140 console.error("Error:", error);
141 resultDiv.innerText = "识别失败,请检查后端服务";
142 resultDiv.style.color = "red";
143 }
144 }
145 </script>
146 </body>
147 </html>
index.html
该html文件在128行会调用后端的predict对画布上的手写图片进行识别预测,其他内容相对简单不再进行详细介绍。
3 Flask前端
后端app.py文件内容如下:


1 import torch
2 import torch.nn as nn
3 import torch.nn.functional as F
4 from torchvision import transforms
5 from flask import Flask, request, jsonify, render_template
6 from PIL import Image
7 import io
8 import base64
9 import numpy as np
10
11 # 导入之前定义的模型类 (确保结构一致)
12 class DeepMNIST(nn.Module):
13 def __init__(self):
14 super(DeepMNIST, self).__init__()
15 self.features = nn.Sequential(
16 nn.Conv2d(1, 32, kernel_size=3, padding=1),
17 nn.BatchNorm2d(32),
18 nn.ReLU(inplace=True),
19 nn.MaxPool2d(kernel_size=2, stride=2),
20 nn.Conv2d(32, 64, kernel_size=3, padding=1),
21 nn.BatchNorm2d(64),
22 nn.ReLU(inplace=True),
23 nn.MaxPool2d(kernel_size=2, stride=2),
24 nn.Dropout(0.25)
25 )
26 self.classifier = nn.Sequential(
27 nn.Flatten(),
28 nn.Linear(64 * 7 * 7, 128),
29 nn.ReLU(inplace=True),
30 nn.Dropout(0.5),
31 nn.Linear(128, 10)
32 )
33
34 def forward(self, x):
35 x = self.features(x)
36 x = self.classifier(x)
37 return x
38
39 app = Flask(__name__)
40
41 # 加载模型
42 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43 model = DeepMNIST().to(device)
44 model.load_state_dict(torch.load("mnist_cnn.pth", map_location=device))
45 model.eval()
46
47 # 图像预处理流水线
48 transform = transforms.Compose([
49 transforms.Resize((28, 28)),
50 transforms.Grayscale(),
51 transforms.ToTensor(),
52 transforms.Normalize((0.1307,), (0.3081,))
53 ])
54
55 @app.route('/')
56 def index():
57 return render_template('index.html')
58
59 @app.route('/predict', methods=['POST'])
60 def predict():
61 data = request.get_json()
62 image_data = data['image'].split(',')[1] # 去掉 base64 头部
63 image_bytes = io.BytesIO(base64.b64decode(image_data))
64
65 # 打开并反色(前端通常是白底黑字,MNIST 是黑底白字)
66 img = Image.open(image_bytes).convert('L')
67 img = Image.eval(img, lambda x: 255 - x)
68 img.resize((28, 28)).save("debug_input.png")
69
70 img_tensor = transform(img).unsqueeze(0).to(device)
71
72 with torch.no_grad():
73 output = model(img_tensor)
74 prob = F.softmax(output, dim=1)
75 prediction = torch.argmax(prob, dim=1).item()
76 confidence = torch.max(prob).item()
77
78 return jsonify({
79 'prediction': prediction,
80 'confidence': f"{confidence:.2%}"
81 })
82
83 if __name__ == '__main__':
84 app.run(host='0.0.0.0', port=5000, debug=True)
在后端文件中仍然使用和模型训练时完全相同的DeepMNIST类,在第39行创建Flask实例app;在第44,45行使用DeepMNIST实例model加载之前训练获得的模型mnist_cnn.pth,并切换到评估模式;第48~53行创建图像预处理流水线对前端传过来的图片进行预处理;在60行的predict函数中,将前端传过来的白底黑字转换为和MNIST数据集一样的黑底白字的图片,然后用transform对图片数据进行预处理,之后使用模型对数据进行识别预测,并将结果返回给前端进行显示,通过python app.py启动后端,并使用浏览器访问localhost:5000前端,运行效果如下:

但是通过测试发现模型对于0经常会识别成9,如下图:

3.2 增强模型版本
由于模型在训练时对形状的拓扑结构(比如圆圈)提取不够鲁棒,或者受到了"书写习惯"的影响,模型可能会把"0"识别成"9",接下来通过四个维度来深度优化训练过程。
1 引入数据增强 (Data Augmentation)
现在的模型可能只见过正儿八经、垂直居中的数字。当书写得稍微歪一点或大一点,它就"懵"了。在训练代码中,修改transform,增加随机旋转和平移:
train_transform = transforms.Compose([
# 随机旋转(正负10度),随机平移(宽高的10%),随机缩放(0.9~1.1倍)
transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
这会让模型在训练时见过各种"奇形怪状"的9和0,强迫它去学习数字的结构(比如9的圆头和长尾巴),而不是具体的像素位置。
2 增加模型深度和宽度
之前的模型可能比较轻量,为了让它能识别更复杂的特征(比如9和0的细微差别),可以增加卷积核的数量或层数。
# 在 features 序列中尝试增加一层卷积
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
更多的通道数(如 128)意味着模型可以同时提取更多种类的几何特征。
3 标签平滑 (Label Smoothing)
MNIST数据集里有些样本本身就写得很模糊(有的9确实写得像1),如果强制要求模型100%确定那是9,会导致过拟合。
# 修改损失函数
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
以上代码告诉模型:"这大概率是 9,但也有一点点可能是别的",这能显著提高模型在面对模糊输入时的泛化能力。
4 增大训练轮数
在深度学习训练中,增大训练轮数(Epochs)的本质是给模型更多"阅读教材"和"总结规律"的时间。如果Epoch过低,则模型可能还没走到损失函数的"谷底"就停止了,这是可以提高Epoch轮数,给模型足够的时间跨越"平原"和"小坑",进入更深、更稳定的低损耗区域。经过更多次迭代,卷积核会变得非常锐利,能够区分出"9的圆圈"和 "1的长杆"之间极细微的统计学差异。需要注意的是增大轮数也有可能带来过拟合(Overfitting)这一副作用,这时模型不再学习数字的"通用特征",而是开始"死记硬背"训练集里的每一个像素噪声(比如某张图片里特有的污点),具体表现为:训练Loss持续下降,但验证/测试Loss开始回升,即识别时准确率反而下降了。在增强模型中把训练Epoch提高到15轮。
以下是完整的训练和后端源码:


import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 1. 增强版模型结构
class ModernMNIST(nn.Module):
def __init__(self):
super(ModernMNIST, self).__init__()
self.features = nn.Sequential(
# 第一层:32通道
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1), # 增加一层卷积增强提取
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2), # 14x14
# 第二层:64通道
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2), # 7x7
nn.Dropout2d(0.25) # 针对空间特征的Dropout
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 7 * 7, 512), # 增大隐藏层维度
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(0.5),
nn.Linear(512, 10)
)
def forward(self, x):
return self.classifier(self.features(x))
def train():
device = torch.device("cuda") # 这里为了保障训练速度使用了显卡,如没有显卡请使用CPU
# 2. 核心改进:数据增强
# 这将解决你画的 9 稍微偏一点就识别成 1 的问题
train_transform = transforms.Compose([
transforms.RandomAffine(degrees=15, translate=(0.15, 0.15), scale=(0.8, 1.2)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_loader = DataLoader(datasets.MNIST('./data', train=True, download=True, transform=train_transform), batch_size=128, shuffle=True)
test_loader = DataLoader(datasets.MNIST('./data', train=False, transform=test_transform), batch_size=1000)
model = ModernMNIST().to(device)
# 3. 标签平滑,增强泛化能力
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 4. 学习率调度器:每 5 轮衰减为原来的 0.5
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
for epoch in range(1, 16): # 增加到 15 轮
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
scheduler.step() # 更新学习率
# 测试
model.eval()
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
pred = model(data).argmax(dim=1)
correct += pred.eq(target).sum().item()
print(f"Epoch {epoch}: Accuracy {100.*correct/10000}% | LR: {scheduler.get_last_lr()[0]:.6f}")
torch.save(model.state_dict(), "mnist_cnn_pro.pth")
print("Saved Pro model!")
if __name__ == "__main__":
train()
train_upgrade.py


import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from flask import Flask, request, jsonify, render_template
from PIL import Image
import io
import base64
import numpy as np
class ModernMNIST(nn.Module):
def __init__(self):
super(ModernMNIST, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Dropout2d(0.25)
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 7 * 7, 512),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(0.5),
nn.Linear(512, 10)
)
def forward(self, x):
return self.classifier(self.features(x))
app = Flask(__name__)
# 加载模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ModernMNIST().to(device) # 使用新的类名
model.load_state_dict(torch.load("mnist_cnn_pro.pth", map_location=device))
model.eval()
# 图像预处理流水线
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
@app.route('/')
def index():
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
data = request.get_json()
image_data = data['image'].split(',')[1] # 去掉 base64 头部
image_bytes = io.BytesIO(base64.b64decode(image_data))
# 打开并反色(前端通常是白底黑字,MNIST 是黑底白字)
img = Image.open(image_bytes).convert('L')
img = Image.eval(img, lambda x: 255 - x)
img.resize((28, 28)).save("debug_input.png")
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(img_tensor)
prob = F.softmax(output, dim=1)
prediction = torch.argmax(prob, dim=1).item()
confidence = torch.max(prob).item()
return jsonify({
'prediction': prediction,
'confidence': f"{confidence:.2%}"
})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)
app_pro.py
为了训练速度,train_upgrade.py中使用cuda显卡,如果没有显卡请修改相应行代码,训练输出如下:

从输出结果会发现,12轮到15轮,精度基本不再提升,甚至有微小向下波动。这说明:
收敛完成:对于MNIST这种简单任务,15轮左右已经把该学的规律都学完了;
边际效应递减:再往后增加到100轮,精度也不会变成100%,反而会增加过拟合风险。
训练完成后,运行python app_pro.py启动后端,并通过浏览器进行手写识别测试,发现对于0的识别效果有明显改善,但是对于书写不在画布中央时仍会出现错误识别:

3.3 终极版本
同一个数字在中间识别好,在边缘识别差------其本质原因在于:CNN并不是天生就具有"平移不变性"(Translation Invariance),或者说,这种不变性是有条件的,它依然受到了以下三个技术层面的制约:
1 MNIST 训练集的"偏置" (Dataset Bias)
MNIST数据集并不是让数字随意散落在28x28的画布上,所有样本都是经过"重心对齐(Center of Mass)"处理的。
模型的经验:模型在训练时,99.9%的"2"的核心特征(横和竖)都出现在图像的中央位置。
写到边缘时:当把"2"写到画布边缘,缩放后的28x28图片中,大部分区域是空白的,只有边缘有一点点像素,这对于模型来说,是一个从来没见过的图像分布,它自然就认不出来了。
2 形态学差异:缩放引起的"变形"
居中缩放:如果数字在中间,缩放后它会变成一个标准的、占满28x28区域的灰度数字。
偏心缩放:如果数字在边缘,缩放后由于周围空白过多,数字会被挤得很小且"靠边"。线条可能会因为下采样(Downsampling)而变得破碎。模型看到的不是一个"2",而是一小团"噪声特征"。
3 CNN的层级结构与Padding效应
卷积操作依靠滑动窗口,为了让窗口覆盖边缘像素,我们使用了 Padding(在周围补 0)。
浅层特征:第一层卷积可能提取到了"2"的局部特征。
深层特征:随着层数变深,感受野(Receptive Field)增大,后面的全连接层(Classifier)会结合特征的空间位置来做最终判断。如果模型学到的是"中间有特征=2",那么"边缘有特征"就可能被错误地判定为其他数字(如图中的1,置信度仅为29.83%)。
为了彻底解决"位置敏感"问题,需要在预处理(InferenceTime)或训练(TrainingTime)阶段加入重心居中处理。找到数字的核心区域(Region of Interest),将其提取出来,然后人为地放置在28x28的正中央。
以下是后端app_pro_last.py最终版本:


import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from flask import Flask, request, jsonify, render_template
from PIL import Image, ImageOps, ImageChops
import io
import base64
import numpy as np
# 1. 必须匹配你训练 mnist_cnn_pro.pth 时使用的 ModernMNIST 类
class ModernMNIST(nn.Module):
def __init__(self):
super(ModernMNIST, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Dropout2d(0.25)
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 7 * 7, 512),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(0.5),
nn.Linear(512, 10)
)
def forward(self, x):
return self.classifier(self.features(x))
app = Flask(__name__)
# 加载模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ModernMNIST().to(device)
model.load_state_dict(torch.load("mnist_cnn_pro.pth", map_location=device))
model.eval()
# 核心预处理函数:提取数字并居中
def process_image(pil_img):
# A. 处理透明度并转为黑底白字
bg = Image.new("RGBA", pil_img.size, (255, 255, 255, 255))
combined = Image.alpha_composite(bg, pil_img).convert('L')
inv_img = ImageOps.invert(combined) # 反色:变成黑底白字
# B. 找到数字的边界框 (Bounding Box)
bbox = inv_img.getbbox()
if not bbox:
return None
# C. 裁剪出数字区域并等比缩放到 20x20 内部
roi = inv_img.crop(bbox)
width, height = roi.size
max_dim = max(width, height)
# 按照长边缩放到 20 像素,留出边缘 Padding
scale = 20.0 / max_dim
new_size = (int(width * scale), int(height * scale))
roi = roi.resize(new_size, Image.LANCZOS)
# D. 将缩放后的数字粘贴到 28x28 的纯黑画布正中央
final_canvas = Image.new('L', (28, 28), 0)
upper_left = ((28 - new_size[0]) // 2, (28 - new_size[1]) // 2)
final_canvas.paste(roi, upper_left)
return final_canvas
@app.route('/')
def index():
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
try:
data = request.get_json()
image_data = data['image'].split(',')[1]
image_bytes = io.BytesIO(base64.b64decode(image_data))
raw_img = Image.open(image_bytes).convert('RGBA')
# 执行居中预处理
processed_img = process_image(raw_img)
if processed_img is None:
return jsonify({'prediction': 'N/A', 'confidence': '0%'})
# 调试保存:查看模型真正"看"到的样子
processed_img.save("debug_centered.png")
# 转换为 Tensor (注意:这里不再需要复杂的 Normalize 增强,简单的标准化即可)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
img_tensor = transform(processed_img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(img_tensor)
prob = F.softmax(output, dim=1)
pred = torch.argmax(prob, dim=1).item()
conf = torch.max(prob).item()
return jsonify({
'prediction': pred,
'confidence': f"{conf:.2%}"
})
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
app_pro_last.py
重新运行后端python app_pro_last.py,可见对于同样的图片已经正确完成识别:
