PyTorch 中的激活函数

深入理解 PyTorch 中的激活函数

激活函数是神经网络中不可或缺的组成部分,它们引入了非线性,使得神经网络能够学习复杂的模式和特征。在 PyTorch 中,torch.nn 模块提供了多种常用的激活函数,本文将对这些激活函数进行总结和介绍。


1. 什么是激活函数?

激活函数的主要作用是将输入信号映射到输出信号,并引入非线性特性。没有激活函数的神经网络本质上是线性变换的堆叠,无法处理复杂的非线性问题。

激活函数的主要特点包括:

  • 非线性:允许网络学习复杂的模式。
  • 可微性:支持反向传播算法。
  • 数值稳定性:避免梯度爆炸或梯度消失问题。

2. PyTorch 中的激活函数分类

PyTorch 提供了多种激活函数,以下是常见的分类:

2.1 基本激活函数

  • ReLU (Rectified Linear Unit) :

    • 定义:ReLU(x) = max(0, x)

    • 特点:简单高效,解决了梯度消失问题。

    • 示例:

python 复制代码
    import torch.nn as nn

    relu = nn.ReLU()
  • Sigmoid:

    • 定义:Sigmoid(x) = 1 / (1 + exp(-x))

    • 特点:将输出映射到 (0, 1),适合二分类问题。

    • 示例:

      sigmoid = nn.Sigmoid()

  • Tanh (Hyperbolic Tangent) :

    • 定义:Tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))

    • 特点:将输出映射到 (-1, 1),比 Sigmoid 更适合处理零均值数据。

    • 示例:

    tanh = nn.Tanh()

2.2 改进的 ReLU 变体

  • LeakyReLU:

    • 定义:LeakyReLU(x) = x if x > 0 else alpha * x

    • 特点:允许负值通过,缓解 ReLU 的"死亡神经元"问题。

    • 示例:

      leaky_relu = nn.LeakyReLU(negative_slope=0.01)

  • PReLU (Parametric ReLU) :

    • 定义:类似 LeakyReLU,但负斜率是可学习的参数。

    • 示例:

    prelu = nn.PReLU()

  • ReLU6:

    • 定义:ReLU6(x) = min(max(0, x), 6)

    • 特点:限制输出范围,适合移动设备上的量化网络。

    • 示例:

    relu6 = nn.ReLU6()

2.3 平滑激活函数

  • Softplus:

    • 定义:Softplus(x) = log(1 + exp(x))

    • 特点:平滑版的 ReLU。

    • 示例:

softplus = nn.Softplus()

  • SiLU (Swish) :

    • 定义:SiLU(x) = x * Sigmoid(x)

    • 特点:在深度学习中表现优异。

    • 示例:

silu = nn.SiLU()

  • Mish:

    • 定义:Mish(x) = x * Tanh(Softplus(x))

    • 特点:自正则化激活函数,适合深层网络。

    • 示例:

mish = nn.Mish()

2.4 归一化激活函数

  • Softmax:

    • 定义:Softmax(x_i) = exp(x_i) / sum(exp(x_j))

    • 特点:将输出归一化为概率分布,常用于多分类问题。

    • 示例:

      softmax = nn.Softmax(dim=1)

  • LogSoftmax:

    • 定义:LogSoftmax(x) = log(Softmax(x))

    • 特点:结合负对数似然损失,数值更稳定。

    • 示例:

      log_softmax = nn.LogSoftmax(dim=1)


3. 激活函数的选择

选择激活函数时需要考虑以下因素:


4. 示例代码

以下是一个简单的示例,展示如何在 PyTorch 中使用激活函数:

python 复制代码
import torch
import torch.nn as nn

