PyTorch矩阵乘法函数区别解析与矩阵高级索引说明——《动手学深度学习》3.6.3、3.6.4和3.6.5 (P79)

主要区别总结

函数 输入要求 输出维度 支持广播 使用场景
torch.matmul 灵活 灵活 通用矩阵乘法
torch.mm 两个2D张量 2D 严格矩阵乘法
torch.mv 2D矩阵 × 1D向量 1D 矩阵向量乘法

推荐用法

  • 大多数情况 :使用 torch.matmul@ 运算符(A @ B

  • 性能关键且确定维度 :使用特定函数(mm/mv

  • 批量运算 :必须使用 torch.matmul

python 复制代码
# 现代PyTorch中推荐使用 @ 运算符
result = A @ B  # 等同于 torch.matmul(A, B)

torch.matmul 是最通用的选择,而 mmmv 或dot在你知道确切维度时可以提供更清晰的代码意图。

矩阵高级索引

y_hat[[0, 1], y] 是 PyTorch 中的**高级索引(advanced indexing)**操作

代码解析:

python 复制代码
y = torch.tensor([0, 2])           # 形状: (2,)
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])  # 形状: (2, 3)

result = y_hat[[0, 1], y]          # 结果: tensor([0.1000, 0.5000])

索引操作详解:

这个操作相当于:

  • y_hat[0, y[0]]y_hat[0, 0]0.1

  • y_hat[1, y[1]]y_hat[1, 2]0.5

逐步分解:

python 复制代码
# 行索引: [0, 1]   列索引: [0, 2]
# 对应位置配对:
#   第1对: 行0, 列0 → 值0.1
#   第2对: 行1, 列2 → 值0.5

关于向量类型:

y_hat 是矩阵(2D张量):

python 复制代码
print(y_hat.shape)  # torch.Size([2, 3])
# 这是一个 2×3 的矩阵,不是行向量也不是列向量

y 是1D张量:

python 复制代码
print(y.shape)      # torch.Size([2])
# 这是一个包含2个元素的一维张量

结果是1D张量:

python 复制代码
print(result.shape) # torch.Size([2])
# 结果是一个包含2个元素的一维张量

这种索引的实用场景:

这在机器学习中很常见,特别是在计算交叉熵损失时:

python 复制代码
# 假设:
# y_hat 是预测的概率分布 (batch_size, num_classes)
# y 是真实标签 (batch_size,)

# 这种索引用于获取每个样本对应真实标签的预测概率
predicted_probs = y_hat[range(len(y)), y]
# 这在交叉熵损失计算中很有用

总结:

  • y_hat[[0, 1], y]配对索引操作

  • 返回的是每个 (行, 列) 对对应的元素

  • y_hat2D矩阵y1D向量 ,结果是1D向量

  • 这种操作在机器学习中常用于根据真实标签索引预测概率

最大值的索引位置

这是一个在机器学习和深度学习中使用非常频繁的代码片段。

基本含义

y_hat.argmax(axis=1) 表示:在第二个维度(axis=1)上找出最大值的索引位置

具体解释

在分类问题中的典型用法

假设 y_hat 是一个预测概率矩阵:

  • 每一行代表一个样本

  • 每一列代表一个类别的预测概率

python 复制代码
import numpy as np

# 示例:3个样本,4个类别的预测概率
y_hat = np.array([
    [0.1, 0.2, 0.6, 0.1],  # 样本1:第3个类别概率最高(0.6)
    [0.7, 0.1, 0.1, 0.1],  # 样本2:第1个类别概率最高(0.7)
    [0.05, 0.05, 0.1, 0.8] # 样本3:第4个类别概率最高(0.8)
])

predictions = y_hat.argmax(axis=1)
print(predictions)  # 输出:[2, 0, 3]

维度说明

  • axis=0:沿着行方向(垂直)

  • axis=1:沿着列方向(水平)

对于二维数组:

python 复制代码
[[x00, x01, x02],  ← axis=1(列方向)
 [x10, x11, x12],
 [x20, x21, x22]]
 ↑
axis=0(行方向)

实际应用场景

python 复制代码
# 在神经网络分类中
import torch

# 模型输出(批量大小=4,类别数=3)
outputs = torch.tensor([
    [1.2, 0.5, -0.3],
    [0.1, 2.1, 0.8],
    [-0.5, 0.3, 1.7],
    [0.9, 0.6, 0.4]
])

# 获取预测类别
predicted_classes = outputs.argmax(dim=1)  # PyTorch中用dim
print(predicted_classes)  # 输出:tensor([0, 1, 2, 0])

总结

y_hat.argmax(axis=1) 的主要作用是:

  • 将概率分布转换为具体的类别预测

  • 找出每个样本最可能的类别

  • 常用于计算准确率和模型评估

这是在分类任务中从模型输出获取最终预测结果的常用方法。

相关推荐
Mos_x20 小时前
关于我们的python日记本
开发语言·python
掘金安东尼20 小时前
被权重出卖的“脏数据”:GPT-oss 揭开的 OpenAI 中文训练真相
人工智能
十重幻想20 小时前
reshape的共享内存
python
Orange_sparkle20 小时前
关于dify中http节点下载文件时,文件名不为原始文件名问题解决
人工智能·http·chatgpt·dify
Juchecar21 小时前
设计模式不是Java专属,其他语言的使用方法
java·python·设计模式
王哈哈^_^21 小时前
【完整源码+数据集】蓝莓数据集,yolo11蓝莓成熟度检测数据集 3023 张,蓝莓成熟度数据集,目标检测蓝莓识别算法系统实战教程
人工智能·算法·yolo·目标检测·计算机视觉·ai·视觉检测
盘古开天166621 小时前
通俗易懂:YOLO模型原理详解,从零开始理解目标检测
人工智能·yolo·目标检测
OpenBuild.xyz21 小时前
x402 生态系统:Web3 与 AI 融合的支付新基建
人工智能·web3
王哈哈^_^21 小时前
【完整源码+数据集】高空作业数据集,yolo高空作业检测数据集 2076 张,人员高空作业数据集,目标检测高空作业识别系统实战教程
人工智能·算法·yolo·目标检测·计算机视觉·目标跟踪·视觉检测
猿小猴子21 小时前
主流 AI IDE 之一的 Comate IDE 介绍
ide·人工智能·comate