pytorch小记(二):pytorch矩阵乘法:torch.cat(tensors, dim=0)
在 PyTorch 中,torch.cat()
是一种用于在指定维度上连接张量的操作。它能够将多个张量沿某个轴拼接成一个新的张量。
语法
python
torch.cat(tensors, dim=0)
tensors
:一个包含多个待拼接张量的列表或元组。这些张量在指定的dim
维度以外的所有维度上必须具有相同的形状。dim
:指定在哪个维度上进行拼接操作。
使用规则
- 在指定维度上,张量的形状可以不同(因为会拼接)。
- 在其他维度上,张量的形状必须相同。
示例 1:在第 0 维(行)拼接
python
x = torch.tensor([[1, 2],
[3, 4]]) # 形状 (2, 2)
y = torch.tensor([[5, 6],
[7, 8]]) # 形状 (2, 2)
result = torch.cat((x, y), dim=0) # 在第 0 维拼接
print(result)
输出:
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
- 原始张量
x
和y
在第 0 维上(行方向)拼接,因此新张量的形状为(4, 2)
。
示例 2:在第 1 维(列)拼接
python
x = torch.tensor([[1, 2],
[3, 4]]) # 形状 (2, 2)
y = torch.tensor([[5, 6],
[7, 8]]) # 形状 (2, 2)
result = torch.cat((x, y), dim=1) # 在第 1 维拼接
print(result)
输出:
tensor([[1, 2, 5, 6],
[3, 4, 7, 8]])
- 原始张量
x
和y
在第 1 维上(列方向)拼接,因此新张量的形状为(2, 4)
。
示例 3:在高维张量上拼接
我们来创建两个高维张量 x
和 y
,并分别在不同维度(dim=0
, dim=1
, dim=2
)上使用 torch.cat
进行拼接,展示具体计算结果。
初始张量
python
x = torch.tensor([
[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9], [10, 11, 12]]
]) # 形状 (2, 2, 3)
y = torch.tensor([
[[13, 14, 15], [16, 17, 18]],
[[19, 20, 21], [22, 23, 24]]
]) # 形状 (2, 2, 3)
-
x
和y
是形状为(2, 2, 3)
的 3D 张量:x: [[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]] y: [[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]]
1. 在 dim=0
拼接
python
result_dim0 = torch.cat((x, y), dim=0)
print(result_dim0.shape) # torch.Size([4, 2, 3])
print(result_dim0)
拼接逻辑:
- 在第 0 维度(最外层)拼接,结果张量包含 4 个"块",每个"块"的形状仍然是
(2, 3)
。
结果:
result_dim0:
[[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[ 10, 11, 12]],
[[ 13, 14, 15],
[ 16, 17, 18]],
[[ 19, 20, 21],
[ 22, 23, 24]]]
2. 在 dim=1
拼接
python
result_dim1 = torch.cat((x, y), dim=1)
print(result_dim1.shape) # torch.Size([2, 4, 3])
print(result_dim1)
拼接逻辑:
- 在第 1 维度(每个"块"中的行)拼接,结果张量包含 2 个"块",每个"块"增加了 2 行,形状从
(2, 3)
变为(4, 3)
。
结果:
result_dim1:
[[[ 1, 2, 3],
[ 4, 5, 6],
[ 13, 14, 15],
[ 16, 17, 18]],
[[ 7, 8, 9],
[ 10, 11, 12],
[ 19, 20, 21],
[ 22, 23, 24]]]
3. 在 dim=2
拼接
python
result_dim2 = torch.cat((x, y), dim=2)
print(result_dim2.shape) # torch.Size([2, 2, 6])
print(result_dim2)
拼接逻辑:
- 在第 2 维度(每行中的列)拼接,结果张量包含 2 个"块",每个"块"有 2 行,但每行的列数增加了一倍,从 3 列变为 6 列。
结果:
result_dim2:
[[[ 1, 2, 3, 13, 14, 15],
[ 4, 5, 6, 16, 17, 18]],
[[ 7, 8, 9, 19, 20, 21],
[ 10, 11, 12, 22, 23, 24]]]
总结
dim 值 |
拼接维度 | 结果形状 | 拼接效果 |
---|---|---|---|
dim=0 |
最外层 | (4, 2, 3) |
增加块的数量(纵向堆叠) |
dim=1 |
每块的行数 | (2, 4, 3) |
增加每块的行数(横向堆叠行) |
dim=2 |
每行的列数 | (2, 2, 6) |
增加每行的列数(横向堆叠列) |
通过改变 dim
,torch.cat
可以在不同维度上灵活地拼接张量。
示例 4:拼接不同形状的张量(错误示范)
如果张量在非拼接维度上的形状不同,会抛出错误:
python
x = torch.tensor([[1, 2], [3, 4]]) # 形状 (2, 2)
y = torch.tensor([[5, 6, 7]]) # 形状 (1, 3)
result = torch.cat((x, y), dim=0) # 抛出错误
错误信息:
RuntimeError: Sizes of tensors must match except in dimension 0. Got 2 and 3 in dimension 1
如果希望在行方向 dim=0
拼接,可以通过 补零
或 裁剪
等方式使列数一致。
补零torch.nn.functional.pad
:
python
import torch
import torch.nn.functional as F
x = torch.tensor([[1, 2], [3, 4]]) # 形状 (2, 2)
y = torch.tensor([[5, 6, 7]]) # 形状 (1, 3)
# 对 x 补零到列数 3
x_padded = F.pad(x, (0, 1)) # 在列方向右侧补 1 列零
# x_padded 形状: (2, 3)
# 在 dim=0 拼接
result = torch.cat((x_padded, y), dim=0)
print(result)
结果:
tensor([[1, 2, 0],
[3, 4, 0],
[5, 6, 7]])
但是
result = torch.cat((x_padded, y), dim=1)
则还是错误的!!!
总结
torch.cat()
用于连接张量,指定的dim
决定了在哪个维度上进行拼接。- 拼接维度的大小是累加的,其他维度的大小必须一致。
- 如果不满足上述规则,会抛出错误。
通过这种操作,你可以灵活地调整和组织张量的数据结构。