完整的神经网络
以分类任务为例,神经网络一般包括backbone和head(计算机视觉领域)
下面的BasicBlock不是一个标准的backbone,标准的应该是复杂的CNNs构成的
Classfier是一个标准的head,其中output_dim表示分类类别,一般写作num_classes
python
import torch # 导入 torch 库
import torch.nn as nn # 导入 torch 的神经网络模块
import torch.nn.functional as F # 导入 torch 的函数式接口
# 定义一个基础的神经网络模块
class BasicBlock(nn.Module): # 继承自 torch 的 Module 类
def __init__(self, input_dim, output_dim):
super(BasicBlock, self).__init__() # 初始化父类
# 构建一个序列模块,包含一个线性层和一个 ReLU 激活函数
self.block = nn.Sequential(
# 线性层,输入维度为 input_dim,输出维度为 output_dim
nn.Linear(input_dim, output_dim),
nn.ReLU(), # ReLU 激活函数
)
def forward(self, x):
x = self.block(x) # 将输入数据 x 通过定义的序列模块
return x # 返回模块的输出
# 定义一个分类器神经网络
class Classifier(nn.Module): # 继承自 torch 的 Module 类
def __init__(self, input_dim, output_dim=41, hidden_layers=1, hidden_dim=256):
super(Classifier, self).__init__() # 初始化父类
# 构建一个序列模块,包含若干个 BasicBlock 和一个线性输出层
self.fc = nn.Sequential(
# 第一个 BasicBlock,将输入维度转换为隐藏层维度
BasicBlock(input_dim, hidden_dim),
# 根据 hidden_layers 数量添加多个 BasicBlock
*[BasicBlock(hidden_dim, hidden_dim) for _ in range(hidden_layers)],
# 线性输出层,将隐藏层维度转换为输出维度
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
x = self.fc(x) # 将输入数据 x 通过定义的序列模块
return x # 返回模块的输出
对 *[BasicBlock(hidden_dim, hidden_dim) for _ in range(hidden_layers)]的一个补充解释,"*"代表解压列表,例如A=[a,b,c],那么f(*A)=f(a,b,c)
在这里的具体意义是"便于控制隐藏层数量",而其中的"_"代表不希望在循环中使用变量,这是一种通用的惯例,表明循环的目的不是对每个元素进行操作,而是为了重复某个操作特定次数。如果hidden_layers=3,这里的等价含义就是BasicBlock(hidden_dim, hidden_dim),BasicBlock(hidden_dim, hidden_dim),BasicBlock(hidden_dim, hidden_dim),------连续出现三次
dropout
Dropout层在神经网络层当中是用来干什么的呢?它是一种可以用于减少神经网络过拟合的结构。
如上图我们定义的网络,一共有四个输入x_i,一个输出y。Dropout则是在每一个batch的训练当中随机减掉一些神经元,而作为编程者,我们可以设定每一层dropout(将神经元去除的的多少)的概率,在设定之后,就可以得到第一个batch进行训练的结果:
从上图我们可以看到一些神经元之间断开了连接,因此它们被dropout了!dropout顾名思义就是被拿掉的意思,正因为我们在神经网络当中拿掉了一些神经元,所以才叫做dropout层。
在进行第一个batch的训练时,有以下步骤:
- 设定每一个神经网络层进行dropout的概率
- 根据相应的概率拿掉一部分的神经元,然后开始训练,更新没有被拿掉神经元以及权重的参数,将其保留
- 参数全部更新之后,又重新根据相应的概率拿掉一部分神经元,然后开始训练,如果新用于训练的神经元已经在第一次当中训练过,那么我们继续更新它的参数。而第二次被剪掉的神经元,同时第一次已经更新过参数的,我们保留它的权重,不做修改,直到第n次batch进行dropout时没有将其删除。