【机器学习】Softmax 函数

Softmax 是机器学习中常用的函数,广泛用于多分类问题的输出层。它可以将一组实数转换为一个概率分布,使得结果满足"非负"和"总和为1"的要求。在分类问题中,Softmax 让模型预测的每个类别概率都易于解释。本文将详细讲解 Softmax 的原理、公式推导、Numpy 实现及其在 Pytorch 中的实际应用。

Softmax 原理

给定一个类别集合 { y 1 , y 2 , ... , y n } \{y_1, y_2, \dots, y_n\} {y1,y2,...,yn},Softmax 将模型输出的每个数值(称为"得分"或"logits")转换为概率值。假设模型输出 z i z_i zi 为第 i i i 类的得分,Softmax 将所有的得分映射到概率空间,使每个得分转化为该类的预测概率。

Softmax 函数的公式为:
P ( y i ) = e z i ∑ j = 1 n e z j P(y_i) = \frac{e^{z_i}}{\sum_{j=1}^n e^{z_j}} P(yi)=∑j=1nezjezi

其中 z i z_i zi 表示模型为第 i i i 类输出的得分, n n n 是类别的数量。通过对指数值的归一化处理,Softmax 函数输出的概率满足:

  1. 所有概率值都为非负数;
  2. 概率总和为 1。

Softmax 计算中的数值稳定性

在计算中,Softmax 可能会因为指数运算导致数值溢出,为了减小这种风险,可以对每个 (z_i) 值减去一个常数 max ⁡ ( z ) \max(z) max(z):
P ( y i ) = e z i − max ⁡ ( z ) ∑ j = 1 n e z j − max ⁡ ( z ) P(y_i) = \frac{e^{z_i - \max(z)}}{\sum_{j=1}^n e^{z_j - \max(z)}} P(yi)=∑j=1nezj−max(z)ezi−max(z)

这种转换不会改变概率的分布,避免了指数函数产生的大数值溢出问题。

Numpy 实现 Softmax 函数

下面通过 Numpy 实现 Softmax,并进行数据可视化以更直观地理解 Softmax 对得分的转换过程。

python 复制代码
import numpy as np
import matplotlib.pyplot as plt

# 定义 Softmax 函数
def softmax(logits):
    """
    使用数值稳定性的 Softmax 函数实现

    参数:
    - logits: 模型输出得分向量(shape: (n,),表示 n 个类别的得分)

    返回:
    - probs: 转换后的概率向量,shape: (n,)
    """
    exp_shifted = np.exp(logits - np.max(logits))  # 减去 max(logits) 以确保数值稳定性
    probs = exp_shifted / np.sum(exp_shifted)  # 归一化为概率
    return probs

# 示例输入的分类得分
logits = np.array([2.0, 1.0, 0.1])

# 使用 Softmax 函数计算各类别的概率
probs = softmax(logits)

# 输出各类的预测概率
print("分类得分:", logits)
print("预测概率:", probs)

Softmax 输出可视化

我们可以用图像展示 Softmax 如何将得分转化为概率,假设输入的分类得分范围为 -2 到 4。

python 复制代码
# 生成模拟的分类得分范围
logit_range = np.linspace(-2, 4, 100)
all_probs = np.array([softmax([l, 1.0, 0.1]) for l in logit_range])

# 可视化不同类别的预测概率随得分变化的趋势
plt.plot(logit_range, all_probs[:, 0], label="类别 1")
plt.plot(logit_range, all_probs[:, 1], label="类别 2")
plt.plot(logit_range, all_probs[:, 2], label="类别 3")
plt.xlabel("得分 (logits)")
plt.ylabel("概率")
plt.title("Softmax 函数输出的概率分布")
plt.legend()
plt.show()

Softmax 损失函数:交叉熵损失

在多分类任务中,常用 交叉熵损失函数 来衡量模型预测概率分布与真实标签的匹配程度。对于单个样本,交叉熵损失定义为:
L = − ∑ i = 1 n y i ⋅ log ⁡ ( P ( y i ) ) L = -\sum_{i=1}^{n} y_i \cdot \log(P(y_i)) L=−i=1∑nyi⋅log(P(yi))

其中 (y_i) 是真实标签的 one-hot 编码,(P(y_i)) 是 Softmax 转换后的预测概率。

python 复制代码
# 定义交叉熵损失函数
def cross_entropy_loss(probs, y_true):
    """
    计算交叉熵损失
    
    参数:
    - probs: Softmax 预测概率 (shape: (n,))
    - y_true: 实际标签 (shape: (n,)),one-hot 编码

    返回:
    - loss: 交叉熵损失
    """
    loss = -np.sum(y_true * np.log(probs + 1e-10))  # 加1e-10防止 log(0)
    return loss

# 示例计算
y_true = np.array([1, 0, 0])  # 假设类别 1 为正确类别
loss = cross_entropy_loss(probs, y_true)
print("交叉熵损失:", loss)

在 PyTorch 中使用 Softmax

在 PyTorch 中,我们可以直接调用 torch.nn.functional.softmax 来实现 Softmax。此外,PyTorch 提供的 torch.nn.CrossEntropyLoss 函数在内部自动包含了 Softmax 和交叉熵的计算,无需显式计算。

python 复制代码
import torch
import torch.nn.functional as F

# 示例:在 PyTorch 中实现 Softmax 和交叉熵损失
logits_torch = torch.tensor([2.0, 1.0, 0.1])

# 使用 PyTorch 的 Softmax 函数
probs_torch = F.softmax(logits_torch, dim=0)
print("PyTorch 预测概率:", probs_torch.numpy())

# 使用交叉熵损失函数
y_true_index = torch.tensor([0])  # 假设第一个类别为正确类别
loss_fn = torch.nn.CrossEntropyLoss()
loss_torch = loss_fn(logits_torch.unsqueeze(0), y_true_index)
print("PyTorch 交叉熵损失:", loss_torch.item())

在 PyTorch 中,torch.nn.CrossEntropyLoss 在传入 logits 后自动应用 Softmax 和交叉熵计算,为多分类问题提供了便捷的计算方式。

总结

本文介绍了 Softmax 的原理、公式、Numpy 实现、可视化以及在 PyTorch 中的使用。Softmax 是将得分转化为概率分布的关键函数,尤其适用于多分类任务。我们还探讨了数值稳定性的处理以及交叉熵损失在多分类中的作用,理解并实现 Softmax 有助于构建更稳定且易解释的分类模型。

相关推荐
搏博38 分钟前
神经网络问题之二:梯度爆炸(Gradient Explosion)
人工智能·深度学习·神经网络
Chef_Chen1 小时前
从0开始学习机器学习--Day33--机器学习阶段总结
人工智能·学习·机器学习
搏博1 小时前
神经网络问题之:梯度不稳定
人工智能·深度学习·神经网络
Sxiaocai1 小时前
使用 PyTorch 实现并训练 VGGNet 用于 MNIST 分类
pytorch·深度学习·分类
databook1 小时前
『玩转Streamlit』--布局与容器组件
python·机器学习·数据分析
shansjqun1 小时前
教学内容全覆盖:航拍杂草检测与分类
人工智能·分类·数据挖掘
肖永威2 小时前
CentOS环境上离线安装python3及相关包
linux·运维·机器学习·centos
白光白光3 小时前
量子神经网络
人工智能·深度学习·神经网络
IT古董4 小时前
【人工智能】Python在机器学习与人工智能中的应用
开发语言·人工智能·python·机器学习
CV学术叫叫兽5 小时前
快速图像识别:落叶植物叶片分类
人工智能·分类·数据挖掘