主要区别总结
函数 | 输入要求 | 输出维度 | 支持广播 | 使用场景 |
---|---|---|---|---|
torch.matmul |
灵活 | 灵活 | ✅ | 通用矩阵乘法 |
torch.mm |
两个2D张量 | 2D | ❌ | 严格矩阵乘法 |
torch.mv |
2D矩阵 × 1D向量 | 1D | ❌ | 矩阵向量乘法 |
推荐用法
-
大多数情况 :使用
torch.matmul
或@
运算符(A @ B
) -
性能关键且确定维度 :使用特定函数(
mm
/mv
) -
批量运算 :必须使用
torch.matmul
python
# 现代PyTorch中推荐使用 @ 运算符
result = A @ B # 等同于 torch.matmul(A, B)
torch.matmul
是最通用的选择,而 mm
和 mv
或dot在你知道确切维度时可以提供更清晰的代码意图。
矩阵高级索引
y_hat[[0, 1], y]
是 PyTorch 中的**高级索引(advanced indexing)**操作
代码解析:
python
y = torch.tensor([0, 2]) # 形状: (2,)
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]]) # 形状: (2, 3)
result = y_hat[[0, 1], y] # 结果: tensor([0.1000, 0.5000])
索引操作详解:
这个操作相当于:
-
取
y_hat[0, y[0]]
→y_hat[0, 0]
→0.1
-
取
y_hat[1, y[1]]
→y_hat[1, 2]
→0.5
逐步分解:
python
# 行索引: [0, 1] 列索引: [0, 2]
# 对应位置配对:
# 第1对: 行0, 列0 → 值0.1
# 第2对: 行1, 列2 → 值0.5
关于向量类型:
y_hat
是矩阵(2D张量):
python
print(y_hat.shape) # torch.Size([2, 3])
# 这是一个 2×3 的矩阵,不是行向量也不是列向量
y
是1D张量:
python
print(y.shape) # torch.Size([2])
# 这是一个包含2个元素的一维张量
结果是1D张量:
python
print(result.shape) # torch.Size([2])
# 结果是一个包含2个元素的一维张量
这种索引的实用场景:
这在机器学习中很常见,特别是在计算交叉熵损失时:
python
# 假设:
# y_hat 是预测的概率分布 (batch_size, num_classes)
# y 是真实标签 (batch_size,)
# 这种索引用于获取每个样本对应真实标签的预测概率
predicted_probs = y_hat[range(len(y)), y]
# 这在交叉熵损失计算中很有用
总结:
-
y_hat[[0, 1], y]
是配对索引操作 -
返回的是每个
(行, 列)
对对应的元素 -
y_hat
是2D矩阵 ,y
是1D向量 ,结果是1D向量 -
这种操作在机器学习中常用于根据真实标签索引预测概率
最大值的索引位置
这是一个在机器学习和深度学习中使用非常频繁的代码片段。
基本含义
y_hat.argmax(axis=1)
表示:在第二个维度(axis=1)上找出最大值的索引位置。
具体解释
在分类问题中的典型用法
假设 y_hat
是一个预测概率矩阵:
-
每一行代表一个样本
-
每一列代表一个类别的预测概率
python
import numpy as np
# 示例:3个样本,4个类别的预测概率
y_hat = np.array([
[0.1, 0.2, 0.6, 0.1], # 样本1:第3个类别概率最高(0.6)
[0.7, 0.1, 0.1, 0.1], # 样本2:第1个类别概率最高(0.7)
[0.05, 0.05, 0.1, 0.8] # 样本3:第4个类别概率最高(0.8)
])
predictions = y_hat.argmax(axis=1)
print(predictions) # 输出:[2, 0, 3]
维度说明
-
axis=0
:沿着行方向(垂直) -
axis=1
:沿着列方向(水平)
对于二维数组:
python
[[x00, x01, x02], ← axis=1(列方向)
[x10, x11, x12],
[x20, x21, x22]]
↑
axis=0(行方向)
实际应用场景
python
# 在神经网络分类中
import torch
# 模型输出(批量大小=4,类别数=3)
outputs = torch.tensor([
[1.2, 0.5, -0.3],
[0.1, 2.1, 0.8],
[-0.5, 0.3, 1.7],
[0.9, 0.6, 0.4]
])
# 获取预测类别
predicted_classes = outputs.argmax(dim=1) # PyTorch中用dim
print(predicted_classes) # 输出:tensor([0, 1, 2, 0])
总结
y_hat.argmax(axis=1)
的主要作用是:
-
将概率分布转换为具体的类别预测
-
找出每个样本最可能的类别
-
常用于计算准确率和模型评估
这是在分类任务中从模型输出获取最终预测结果的常用方法。