深入理解 PyTorch 中的激活函数
激活函数是神经网络中不可或缺的组成部分,它们引入了非线性,使得神经网络能够学习复杂的模式和特征。在 PyTorch 中,torch.nn 模块提供了多种常用的激活函数,本文将对这些激活函数进行总结和介绍。
1. 什么是激活函数?
激活函数的主要作用是将输入信号映射到输出信号,并引入非线性特性。没有激活函数的神经网络本质上是线性变换的堆叠,无法处理复杂的非线性问题。
激活函数的主要特点包括:
- 非线性:允许网络学习复杂的模式。
- 可微性:支持反向传播算法。
- 数值稳定性:避免梯度爆炸或梯度消失问题。
2. PyTorch 中的激活函数分类
PyTorch 提供了多种激活函数,以下是常见的分类:
2.1 基本激活函数
-
ReLU (Rectified Linear Unit) :
-
定义:
ReLU(x) = max(0, x)
-
特点:简单高效,解决了梯度消失问题。
-
示例:
-
python
import torch.nn as nn
relu = nn.ReLU()
-
Sigmoid:
-
定义:
Sigmoid(x) = 1 / (1 + exp(-x))
-
特点:将输出映射到 (0, 1),适合二分类问题。
-
示例:
sigmoid = nn.Sigmoid()
-
-
Tanh (Hyperbolic Tangent) :
-
定义:
Tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
-
特点:将输出映射到 (-1, 1),比 Sigmoid 更适合处理零均值数据。
-
示例:
tanh = nn.Tanh()
-
2.2 改进的 ReLU 变体
-
LeakyReLU:
-
定义:
LeakyReLU(x) = x if x > 0 else alpha * x
-
特点:允许负值通过,缓解 ReLU 的"死亡神经元"问题。
-
示例:
leaky_relu = nn.LeakyReLU(negative_slope=0.01)
-
-
PReLU (Parametric ReLU) :
-
定义:类似 LeakyReLU,但负斜率是可学习的参数。
-
示例:
prelu = nn.PReLU()
-
-
ReLU6:
-
定义:
ReLU6(x) = min(max(0, x), 6)
-
特点:限制输出范围,适合移动设备上的量化网络。
-
示例:
relu6 = nn.ReLU6()
-
2.3 平滑激活函数
-
Softplus:
-
定义:
Softplus(x) = log(1 + exp(x))
-
特点:平滑版的 ReLU。
-
示例:
-
softplus = nn.Softplus()
-
SiLU (Swish) :
-
定义:
SiLU(x) = x * Sigmoid(x)
-
特点:在深度学习中表现优异。
-
示例:
-
silu = nn.SiLU()
-
Mish:
-
定义:
Mish(x) = x * Tanh(Softplus(x))
-
特点:自正则化激活函数,适合深层网络。
-
示例:
-
mish = nn.Mish()
2.4 归一化激活函数
-
Softmax:
-
定义:
Softmax(x_i) = exp(x_i) / sum(exp(x_j))
-
特点:将输出归一化为概率分布,常用于多分类问题。
-
示例:
softmax = nn.Softmax(dim=1)
-
-
LogSoftmax:
-
定义:
LogSoftmax(x) = log(Softmax(x))
-
特点:结合负对数似然损失,数值更稳定。
-
示例:
log_softmax = nn.LogSoftmax(dim=1)
-
3. 激活函数的选择
选择激活函数时需要考虑以下因素:
-
任务类型:
-
网络深度:
-
数值稳定性:
4. 示例代码
以下是一个简单的示例,展示如何在 PyTorch 中使用激活函数:
python
import torch
import torch.nn as nn
# 定义一个简单的神经网络
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(20, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return x
# 创建模型并测试
model = SimpleNN()
input_data = torch.randn(5, 10)
output = model(input_data)
print(output)
5. 总结
激活函数是神经网络的核心组件,它们赋予网络非线性能力,使其能够处理复杂的任务。在 PyTorch 中,提供了多种激活函数以满足不同的需求。选择合适的激活函数可以显著提升模型的性能。
希望本文能帮助你更好地理解和使用 PyTorch 中的激活函数
6.补充用法
1. Threshold
公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> y = { x , if x > threshold value , otherwise y = \begin{cases} x, & \text{if } x > \text{threshold} \ \text{value}, & \text{otherwise} \end{cases} </math>y={x,if x>threshold value,otherwise
代码示例:
python
import torch
import torch.nn as nn
m = nn.Threshold(0.1, 20)
input = torch.tensor([-0.5, 0.2, 0.8])
output = m(input)
print(output) # 输出:[20.0, 0.2, 0.8]
2. ReLU
公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> ReLU ( x ) = max ( 0 , x ) \text{ReLU}(x) = \max(0, x) </math>ReLU(x)=max(0,x)
代码示例:
python
relu = nn.ReLU()
input = torch.tensor([-1.0, 0.0, 1.0])
output = relu(input)
print(output) # 输出:[0.0, 0.0, 1.0]
3. RReLU
公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> RReLU ( x ) = { x , if x ≥ 0 a ⋅ x , otherwise \text{RReLU}(x) = \begin{cases} x, & \text{if } x \geq 0 \ a \cdot x, & \text{otherwise} \end{cases} </math>RReLU(x)={x,if x≥0 a⋅x,otherwise 其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> a a </math>a 是从均匀分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> U ( lower , upper ) \mathcal{U}(\text{lower}, \text{upper}) </math>U(lower,upper) 中随机采样的值。
代码示例:
python
rrelu = nn.RReLU(lower=0.1, upper=0.3)
input = torch.tensor([-1.0, 0.0, 1.0])
output = rrelu(input)
print(output)
4. Hardtanh
公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> HardTanh ( x ) = { max_val , if x > max_val min_val , if x < min_val x , otherwise \text{HardTanh}(x) = \begin{cases} \text{max\_val}, & \text{if } x > \text{max\_val} \\ \text{min\_val}, & \text{if } x < \text{min\_val} \\ x, & \text{otherwise} \end{cases} </math>HardTanh(x)=⎩ ⎨ ⎧max_val,min_val,x,if x>max_valif x<min_valotherwise
代码示例:
python
hardtanh = nn.Hardtanh(min_val=-1.0, max_val=1.0)
input = torch.tensor([-2.0, 0.0, 2.0])
output = hardtanh(input)
print(output) # 输出:[-1.0, 0.0, 1.0]
5. ReLU6
公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> ReLU6 ( x ) = min ( max ( 0 , x ) , 6 ) \text{ReLU6}(x) = \min(\max(0, x), 6) </math>ReLU6(x)=min(max(0,x),6)
代码示例:
python
relu6 = nn.ReLU6()
input = torch.tensor([-1.0, 3.0, 7.0])
output = relu6(input)
print(output) # 输出:[0.0, 3.0, 6.0]
6. Sigmoid
公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> Sigmoid ( x ) = 1 1 + exp ( − x ) \text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)} </math>Sigmoid(x)=1+exp(−x)1
代码示例:
python
sigmoid = nn.Sigmoid()
input = torch.tensor([-1.0, 0.0, 1.0])
output = sigmoid(input)
print(output) # 输出:[0.2689, 0.5, 0.7311]
7. Hardsigmoid
公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> Hardsigmoid ( x ) = { 0 , if x ≤ − 3 1 , if x ≥ 3 x 6 + 1 2 , otherwise \text{Hardsigmoid}(x) = \begin{cases} 0, & \text{if } x \leq -3 \ 1, & \text{if } x \geq 3 \ \frac{x}{6} + \frac{1}{2}, & \text{otherwise} \end{cases} </math>Hardsigmoid(x)={0,if x≤−3 1,if x≥3 6x+21,otherwise
代码示例:
python
hardsigmoid = nn.Hardsigmoid()
input = torch.tensor([-4.0, 0.0, 4.0])
output = hardsigmoid(input)
print(output) # 输出:[0.0, 0.5, 1.0]
8. Tanh
公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> Tanh ( x ) = exp ( x ) − exp ( − x ) exp ( x ) + exp ( − x ) \text{Tanh}(x) = \frac{\exp(x) - \exp(-x)}{\exp(x) + \exp(-x)} </math>Tanh(x)=exp(x)+exp(−x)exp(x)−exp(−x)
代码示例:
python
tanh = nn.Tanh()
input = torch.tensor([-1.0, 0.0, 1.0])
output = tanh(input)
print(output) # 输出:[-0.7616, 0.0, 0.7616]
9. SiLU (Swish)
公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> SiLU ( x ) = x ⋅ Sigmoid ( x ) \text{SiLU}(x) = x \cdot \text{Sigmoid}(x) </math>SiLU(x)=x⋅Sigmoid(x)
代码示例:
python
silu = nn.SiLU()
input = torch.tensor([-1.0, 0.0, 1.0])
output = silu(input)
print(output)
10. Mish
公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> Mish ( x ) = x ⋅ tanh ( Softplus ( x ) ) \text{Mish}(x) = x \cdot \tanh(\text{Softplus}(x)) </math>Mish(x)=x⋅tanh(Softplus(x))
代码示例:
python
mish = nn.Mish()
input = torch.tensor([-1.0, 0.0, 1.0])
output = mish(input)
print(output)
11. Hardswish
公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> Hardswish ( x ) = { 0 , if x ≤ − 3 x , if x ≥ 3 x ⋅ x + 3 6 , otherwise \text{Hardswish}(x) = \begin{cases} 0, & \text{if } x \leq -3 \ x, & \text{if } x \geq 3 \ x \cdot \frac{x + 3}{6}, & \text{otherwise} \end{cases} </math>Hardswish(x)={0,if x≤−3 x,if x≥3 x⋅6x+3,otherwise
代码示例:
python
hardswish = nn.Hardswish()
input = torch.tensor([-4.0, 0.0, 4.0])
output = hardswish(input)
print(output)
12. ELU
公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> ELU ( x ) = { x , if x > 0 α ⋅ ( exp ( x ) − 1 ) , if x ≤ 0 \text{ELU}(x) = \begin{cases} x, & \text{if } x > 0 \ \alpha \cdot (\exp(x) - 1), & \text{if } x \leq 0 \end{cases} </math>ELU(x)={x,if x>0 α⋅(exp(x)−1),if x≤0
代码示例:
python
elu = nn.ELU(alpha=1.0)
input = torch.tensor([-1.0, 0.0, 1.0])
output = elu(input)
print(output)
13. Softmax
公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> Softmax ( x i ) = exp ( x i ) ∑ j exp ( x j ) \text{Softmax}(x_i) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} </math>Softmax(xi)=∑jexp(xj)exp(xi)
代码示例:
python
softmax = nn.Softmax(dim=0)
input = torch.tensor([1.0, 2.0, 3.0])
output = softmax(input)
print(output)