Tensor含义
Tensor(张量)可以看作是一个多维数组,它是标量、向量和矩阵向更高维度的扩展。
| 张量维度 | 数学等价物 | 实例说明(PyTorch创建示例) |
|---|---|---|
| 0维 | **标量 (Scalar)** | 单个数值,如损失值:tensor(3.1416) |
| 1维 | **向量 (Vector)** | 一维数组,如特征向量:tensor([1, 2, 3]) |
| 2维 | **矩阵 (Matrix)** | 二维数组,如全连接层权重:tensor([[1, 2], [3, 4]]) |
| 3维及以上 | 高阶张量 | 如RGB图像(3, 224, 224)、图像批次(16, 3, 224, 224) |
Tensor的关键属性
- 数据类型(dtype):指定张量中常见的数据类型,如torch.float32、torch.float64、torch.int64、torch.bool等
- 设备(device):表明张量当前存储在何处,是cpu还是cuda:0(GPU)等
- 形状(shape):一个元组,表示张量在每个维度上的大小。
- 是否需要梯度(requires_grad):一个布尔值,指示是否需要为张量计算梯度。
Tensor常见操作
- torch.cat(torsors, dim)
dim=0 表示拼接行,dim=1 拼接列
比如
import torch
A = torch.tensor([[1, 2, 3], [4, 5, 6]])
B = torch.tensor([[7, 8, 9], [10, 11, 12]])
torch.cat((A, B), dim=0)
"""
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
"""
torch.cat((A, B), dim=1)
"""
tensor([[ 1, 2, 3, 7, 8, 9],
[ 4, 5, 6, 10, 11, 12]])
"""
参考资料: