Pytorch学习--神经网络--线性层及其他层

一、正则化层

torch.nn.BatchNorm2d

python 复制代码
torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)

正则化的意义:

  • 加速训练收敛:在每一层网络的输入上执行批量归一化可以保持数据的分布稳定,从而减小梯度的波动。这种稳定性让模型更快收敛,从而提高训练速度。

  • 减轻梯度消失和梯度爆炸问题:通过调整每一层的输入分布,Batch Normalization可以减轻深层网络中梯度消失和梯度爆炸的现象,使得更深的网络也能够得到有效的训练。

  • 减少对权重初始化的敏感性:Batch Normalization可以减小网络对权重初始化的依赖,使得模型可以在更宽的初始化范围内有效训练。这减少了在不同模型初始化方案间进行调试的时间和精力。

  • 提高模型的泛化能力:Batch Normalization在训练时引入了少量噪声(由于 mini-batch 的不同),这在一定程度上起到了正则化作用,有助于提高模型的泛化能力,降低过拟合的风险。

  • 降低学习率调整的难度:使用Batch Normalization可以让模型在较高的学习率下进行训练,从而进一步加速训练过程。

二、Dropout层

torch.nn.Dropout

python 复制代码
torch.nn.Dropout(p=0.5, inplace=False)

防止过拟合

三、线性层

torch.nn.Linear

python 复制代码
torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)

代码实现:

CIFAR 中的图片 转换为 一维的数据(1,m),再转换成 (1,n) 的维度

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10(root="datasets",train=False,transform=torchvision.transforms.ToTensor(),download=True)

dataloader = DataLoader(dataset,batch_size=64)

class Mary(nn.Module):
    def __init__(self):
        super(Mary,self).__init__()
        self.linear1 = Linear(196608,10)
    def forward(self,x):
        x = self.linear1(x)
        return x
Yorelee = Mary()

for data in dataloader:
    img,targets = data
    img = torch.flatten(img)
    print(img.shape)
    output = Yorelee(img)
    print(output.shape)

输出:

python 复制代码
torch.Size([196608])
torch.Size([10])
相关推荐
JicasdC123asd几秒前
基于YOLO11-seg的MultiSEAMHead驾驶员疲劳检测系统_计算机视觉实时监测_眼睛嘴巴状态识别
人工智能·计算机视觉
极新1 分钟前
飞书产品营销&增长经理王屹煊:飞书多维表格驱动业务新增长|2025极新AIGC峰会演讲实录
人工智能
小当家.1051 分钟前
JVM八股详解(上部):核心原理与内存管理
java·jvm·学习·面试
寻星探路2 分钟前
【Python 全栈测开之路】Python 基础语法精讲(三):函数、容器类型与文件处理
java·开发语言·c++·人工智能·python·ai·c#
且去填词2 分钟前
构建基于 DeepEval 的 LLM 自动化评估流水线
运维·人工智能·python·自动化·llm·deepseek·deepeval
dagouaofei5 分钟前
不同 AI 生成 2026 年工作计划 PPT 的使用门槛对比
人工智能·python·powerpoint
IRevers7 分钟前
【目标检测】深度学习目标检测损失函数总结
人工智能·pytorch·深度学习·目标检测·计算机视觉
知识分享小能手9 分钟前
Ubuntu入门学习教程,从入门到精通,Ubuntu 22.04 中的大数据 —— 知识点详解 (24)
大数据·学习·ubuntu
RockHopper202511 分钟前
为什么具身智能系统需要能“自我闭环”的认知机制
人工智能·具身智能·具身认知
itwangyang52011 分钟前
Windows + Conda + OpenMM GPU(CUDA)完整安装教程-50显卡系列
人工智能·windows·python·conda