# 定义一个简单的神经网络
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(20, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x

# 创建模型并测试
model = SimpleNN()
input_data = torch.randn(5, 10)
output = model(input_data)
print(output)

5. 总结

激活函数是神经网络的核心组件,它们赋予网络非线性能力,使其能够处理复杂的任务。在 PyTorch 中,提供了多种激活函数以满足不同的需求。选择合适的激活函数可以显著提升模型的性能。

希望本文能帮助你更好地理解和使用 PyTorch 中的激活函数

6.补充用法

1. Threshold

公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> y = { x , if x > threshold value , otherwise y = \begin{cases} x, & \text{if } x > \text{threshold} \ \text{value}, & \text{otherwise} \end{cases} </math>y={x,if x>threshold value,otherwise

代码示例

python 复制代码
import torch

import torch.nn as nn

m = nn.Threshold(0.1, 20)

input = torch.tensor([-0.5, 0.2, 0.8])

output = m(input)

print(output)  # 输出:[20.0, 0.2, 0.8]

2. ReLU

公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> ReLU ( x ) = max ⁡ ( 0 , x ) \text{ReLU}(x) = \max(0, x) </math>ReLU(x)=max(0,x)

代码示例

python 复制代码
relu = nn.ReLU()

input = torch.tensor([-1.0, 0.0, 1.0])

output = relu(input)

print(output)  # 输出:[0.0, 0.0, 1.0]

3. RReLU

公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> RReLU ( x ) = { x , if x ≥ 0 a ⋅ x , otherwise \text{RReLU}(x) = \begin{cases} x, & \text{if } x \geq 0 \ a \cdot x, & \text{otherwise} \end{cases} </math>RReLU(x)={x,if x≥0 a⋅x,otherwise 其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> a a </math>a 是从均匀分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> U ( lower , upper ) \mathcal{U}(\text{lower}, \text{upper}) </math>U(lower,upper) 中随机采样的值。

代码示例

python 复制代码
rrelu = nn.RReLU(lower=0.1, upper=0.3)

input = torch.tensor([-1.0, 0.0, 1.0])

output = rrelu(input)

print(output)

4. Hardtanh

公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> HardTanh ( x ) = { max_val , if x > max_val min_val , if x < min_val x , otherwise \text{HardTanh}(x) = \begin{cases} \text{max\_val}, & \text{if } x > \text{max\_val} \\ \text{min\_val}, & \text{if } x < \text{min\_val} \\ x, & \text{otherwise} \end{cases} </math>HardTanh(x)=⎩ ⎨ ⎧max_val,min_val,x,if x>max_valif x<min_valotherwise

代码示例

python 复制代码
hardtanh = nn.Hardtanh(min_val=-1.0, max_val=1.0)

input = torch.tensor([-2.0, 0.0, 2.0])

output = hardtanh(input)

print(output)  # 输出:[-1.0, 0.0, 1.0]

5. ReLU6

公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> ReLU6 ( x ) = min ⁡ ( max ⁡ ( 0 , x ) , 6 ) \text{ReLU6}(x) = \min(\max(0, x), 6) </math>ReLU6(x)=min(max(0,x),6)

代码示例

python 复制代码
relu6 = nn.ReLU6()

input = torch.tensor([-1.0, 3.0, 7.0])

output = relu6(input)

print(output)  # 输出:[0.0, 3.0, 6.0]

6. Sigmoid

公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> Sigmoid ( x ) = 1 1 + exp ⁡ ( − x ) \text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)} </math>Sigmoid(x)=1+exp(−x)1

代码示例

python 复制代码
sigmoid = nn.Sigmoid()

input = torch.tensor([-1.0, 0.0, 1.0])

output = sigmoid(input)

print(output)  # 输出:[0.2689, 0.5, 0.7311]

7. Hardsigmoid

公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> Hardsigmoid ( x ) = { 0 , if x ≤ − 3 1 , if x ≥ 3 x 6 + 1 2 , otherwise \text{Hardsigmoid}(x) = \begin{cases} 0, & \text{if } x \leq -3 \ 1, & \text{if } x \geq 3 \ \frac{x}{6} + \frac{1}{2}, & \text{otherwise} \end{cases} </math>Hardsigmoid(x)={0,if x≤−3 1,if x≥3 6x+21,otherwise

代码示例

python 复制代码
hardsigmoid = nn.Hardsigmoid()

input = torch.tensor([-4.0, 0.0, 4.0])

output = hardsigmoid(input)

print(output)  # 输出:[0.0, 0.5, 1.0]

8. Tanh

公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> Tanh ( x ) = exp ⁡ ( x ) − exp ⁡ ( − x ) exp ⁡ ( x ) + exp ⁡ ( − x ) \text{Tanh}(x) = \frac{\exp(x) - \exp(-x)}{\exp(x) + \exp(-x)} </math>Tanh(x)=exp(x)+exp(−x)exp(x)−exp(−x)

代码示例

python 复制代码
tanh = nn.Tanh()

input = torch.tensor([-1.0, 0.0, 1.0])

output = tanh(input)

print(output)  # 输出:[-0.7616, 0.0, 0.7616]

9. SiLU (Swish)

公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> SiLU ( x ) = x ⋅ Sigmoid ( x ) \text{SiLU}(x) = x \cdot \text{Sigmoid}(x) </math>SiLU(x)=x⋅Sigmoid(x)

代码示例

python 复制代码
silu = nn.SiLU()

input = torch.tensor([-1.0, 0.0, 1.0])

output = silu(input)

print(output)

10. Mish

公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> Mish ( x ) = x ⋅ tanh ⁡ ( Softplus ( x ) ) \text{Mish}(x) = x \cdot \tanh(\text{Softplus}(x)) </math>Mish(x)=x⋅tanh(Softplus(x))

代码示例

python 复制代码
mish = nn.Mish()

input = torch.tensor([-1.0, 0.0, 1.0])

output = mish(input)

print(output)

11. Hardswish

公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> Hardswish ( x ) = { 0 , if x ≤ − 3 x , if x ≥ 3 x ⋅ x + 3 6 , otherwise \text{Hardswish}(x) = \begin{cases} 0, & \text{if } x \leq -3 \ x, & \text{if } x \geq 3 \ x \cdot \frac{x + 3}{6}, & \text{otherwise} \end{cases} </math>Hardswish(x)={0,if x≤−3 x,if x≥3 x⋅6x+3,otherwise

代码示例

python 复制代码
hardswish = nn.Hardswish()

input = torch.tensor([-4.0, 0.0, 4.0])

output = hardswish(input)

print(output)

12. ELU

公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> ELU ( x ) = { x , if x > 0 α ⋅ ( exp ⁡ ( x ) − 1 ) , if x ≤ 0 \text{ELU}(x) = \begin{cases} x, & \text{if } x > 0 \ \alpha \cdot (\exp(x) - 1), & \text{if } x \leq 0 \end{cases} </math>ELU(x)={x,if x>0 α⋅(exp(x)−1),if x≤0

代码示例

python 复制代码
elu = nn.ELU(alpha=1.0)

input = torch.tensor([-1.0, 0.0, 1.0])

output = elu(input)

print(output)

13. Softmax

公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> Softmax ( x i ) = exp ⁡ ( x i ) ∑ j exp ⁡ ( x j ) \text{Softmax}(x_i) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} </math>Softmax(xi)=∑jexp(xj)exp(xi)

代码示例

python 复制代码
softmax = nn.Softmax(dim=0)

input = torch.tensor([1.0, 2.0, 3.0])

output = softmax(input)

print(output)

相关推荐
视觉&物联智能4 分钟前
【杂谈】-2025年AI与网络安全六大趋势展望
人工智能·安全·web安全·网络安全·ai·agi·数字安全
梦想画家5 分钟前
PyTorch系列教程:基于LSTM构建情感分析模型
人工智能·pytorch·lstm
知舟不叙13 分钟前
机器学习——深入浅出理解朴素贝叶斯算法
人工智能·python·算法·机器学习
CodeJourney.28 分钟前
AI赋能办公:开启高效职场新时代
数据库·人工智能·算法
yscript35 分钟前
linux系统安装和激活conda
linux·运维·人工智能·python·深度学习·conda
szxinmai主板定制专家42 分钟前
基于FPGA的3U机箱轨道交通网络通讯板,对内和主控板、各类IO板通信,对外可进行RS485、CAN或MVB组网通信
大数据·人工智能·嵌入式硬件·fpga开发·边缘计算
海特伟业1 小时前
森林防火预警广播监控系统:以4G为纽带架构融合智能广播、远程监控、AI智能识别、告警提示、太阳能供电于一体的新一代森林防火预警系统
人工智能·架构
KangkangLoveNLP1 小时前
简单循环神经网络(RNN):原理、结构与应用
人工智能·pytorch·rnn·深度学习·神经网络·机器学习·transformer
未来之窗软件服务1 小时前
数字人本地部署之llama-本地推理模型
人工智能·llama·数字人
扶摇升1 小时前
大型语言模型(LLM)部署中的内存消耗计算
人工智能·语言模型·自然语言处理