PyTorch-----torch.nn.Softmax()函数

Softmax原理

Softmax 函数是一种常用的激活函数,通常用于多分类问题中。它将一个含有多个实数值的向量(通常称为 logits)转换成一个概率分布,使得每个元素都在 (0, 1) 区间内,并且所有元素的和为 1。

假设我们有一个实数值向量 z,其中 z = [z1, z2, ..., zn],其中 zi 是向量 z 的第 i 个元素。Softmax 函数将向量 z 转换为一个概率分布向量 p = [p1, p2, ..., pn],其中 pi 表示类别 i 的概率。

Softmax 函数的定义如下:

其中,zi 是 logits 向量 z 的第 i 个元素,n 是 logits 向量 z 的长度(即类别的数量),e 是自然对数的底(约等于 2.71828)。

Softmax 函数的计算过程如下:

  1. 对 logits 向量 z 中的每个元素进行指数化(即计算 e 的 z 次方)。
  2. 计算所有指数化的值的和(即分母部分)。
  3. 将每个指数化的值除以总和,得到归一化后的概率值。

Softmax 函数的一个关键特性是它的输出是一个概率分布,即所有输出值的和为 1,因此可以用于表示多个互斥的类别的概率。

在神经网络中,Softmax 函数通常作为输出层的激活函数使用,用于将网络的最后一层输出转换为概率分布,以便进行多分类任务的训练和预测。

softmax应用

torch.nn.Softmax 是 PyTorch 中的一个类,用于计算 softmax 函数。softmax 函数常用于多分类问题中,将一个具有任意实数值的向量转换为一个概率分布,使得每个元素都在 (0, 1) 之间,并且所有元素的和为 1。

在 PyTorch 中,torch.nn.Softmax 可以作为一个层(Layer)添加到神经网络模型中,也可以作为一个函数使用。它的语法如下:

python 复制代码
torch.nn.Softmax(dim=None)
  • dim(可选):指定 softmax 函数计算的维度。默认值为 -1,表示最后一个维度。

torch.nn.Softmax 类初始化后可以调用其 forward 方法来计算 softmax 函数。另外,你也可以直接使用 torch.softmax() 函数来计算 softmax。

下面是使用 torch.nn.Softmax 类的一个示例:

python 复制代码
import torch
import torch.nn as nn

# 创建一个 3x4 的输入张量
input_tensor = torch.randn(3, 4)

# 创建 Softmax 层
softmax_layer = nn.Softmax(dim=1)

# 对输入张量应用 Softmax 层
output_tensor = softmax_layer(input_tensor)

print(output_tensor)

这里,我们首先创建了一个 3x4 的输入张量 input_tensor,然后创建了一个 softmax 层,并将其应用于输入张量。最终得到的 output_tensor 是一个概率分布,其中每一行的元素都在 (0, 1) 之间,并且每一行的元素之和为 1。

你也可以使用 torch.softmax() 函数直接计算 softmax,示例如下:

python 复制代码
output_tensor = torch.softmax(input_tensor, dim=1)

这与使用 softmax 层的结果是相同的。

相关推荐
小唐C++32 分钟前
C++小病毒-1.0勒索
开发语言·c++·vscode·python·算法·c#·编辑器
北 染 星 辰1 小时前
Python网络自动化运维---用户交互模块
开发语言·python·自动化
codists1 小时前
《CPython Internals》阅读笔记:p336-p352
python
Мартин.1 小时前
[Meachines] [Easy] GoodGames SQLI+Flask SSTI+Docker逃逸权限提升
python·docker·flask
日日行不惧千万里1 小时前
如何用YOLOv8训练一个识别安全帽的模型?
python·yolo
LuiChun2 小时前
Flutter接django后台文件通道
python·flutter·django
阿俊仔(摸鱼版)3 小时前
Python 常用运维模块之Shutil 模块
linux·服务器·python·自动化·云服务器
MarsBighead3 小时前
(二)PosrgreSQL: Python3 连接Pgvector出错排查
python·postgresql·向量数据库·pgvector
深蓝海拓3 小时前
Pyside6(PyQT5)中的QTableView与QSqlQueryModel、QSqlTableModel的联合使用
数据库·python·qt·pyqt
无须logic ᭄3 小时前
CrypTen项目实践
python·机器学习·密码学·同态加密