PyTorch-----torch.nn.Softmax()函数

Softmax原理

Softmax 函数是一种常用的激活函数,通常用于多分类问题中。它将一个含有多个实数值的向量(通常称为 logits)转换成一个概率分布,使得每个元素都在 (0, 1) 区间内,并且所有元素的和为 1。

假设我们有一个实数值向量 z,其中 z = [z1, z2, ..., zn],其中 zi 是向量 z 的第 i 个元素。Softmax 函数将向量 z 转换为一个概率分布向量 p = [p1, p2, ..., pn],其中 pi 表示类别 i 的概率。

Softmax 函数的定义如下:

其中,zi 是 logits 向量 z 的第 i 个元素,n 是 logits 向量 z 的长度(即类别的数量),e 是自然对数的底(约等于 2.71828)。

Softmax 函数的计算过程如下:

  1. 对 logits 向量 z 中的每个元素进行指数化(即计算 e 的 z 次方)。
  2. 计算所有指数化的值的和(即分母部分)。
  3. 将每个指数化的值除以总和,得到归一化后的概率值。

Softmax 函数的一个关键特性是它的输出是一个概率分布,即所有输出值的和为 1,因此可以用于表示多个互斥的类别的概率。

在神经网络中,Softmax 函数通常作为输出层的激活函数使用,用于将网络的最后一层输出转换为概率分布,以便进行多分类任务的训练和预测。

softmax应用

torch.nn.Softmax 是 PyTorch 中的一个类,用于计算 softmax 函数。softmax 函数常用于多分类问题中,将一个具有任意实数值的向量转换为一个概率分布,使得每个元素都在 (0, 1) 之间,并且所有元素的和为 1。

在 PyTorch 中,torch.nn.Softmax 可以作为一个层(Layer)添加到神经网络模型中,也可以作为一个函数使用。它的语法如下:

python 复制代码
torch.nn.Softmax(dim=None)
  • dim(可选):指定 softmax 函数计算的维度。默认值为 -1,表示最后一个维度。

torch.nn.Softmax 类初始化后可以调用其 forward 方法来计算 softmax 函数。另外,你也可以直接使用 torch.softmax() 函数来计算 softmax。

下面是使用 torch.nn.Softmax 类的一个示例:

python 复制代码
import torch
import torch.nn as nn

# 创建一个 3x4 的输入张量
input_tensor = torch.randn(3, 4)

# 创建 Softmax 层
softmax_layer = nn.Softmax(dim=1)

# 对输入张量应用 Softmax 层
output_tensor = softmax_layer(input_tensor)

print(output_tensor)

这里,我们首先创建了一个 3x4 的输入张量 input_tensor,然后创建了一个 softmax 层,并将其应用于输入张量。最终得到的 output_tensor 是一个概率分布,其中每一行的元素都在 (0, 1) 之间,并且每一行的元素之和为 1。

你也可以使用 torch.softmax() 函数直接计算 softmax,示例如下:

python 复制代码
output_tensor = torch.softmax(input_tensor, dim=1)

这与使用 softmax 层的结果是相同的。

相关推荐
ZH15455891312 分钟前
Flutter for OpenHarmony Python学习助手实战:面向对象编程实战的实现
python·学习·flutter
玄同7653 分钟前
SQLite + LLM:大模型应用落地的轻量级数据存储方案
jvm·数据库·人工智能·python·语言模型·sqlite·知识图谱
User_芊芊君子8 分钟前
CANN010:PyASC Python编程接口—简化AI算子开发的Python框架
开发语言·人工智能·python
白日做梦Q18 分钟前
Anchor-free检测器全解析:CenterNet vs FCOS
python·深度学习·神经网络·目标检测·机器学习
喵手32 分钟前
Python爬虫实战:公共自行车站点智能采集系统 - 从零构建生产级爬虫的完整实战(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·采集公共自行车站点·公共自行车站点智能采集系统·采集公共自行车站点导出csv
喵手40 分钟前
Python爬虫实战:地图 POI + 行政区反查实战 - 商圈热力数据准备完整方案(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·地区poi·行政区反查·商圈热力数据采集
熊猫_豆豆1 小时前
YOLOP车道检测
人工智能·python·算法
nimadan121 小时前
**热门短剧小说扫榜工具2025推荐,精准捕捉爆款趋势与流量
人工智能·python
默默前行的虫虫1 小时前
MQTT.fx实际操作
python
YMWM_1 小时前
python3继承使用
开发语言·python