PyTorch矩阵乘法函数区别解析与矩阵高级索引说明——《动手学深度学习》3.6.3、3.6.4和3.6.5 (P79)

主要区别总结

函数 输入要求 输出维度 支持广播 使用场景
torch.matmul 灵活 灵活 通用矩阵乘法
torch.mm 两个2D张量 2D 严格矩阵乘法
torch.mv 2D矩阵 × 1D向量 1D 矩阵向量乘法

推荐用法

  • 大多数情况 :使用 torch.matmul@ 运算符(A @ B

  • 性能关键且确定维度 :使用特定函数(mm/mv

  • 批量运算 :必须使用 torch.matmul

python 复制代码
# 现代PyTorch中推荐使用 @ 运算符
result = A @ B  # 等同于 torch.matmul(A, B)

torch.matmul 是最通用的选择,而 mmmv 或dot在你知道确切维度时可以提供更清晰的代码意图。

矩阵高级索引

y_hat[[0, 1], y] 是 PyTorch 中的**高级索引(advanced indexing)**操作

代码解析:

python 复制代码
y = torch.tensor([0, 2])           # 形状: (2,)
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])  # 形状: (2, 3)

result = y_hat[[0, 1], y]          # 结果: tensor([0.1000, 0.5000])

索引操作详解:

这个操作相当于:

  • y_hat[0, y[0]]y_hat[0, 0]0.1

  • y_hat[1, y[1]]y_hat[1, 2]0.5

逐步分解:

python 复制代码
# 行索引: [0, 1]   列索引: [0, 2]
# 对应位置配对:
#   第1对: 行0, 列0 → 值0.1
#   第2对: 行1, 列2 → 值0.5

关于向量类型:

y_hat 是矩阵(2D张量):

python 复制代码
print(y_hat.shape)  # torch.Size([2, 3])
# 这是一个 2×3 的矩阵,不是行向量也不是列向量

y 是1D张量:

python 复制代码
print(y.shape)      # torch.Size([2])
# 这是一个包含2个元素的一维张量

结果是1D张量:

python 复制代码
print(result.shape) # torch.Size([2])
# 结果是一个包含2个元素的一维张量

这种索引的实用场景:

这在机器学习中很常见,特别是在计算交叉熵损失时:

python 复制代码
# 假设:
# y_hat 是预测的概率分布 (batch_size, num_classes)
# y 是真实标签 (batch_size,)

# 这种索引用于获取每个样本对应真实标签的预测概率
predicted_probs = y_hat[range(len(y)), y]
# 这在交叉熵损失计算中很有用

总结:

  • y_hat[[0, 1], y]配对索引操作

  • 返回的是每个 (行, 列) 对对应的元素

  • y_hat2D矩阵y1D向量 ,结果是1D向量

  • 这种操作在机器学习中常用于根据真实标签索引预测概率

最大值的索引位置

这是一个在机器学习和深度学习中使用非常频繁的代码片段。

基本含义

y_hat.argmax(axis=1) 表示:在第二个维度(axis=1)上找出最大值的索引位置

具体解释

在分类问题中的典型用法

假设 y_hat 是一个预测概率矩阵:

  • 每一行代表一个样本

  • 每一列代表一个类别的预测概率

python 复制代码
import numpy as np

# 示例:3个样本,4个类别的预测概率
y_hat = np.array([
    [0.1, 0.2, 0.6, 0.1],  # 样本1:第3个类别概率最高(0.6)
    [0.7, 0.1, 0.1, 0.1],  # 样本2:第1个类别概率最高(0.7)
    [0.05, 0.05, 0.1, 0.8] # 样本3:第4个类别概率最高(0.8)
])

predictions = y_hat.argmax(axis=1)
print(predictions)  # 输出:[2, 0, 3]

维度说明

  • axis=0:沿着行方向(垂直)

  • axis=1:沿着列方向(水平)

对于二维数组:

python 复制代码
[[x00, x01, x02],  ← axis=1(列方向)
 [x10, x11, x12],
 [x20, x21, x22]]
 ↑
axis=0(行方向)

实际应用场景

python 复制代码
# 在神经网络分类中
import torch

# 模型输出(批量大小=4,类别数=3)
outputs = torch.tensor([
    [1.2, 0.5, -0.3],
    [0.1, 2.1, 0.8],
    [-0.5, 0.3, 1.7],
    [0.9, 0.6, 0.4]
])

# 获取预测类别
predicted_classes = outputs.argmax(dim=1)  # PyTorch中用dim
print(predicted_classes)  # 输出:tensor([0, 1, 2, 0])

总结

y_hat.argmax(axis=1) 的主要作用是:

  • 将概率分布转换为具体的类别预测

  • 找出每个样本最可能的类别

  • 常用于计算准确率和模型评估

这是在分类任务中从模型输出获取最终预测结果的常用方法。

相关推荐
CLubiy3 小时前
【研究生随笔】Pytorch中的线性代数(微分)
人工智能·pytorch·深度学习·线性代数·梯度·微分
美狐美颜SDK开放平台3 小时前
直播美颜SDK功能开发实录:自然妆感算法、人脸跟踪与AI美颜技术
人工智能·深度学习·算法·美颜sdk·直播美颜sdk·美颜api
郝学胜-神的一滴4 小时前
矩阵的奇异值分解(SVD)及其在计算机图形学中的应用
程序人生·线性代数·算法·矩阵·图形渲染
明朝百晓生5 小时前
强化学习【Monte Carlo Learning][MC Basic 算法]
人工智能·机器学习
AI云原生5 小时前
云原生系列Bug修复:Docker镜像无法启动的终极解决方案与排查思路
运维·服务器·python·docker·云原生·容器·bug
万粉变现经纪人6 小时前
如何解决 pip install -r requirements.txt 私有索引未设为 trusted-host 导致拒绝 问题
开发语言·python·scrapy·flask·beautifulsoup·pandas·pip
查士丁尼·绵8 小时前
笔试-九宫格三阶积幻方
python·九宫格·三阶积幻方
云知谷9 小时前
【C++基本功】C++适合做什么,哪些领域适合哪些领域不适合?
c语言·开发语言·c++·人工智能·团队开发
rit843249910 小时前
基于MATLAB实现基于距离的离群点检测算法
人工智能·算法·matlab