一 添加BN模块
BN模块应该添加 激活层前面
在模型实例化后,我们需要对BN层进行初始化。PyTorch中的BN层是通过nn.BatchNorm1d或nn.BatchNorm2d类来实现的。
bn = nn.BatchNorm1d(20) #
对于1D输入数据,使用nn.BatchNorm1d;对于2D输入数据,使用nn.BatchNorm2d
在模型的前向传播过程中,我们需要将BN层应用到适当的位置。以全连接层为例,我们需要在全连接层的输出之后调用BN层。
python
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.bn = nn.BatchNorm1d(20)
self.fc2 = nn.Linear(20, 30)
self.fc3 = nn.Linear(30, 2)
def forward(self, x):
x = self.fc1(x)
x = self.bn(x)
x = self.fc2(x)
x = self.fc3(x)
return x
二 添加残差连接
最主要的是需要注意输入参数的维度是否一致
python
import torch
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, input_size, hidden_size):
super(ResidualBlock, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, input_size)
self.relu = nn.ReLU()
def forward(self, x):
residual = x
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out += residual
out = self.relu(out)
return out
-----------------------------------
©著作权归作者所有:来自51CTO博客作者mob649e8166c3a5的原创作品,请联系作者获取转载授权,否则将追究法律责任
pytorch 全链接层设置残差模块
https://blog.51cto.com/u_16175510/6892589
2、