【机器学习】037_暂退法

一、实现原理

具有输入噪音的训练,等价于Tikhonov正则化

核心方法:在前向传播的过程中,计算每一内部层的同时注入噪声

· 从作用上来看,表面上来说是在训练过程中丢弃一些神经元

· 假设x是某一层神经网络层的输出,是下一层的输入,我们希望对x加入一些噪音,使得:

※x`的期望为x,也就是说平均上来说输出值还是x

· 暂退法对每个元素进行了如下扰动:

有p的概率下取值:

其它情况(1-p概率):

实践中使用暂退法:

· 通常将暂退法作用在全连接隐藏层的输出上

如图所示,在第一个隐藏层的输出上,有些神经元有p的概率使输出值置零。

非置零的输出值,即有1-p的概率被施加了一个较小的扰动值使其略微增大。

※暂退法只在训练中使用,dropout是正则项,在推理过程中不会使用,这样也会保证输出值确定

※每次执行暂退法的时候,实际上是每次随机采样了一些子神经网络

总结:

①暂退法将一些输出项随机置零来控制模型的复杂度

②暂退法的作用效果和正则化等价

③常应用在多层感知机的隐藏层输出上

④丢弃概率p是控制模型复杂度的超参数

二、代码实现

从零实现代码:

python 复制代码
import torch
from torch import nn
from d2l import torch as d2l

def dropout_layer(X, dropout):
    # assert用于选择dropout符合范围的情况,不符合则报错
    assert 0 <= dropout <= 1, "不符合范围!"
    # 在本情况中,所有元素都被丢弃
    if dropout == 1:
        return torch.zeros_like(X)
    # 在本情况中,所有元素都被保留
    if dropout == 0:
        return X
    # 在这一步操作中,首先定义一个和X张量形状相同但元素值均为随机数的张量
    # 将这个张量里每个元素与dropout比较,如果大于就置为True,小于等于就置为False
    # 再调用float将True和False转化为1和0
    # 这样,mask就是一个仅含1与0的张量了
    # 最后将mask里的每个元素与X里的每个元素做数乘
    mask = (torch.rand(X.shape) > dropout).float()
    return mask * X / (1.0 - dropout)

# 生成X来测试暂退法
X= torch.arange(16, dtype = torch.float32).reshape((2, 8))
print(X)
print(dropout_layer(X, 0.))
print(dropout_layer(X, 0.5))
print(dropout_layer(X, 1.))

# 定义模型参数
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256
# 定义模型
dropout1, dropout2 = 0.2, 0.5
# is_training用来表示当前是在测试还是在训练
class Net(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,
                 is_training = True):
        super(Net, self).__init__()
        self.num_inputs = num_inputs
        self.training = is_training
        self.lin1 = nn.Linear(num_inputs, num_hiddens1)
        self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)
        self.lin3 = nn.Linear(num_hiddens2, num_outputs)
        self.relu = nn.ReLU()

    def forward(self, X):
        H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))
        # 只有在训练模型时才使用dropout
        if self.training == True:
            # 在第一个全连接层之后添加一个dropout层
            H1 = dropout_layer(H1, dropout1)
        H2 = self.relu(self.lin2(H1))
        if self.training == True:
            # 在第二个全连接层之后添加一个dropout层
            H2 = dropout_layer(H2, dropout2)
            # 输出不需要dropout作用
        out = self.lin3(H2)
        return out

net = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)

# 训练、测试模型
num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

简洁实现代码:

python 复制代码
import torch
from torch import nn
from d2l import torch as d2l

# 定义概率参数
dropout1, dropout2 = 0.2, 0.5

net = nn.Sequential(nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        # 在第一个全连接层之后添加一个dropout层
        nn.Dropout(dropout1),
        nn.Linear(256, 256),
        nn.ReLU(),
        # 在第二个全连接层之后添加一个dropout层
        nn.Dropout(dropout2),
        nn.Linear(256, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);

# 训练、测试模型
num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
相关推荐
Tianyanxiao33 分钟前
如何利用探商宝精准营销,抓住行业机遇——以AI技术与大数据推动企业信息精准筛选
大数据·人工智能·科技·数据分析·深度优先·零售
撞南墙者40 分钟前
OpenCV自学系列(1)——简介和GUI特征操作
人工智能·opencv·计算机视觉
OCR_wintone42141 分钟前
易泊车牌识别相机,助力智慧工地建设
人工智能·数码相机·ocr
进击的六角龙1 小时前
Python中处理Excel的基本概念(如工作簿、工作表等)
开发语言·python·excel
王哈哈^_^1 小时前
【数据集】【YOLO】【VOC】目标检测数据集,查找数据集,yolo目标检测算法详细实战训练步骤!
人工智能·深度学习·算法·yolo·目标检测·计算机视觉·pyqt
一者仁心1 小时前
【AI技术】PaddleSpeech
人工智能
是瑶瑶子啦1 小时前
【深度学习】论文笔记:空间变换网络(Spatial Transformer Networks)
论文阅读·人工智能·深度学习·视觉检测·空间变换
一只爱好编程的程序猿1 小时前
Java后台生成指定路径下创建指定名称的文件
java·python·数据下载
EasyCVR1 小时前
萤石设备视频接入平台EasyCVR多品牌摄像机视频平台海康ehome平台(ISUP)接入EasyCVR不在线如何排查?
运维·服务器·网络·人工智能·ffmpeg·音视频
Aniay_ivy1 小时前
深入探索 Java 8 Stream 流:高效操作与应用场景
java·开发语言·python