【机器学习】Softmax 函数

Softmax 是机器学习中常用的函数,广泛用于多分类问题的输出层。它可以将一组实数转换为一个概率分布,使得结果满足"非负"和"总和为1"的要求。在分类问题中,Softmax 让模型预测的每个类别概率都易于解释。本文将详细讲解 Softmax 的原理、公式推导、Numpy 实现及其在 Pytorch 中的实际应用。

Softmax 原理

给定一个类别集合 { y 1 , y 2 , ... , y n } \{y_1, y_2, \dots, y_n\} {y1,y2,...,yn},Softmax 将模型输出的每个数值(称为"得分"或"logits")转换为概率值。假设模型输出 z i z_i zi 为第 i i i 类的得分,Softmax 将所有的得分映射到概率空间,使每个得分转化为该类的预测概率。

Softmax 函数的公式为:
P ( y i ) = e z i ∑ j = 1 n e z j P(y_i) = \frac{e^{z_i}}{\sum_{j=1}^n e^{z_j}} P(yi)=∑j=1nezjezi

其中 z i z_i zi 表示模型为第 i i i 类输出的得分, n n n 是类别的数量。通过对指数值的归一化处理,Softmax 函数输出的概率满足:

  1. 所有概率值都为非负数;
  2. 概率总和为 1。

Softmax 计算中的数值稳定性

在计算中,Softmax 可能会因为指数运算导致数值溢出,为了减小这种风险,可以对每个 (z_i) 值减去一个常数 max ⁡ ( z ) \max(z) max(z):
P ( y i ) = e z i − max ⁡ ( z ) ∑ j = 1 n e z j − max ⁡ ( z ) P(y_i) = \frac{e^{z_i - \max(z)}}{\sum_{j=1}^n e^{z_j - \max(z)}} P(yi)=∑j=1nezj−max(z)ezi−max(z)

这种转换不会改变概率的分布,避免了指数函数产生的大数值溢出问题。

Numpy 实现 Softmax 函数

下面通过 Numpy 实现 Softmax,并进行数据可视化以更直观地理解 Softmax 对得分的转换过程。

python 复制代码
import numpy as np
import matplotlib.pyplot as plt

# 定义 Softmax 函数
def softmax(logits):
    """
    使用数值稳定性的 Softmax 函数实现

    参数:
    - logits: 模型输出得分向量(shape: (n,),表示 n 个类别的得分)

    返回:
    - probs: 转换后的概率向量,shape: (n,)
    """
    exp_shifted = np.exp(logits - np.max(logits))  # 减去 max(logits) 以确保数值稳定性
    probs = exp_shifted / np.sum(exp_shifted)  # 归一化为概率
    return probs

# 示例输入的分类得分
logits = np.array([2.0, 1.0, 0.1])

# 使用 Softmax 函数计算各类别的概率
probs = softmax(logits)

# 输出各类的预测概率
print("分类得分:", logits)
print("预测概率:", probs)

Softmax 输出可视化

我们可以用图像展示 Softmax 如何将得分转化为概率,假设输入的分类得分范围为 -2 到 4。

python 复制代码
# 生成模拟的分类得分范围
logit_range = np.linspace(-2, 4, 100)
all_probs = np.array([softmax([l, 1.0, 0.1]) for l in logit_range])

# 可视化不同类别的预测概率随得分变化的趋势
plt.plot(logit_range, all_probs[:, 0], label="类别 1")
plt.plot(logit_range, all_probs[:, 1], label="类别 2")
plt.plot(logit_range, all_probs[:, 2], label="类别 3")
plt.xlabel("得分 (logits)")
plt.ylabel("概率")
plt.title("Softmax 函数输出的概率分布")
plt.legend()
plt.show()

Softmax 损失函数:交叉熵损失

在多分类任务中,常用 交叉熵损失函数 来衡量模型预测概率分布与真实标签的匹配程度。对于单个样本,交叉熵损失定义为:
L = − ∑ i = 1 n y i ⋅ log ⁡ ( P ( y i ) ) L = -\sum_{i=1}^{n} y_i \cdot \log(P(y_i)) L=−i=1∑nyi⋅log(P(yi))

其中 (y_i) 是真实标签的 one-hot 编码,(P(y_i)) 是 Softmax 转换后的预测概率。

python 复制代码
# 定义交叉熵损失函数
def cross_entropy_loss(probs, y_true):
    """
    计算交叉熵损失
    
    参数:
    - probs: Softmax 预测概率 (shape: (n,))
    - y_true: 实际标签 (shape: (n,)),one-hot 编码

    返回:
    - loss: 交叉熵损失
    """
    loss = -np.sum(y_true * np.log(probs + 1e-10))  # 加1e-10防止 log(0)
    return loss

# 示例计算
y_true = np.array([1, 0, 0])  # 假设类别 1 为正确类别
loss = cross_entropy_loss(probs, y_true)
print("交叉熵损失:", loss)

在 PyTorch 中使用 Softmax

在 PyTorch 中,我们可以直接调用 torch.nn.functional.softmax 来实现 Softmax。此外,PyTorch 提供的 torch.nn.CrossEntropyLoss 函数在内部自动包含了 Softmax 和交叉熵的计算,无需显式计算。

python 复制代码
import torch
import torch.nn.functional as F

# 示例:在 PyTorch 中实现 Softmax 和交叉熵损失
logits_torch = torch.tensor([2.0, 1.0, 0.1])

# 使用 PyTorch 的 Softmax 函数
probs_torch = F.softmax(logits_torch, dim=0)
print("PyTorch 预测概率:", probs_torch.numpy())

# 使用交叉熵损失函数
y_true_index = torch.tensor([0])  # 假设第一个类别为正确类别
loss_fn = torch.nn.CrossEntropyLoss()
loss_torch = loss_fn(logits_torch.unsqueeze(0), y_true_index)
print("PyTorch 交叉熵损失:", loss_torch.item())

在 PyTorch 中,torch.nn.CrossEntropyLoss 在传入 logits 后自动应用 Softmax 和交叉熵计算,为多分类问题提供了便捷的计算方式。

总结

本文介绍了 Softmax 的原理、公式、Numpy 实现、可视化以及在 PyTorch 中的使用。Softmax 是将得分转化为概率分布的关键函数,尤其适用于多分类任务。我们还探讨了数值稳定性的处理以及交叉熵损失在多分类中的作用,理解并实现 Softmax 有助于构建更稳定且易解释的分类模型。

相关推荐
有Li5 小时前
结合无监督表示学习与伪标签监督的自蒸馏方法,用于稀有疾病影像表型分类的分散感知失衡校正|文献速递-基于生成模型的数据增强与疾病监测应用
学习·分类·数据挖掘
STRANGEX-035 小时前
深度学习案例:带有一个隐藏层的平面数据分类
深度学习·平面·分类
YRr YRr6 小时前
深入解析最小二乘法:原理、应用与局限
算法·机器学习·最小二乘法
梭七y6 小时前
(自用)机器学习python代码相关笔记
笔记·机器学习·sklearn
T0uken7 小时前
【机器学习】逻辑回归
人工智能·机器学习·逻辑回归
weixin_307779137 小时前
研究深度神经网络优化稳定性,证明在一定条件下梯度下降和随机梯度下降方法能有效控制损失函数
深度学习·机器学习·dnn
向向20248 小时前
TIFS-2024 FIRe2:细粒度表示和重组在换衣行人重识别中的应用
人工智能·机器学习·支持向量机
墨@#≯8 小时前
回归与分类中的过拟合问题探讨与解决
机器学习·分类·回归·正则化·过拟合