学习pytorch11 神经网络-非线性激活

神经网络-非线性激活

B站小土堆学习pytorch视频 非常棒的up主,讲的很详细明白

官网文档

https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity

常用1 ReLU

对输入做截断非线性处理,使模型泛化

py 复制代码
>>> m = nn.ReLU()
>>> input = torch.randn(2)
>>> output = m(input)
An implementation of CReLU - https://arxiv.org/abs/1603.05201
>>> m = nn.ReLU()
>>> input = torch.randn(2).unsqueeze(0)
>>> output = torch.cat((m(input), m(-input)))

inplace

inplace=True 原位操作 改变变量本身的值

inplace=False 重新定义一个变量output 承接input-relu后的值,一般默认为False,保留输入数据

常用2 Sigmoid

py 复制代码
>>> m = nn.Sigmoid()
>>> input = torch.randn(2)
>>> output = m(input)

弹幕:

激活层的作用是放大不同类别的得分差异

二分类输出层用sigmoid 隐藏层用relu

负值的来源:输入数据;卷积核;归一化;反向梯度下降导致负值;【不确定】

reshape(input, (-1,1,2,2))是将input这个22的张量转化为-1 12 2的张量,其中-1表示张量元素个数除以其他维度大小的乘积,即"-1" == 22/(12*2) = 1

非线性变化主要目的:为我们的网络引入非线性特征 非线性越多才能训练不同的非线性曲线或者说特征,模型泛化能力才好。

代码

py 复制代码
import torch
import torchvision.transforms
from torch import nn
from torch.nn import ReLU, Sigmoid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets

test_set = datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(test_set, batch_size=64, drop_last=True)

class Activation(nn.Module):
    def __init__(self):
        super(Activation, self).__init__()
        self.relu1 = ReLU(inplace=False)
        self.sigmoid1 = Sigmoid()

    def forward(self, input):
        # output1 = self.relu1(input)
        output2 = self.sigmoid1(input)
        # return output1
        return output2

writer = SummaryWriter('logs')
step = 0
activate = Activation()
for data in dataloader:
    imgs, target = data
    writer.add_images("input", imgs, global_step=step)
    output = activate(imgs)
    # writer.add_images("output1", output, global_step=step)
    writer.add_images("output2", output, global_step=step)
    step += 1
writer.close()

logs

相关推荐
l14372332672 分钟前
电影解说详细教程:从「一条视频」到「持续更新」
人工智能
MUTA️6 分钟前
BCEWithLogitsLoss
人工智能
deephub12 分钟前
使用 tsfresh 和 AutoML 进行时间序列特征工程
人工智能·python·机器学习·特征工程·时间序列
静听松涛13312 分钟前
从模式识别到逻辑推理的认知跨越
人工智能·机器学习
牛客企业服务13 分钟前
AI面试选型策略:2026年五大核心维度解析
人工智能
0思必得018 分钟前
[Web自动化] Selenium中Select元素操作方法
前端·python·selenium·自动化·html
啊阿狸不会拉杆21 分钟前
《机器学习》第四章-无监督学习
人工智能·学习·算法·机器学习·计算机视觉
Duang007_22 分钟前
【万字学习总结】API设计与接口开发实战指南
开发语言·javascript·人工智能·python·学习
图生生23 分钟前
AI溶图技术+光影适配:实现产品场景图的高质量合成
人工智能·ai
小北方城市网25 分钟前
JVM 调优实战指南:从问题排查到参数优化
java·spring boot·python·rabbitmq·java-rabbitmq·数据库架构