非线性激活函数的作用是让神经网络能够理解更复杂的模式和规律。如果没有非线性激活函数,神经网络就只能进行简单的加法和乘法运算,没法处理复杂的问题。
非线性变化的目的就是给我们的网络当中引入一些非线性特征

Relu 激活函数
Relu处理图像
python
# 导入必要的库
from os import close
import torch
import torchvision.datasets
from torch import nn
from torch.nn import ReLU, Sigmoid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 加载CIFAR-10测试数据集,将图像转换为Tensor格式
dataset = torchvision.datasets.CIFAR10("./data", train=False, download=True,
transform=torchvision.transforms.ToTensor())
# 创建数据加载器,设置批量大小为64
dataloader = DataLoader(dataset, batch_size=64)
# 定义神经网络模型TY
class TY(nn.Module):
def __init__(self):
super(TY, self).__init__()
# 定义ReLU激活函数层
self.relu1 = ReLU()
# 定义Sigmoid激活函数层(当前未在forward中使用)
self.sigmod1 = Sigmoid()
def forward(self, input):
# 前向传播过程,对输入数据应用ReLU激活函数
output = self.relu1(input)
return output
# 实例化模型
ty = TY()
# 创建TensorBoard写入器,用于可视化数据
writer = SummaryWriter("./logs_relu")
# 初始化步数计数器
step = 0
# 遍历数据加载器中的每个批次
for data in dataloader:
# 获取图像数据和对应的标签
imgs, target = data
# 向TensorBoard添加原始输入图像
writer.add_images("input", imgs, step)
# 将图像数据输入模型,得到经过ReLU处理后的输出
output = ty(imgs)
# 向TensorBoard添加处理后的输出图像
writer.add_images("output", output, step)
# 步数计数器递增
step += 1
# 关闭TensorBoard写入器,释放资源
writer.close()

ReLU处理图像,效果不是很明显

Sigmoid激活函数

python
from os import close
import torch
import torchvision.datasets
from torch import nn
from torch.nn import ReLU, Sigmoid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset = torchvision.datasets.CIFAR10("./data",train=False,download=True,
transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset,batch_size=64)
class TY(nn.Module):
def __init__(self):
super(TY,self).__init__()
self.relu1=ReLU()
self.sigmoid1 = Sigmoid()
def forward(self,input):
output = self.sigmoid1(input)
return output
ty = TY()
writer = SummaryWriter("./logs_relu")
step = 0
for data in dataloader:
imgs,target=data
writer.add_images("input",imgs,step)
output = ty(imgs)
writer.add_images("output",output,step)
step+=1
writer.close()
