Pytorch 缓解过拟合和网络退化

一 添加BN模块

BN模块应该添加 激活层前面

在模型实例化后,我们需要对BN层进行初始化。PyTorch中的BN层是通过nn.BatchNorm1d或nn.BatchNorm2d类来实现的。

bn = nn.BatchNorm1d(20) #

对于1D输入数据,使用nn.BatchNorm1d;对于2D输入数据,使用nn.BatchNorm2d

在模型的前向传播过程中,我们需要将BN层应用到适当的位置。以全连接层为例,我们需要在全连接层的输出之后调用BN层。

python 复制代码
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.bn = nn.BatchNorm1d(20)
        self.fc2 = nn.Linear(20, 30)
        self.fc3 = nn.Linear(30, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

二 添加残差连接

最主要的是需要注意输入参数的维度是否一致

python 复制代码
import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(ResidualBlock, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, input_size)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        residual = x
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out += residual
        out = self.relu(out)
        return out
-----------------------------------
©著作权归作者所有:来自51CTO博客作者mob649e8166c3a5的原创作品,请联系作者获取转载授权,否则将追究法律责任
pytorch 全链接层设置残差模块
https://blog.51cto.com/u_16175510/6892589

1、Pytorch搭建残差网络

2、

相关推荐
Pyeako3 分钟前
opencv计算机视觉--Harris角点检测&SIFT特征提取&图片抠图
人工智能·python·opencv·计算机视觉·harris角点检测·sift特征提取·图片抠图
前进的程序员5 分钟前
智能融合终端的技术革新与应用实践
大数据·人工智能
艾莉丝努力练剑6 分钟前
【AI时代的赋能与重构】当AI成为创作环境的一部分:机遇、挑战与应对路径
linux·c++·人工智能·python·ai·脉脉·ama
程序猫A建仔7 分钟前
【AI入门基础】AI核心知识点速查手册
人工智能
AI科技星9 分钟前
加速运动电荷产生引力场方程求导验证
服务器·人工智能·线性代数·算法·矩阵
Akamai中国9 分钟前
Akamai Cloud客户案例 | Multivrse 信赖 Akamai 为其业务增长提供动力,实现更快资源调配、成本节约与更低延迟
人工智能·云计算·云服务·云存储
嘉立创FPC苗工10 分钟前
气隙变压器铁芯:磁路中的“安全阀”与能量枢纽
大数据·人工智能·制造·fpc·电路板
郝学胜-神的一滴12 分钟前
B站:从二次元到AI创新孵化器的华丽转身 | Google Cloud峰会见闻
开发语言·人工智能·算法
果粒蹬i15 分钟前
从割裂到融合:MATLAB与Python混合编程实战指南
开发语言·汇编·python·matlab
千流出海15 分钟前
冬季风暴考验因AI数据中心而紧张的电网系统
人工智能