前一篇文章,Tensor 基本操作1 | PyTorch 深度学习实战
本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started
目录
Tensor 基本操作
torch.max
torch.max 实现降维运算,基于指定的 dim 选取子元素的最大值。
默认
a = torch.randn(1, 3)
print(a)
b = torch.max(a)
print(b)
Result:
tensor([[-0.5284, -1.5308, -0.2267]])
tensor(-0.2267)
指定维度
指定哪个维度,就是减去第几维:
假如有一个 Tensor Shape 是 AxBxCxD,那么有对应关系
A(dim0),B(dim1),C(dim2),D(dim3)。
假如沿着 dim = 0,则得到矩阵为 BxCxD,其中降维后的 D 中每个值,是 0 维 A 个原始元素最大的值。
假如沿着 dim = 1,则得到矩阵为 AxCxD,其中降维后的 D 中每个值,是 1 维 B 个原始元素最大的值。
a = torch.randn(4, 3, 2, 5) # 声明 4x3x2x5 的 Tensor
print(a)
max, max_indices = torch.max(a, 1) # 沿着第 1 维运算,得到的 max 是一个新的 Tensor, shape(4x2x5)
# 其中,新的 tensor 的第 2 维 有 5 个元素,每个元素是原来第 1 维的 3 个元素的对应位置的最大者
print(max)
print(max_indices)
运算过程:
运算的效果,就是将原来第 1 维的 三个 元素通过取最大值的方式消解了,剩下了 4x2x5 的新的 Tensor.
Detail result:
tensor([[[[ 1.6156, -0.3533, 0.5970, 1.0218, 0.3952], # 这是一个 4x3x2x5 的 Tensor
[ 0.2581, -1.3161, 0.3243, -0.9350, 0.6976]],
[[-0.6239, -0.8732, -0.2739, 1.3695, 0.9614],
[ 3.0117, -2.3211, 2.2359, -1.5275, 1.0230]],
[[ 0.2711, -0.5295, -0.9168, -0.9496, -0.5264],
[-0.0418, 1.4757, -0.3033, -0.5069, -0.6909]]],
[[[-0.3262, 1.0079, -0.2975, -0.9859, 1.6166],
[ 1.2771, -0.0456, 0.1857, 0.3275, 0.4207]],
[[ 0.2362, -0.0821, -0.0105, 1.7645, 0.0989],
[-0.1281, -1.0425, -0.5537, -0.0339, 1.3466]],
[[-1.3060, 1.0920, -0.9126, -0.3850, -0.7273],
[-0.0519, -0.3566, -0.5489, -3.6990, 0.6110]]],
[[[ 1.2422, -0.2393, 0.4786, 0.6107, -0.0252],
[ 0.2563, -0.4030, 1.8649, 0.3462, 0.7197]],
[[-0.6126, 0.7801, -0.6078, 0.1391, -0.8297],
[-1.8600, -0.2814, 0.2408, -0.9058, -0.0186]],
[[ 1.6242, 1.5925, -0.0591, -0.0107, -1.8332],
[ 0.9812, -3.2381, -1.7055, -1.3484, -1.3409]]],
[[[-0.3392, -0.4359, -0.0451, 2.4718, 1.9482],
[ 0.6110, -0.5543, 0.3466, 0.4199, -0.0319]],
[[-0.2322, -0.8355, -1.0138, 0.9620, -0.4311],
[-0.7799, 0.8414, 0.9293, -0.0322, 0.1638]],
[[ 0.6299, 0.7966, 1.8616, -1.8382, -0.1141],
[ 1.2325, -0.0446, -0.7722, 1.2540, -1.8609]]]])
tensor([[[ 1.6156, -0.3533, 0.5970, 1.3695, 0.9614], # 取 Max 之后得到的新的 Tensor
[ 3.0117, 1.4757, 2.2359, -0.5069, 1.0230]],
[[ 0.2362, 1.0920, -0.0105, 1.7645, 1.6166],
[ 1.2771, -0.0456, 0.1857, 0.3275, 1.3466]],
[[ 1.6242, 1.5925, 0.4786, 0.6107, -0.0252],
[ 0.9812, -0.2814, 1.8649, 0.3462, 0.7197]],
[[ 0.6299, 0.7966, 1.8616, 2.4718, 1.9482],
[ 1.2325, 0.8414, 0.9293, 1.2540, 0.1638]]])
tensor([[[0, 0, 0, 1, 1],
[1, 2, 1, 2, 1]],
[[1, 2, 1, 1, 0],
[0, 0, 0, 0, 1]],
[[2, 2, 0, 0, 0],
[2, 1, 0, 0, 0]],
[[2, 2, 2, 0, 0],
[2, 1, 1, 2, 1]]])