学习pytorch11 神经网络-非线性激活

神经网络-非线性激活

B站小土堆学习pytorch视频 非常棒的up主,讲的很详细明白

官网文档

https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity

常用1 ReLU

对输入做截断非线性处理,使模型泛化

py 复制代码
>>> m = nn.ReLU()
>>> input = torch.randn(2)
>>> output = m(input)
An implementation of CReLU - https://arxiv.org/abs/1603.05201
>>> m = nn.ReLU()
>>> input = torch.randn(2).unsqueeze(0)
>>> output = torch.cat((m(input), m(-input)))

inplace

inplace=True 原位操作 改变变量本身的值

inplace=False 重新定义一个变量output 承接input-relu后的值,一般默认为False,保留输入数据

常用2 Sigmoid

py 复制代码
>>> m = nn.Sigmoid()
>>> input = torch.randn(2)
>>> output = m(input)

弹幕:

激活层的作用是放大不同类别的得分差异

二分类输出层用sigmoid 隐藏层用relu

负值的来源:输入数据;卷积核;归一化;反向梯度下降导致负值;【不确定】

reshape(input, (-1,1,2,2))是将input这个22的张量转化为-1 12 2的张量,其中-1表示张量元素个数除以其他维度大小的乘积,即"-1" == 22/(12*2) = 1

非线性变化主要目的:为我们的网络引入非线性特征 非线性越多才能训练不同的非线性曲线或者说特征,模型泛化能力才好。

代码

py 复制代码
import torch
import torchvision.transforms
from torch import nn
from torch.nn import ReLU, Sigmoid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets

test_set = datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(test_set, batch_size=64, drop_last=True)

class Activation(nn.Module):
    def __init__(self):
        super(Activation, self).__init__()
        self.relu1 = ReLU(inplace=False)
        self.sigmoid1 = Sigmoid()

    def forward(self, input):
        # output1 = self.relu1(input)
        output2 = self.sigmoid1(input)
        # return output1
        return output2

writer = SummaryWriter('logs')
step = 0
activate = Activation()
for data in dataloader:
    imgs, target = data
    writer.add_images("input", imgs, global_step=step)
    output = activate(imgs)
    # writer.add_images("output1", output, global_step=step)
    writer.add_images("output2", output, global_step=step)
    step += 1
writer.close()

logs

相关推荐
蒋星熠19 小时前
WebSocket网络编程深度实践:从协议原理到生产级应用
网络·数据库·redis·python·websocket·网络协议·微服务
带娃的IT创业者19 小时前
实战:用 Python 搭建 MCP 服务 —— 模型上下文协议(Model Context Protocol)应用指南
开发语言·python·mcp
Hello123网站19 小时前
探迹SalesGPT
人工智能·ai工具
摘星星的屋顶20 小时前
论文阅读记录之《VelocityGPT 》
论文阅读·人工智能·深度学习·学习
万粉变现经纪人20 小时前
如何解决pip安装报错ModuleNotFoundError: No module named ‘python-dateutil’问题
开发语言·ide·python·pycharm·pandas·pip·httpx
格林威20 小时前
工业相机如何通过光度立体成像技术实现高效精准的2.5D缺陷检测
人工智能·深度学习·数码相机·yolo·计算机视觉
跟橙姐学代码20 小时前
Python 类的正确打开方式:从新手到进阶的第一步
前端·python·ipython
c8i20 小时前
关于python中的变量中使用的下划线_总结
python
MarkHD20 小时前
大语言模型入门指南:从原理到实践应用
人工智能·语言模型·自然语言处理
A尘埃20 小时前
NLP(自然语言处理, Natural Language Processing)
人工智能·自然语言处理·nlp