分类目录:《深入浅出Pytorch函数》总目录
相关文章:
· 深入浅出Pytorch函数------torch.Tensor
函数torch.sum
有两种形式:
torch.sum(input, *, dtype=None)
:返回输入张量input
所有元素的和。torch.sum(input, dim, keepdim=False, *, dtype=None)
:返回给定维度dim
中输入张量的每一行的总和。如果dim
是一个维度列表,则对所有维度进行缩小。如果keepdim
为True
,则输出张量的大小与输入的大小相同,但维度dim
的大小为1。否则,dim
会被挤压(参考torch.squeeze()
)。
语法
dart
torch.sum(input, *, dtype=None) -> Tensor
torch.sum(input, dim, keepdim=False, *, dtype=None) -> Tensor
参数
input
:输入张量dim
:[可选,int
/tuple
] 要减少的一个或多个维度。如果为None
,则所有维度都将被裁剪。keepdim
:[bool
] 输出张量是否保留了dim
。dtype
:[可选,torch.dtype
] 返回张量的所需数据类型。如果指定,则在执行操作之前将输入张量强制转换为dtype
。这对于防止数据类型溢出非常有用,默认值为None
。
实例
dart
>>> a = torch.randn(1, 3)
>>> a
tensor([[ 0.1133, -0.9567, 0.2958]])
>>> torch.sum(a)
tensor(-0.5475)
>>> a = torch.randn(4, 4)
>>> a
tensor([[ 0.0569, -0.2475, 0.0737, -0.3429],
[-0.2993, 0.9138, 0.9337, -1.6864],
[ 0.1132, 0.7892, -0.1003, 0.5688],
[ 0.3637, -0.9906, -0.4752, -1.5197]])
>>> torch.sum(a, 1)
tensor([-0.4598, -0.1381, 1.3708, -2.6217])
>>> b = torch.arange(4 * 5 * 6).view(4, 5, 6)
>>> torch.sum(b, (2, 1))
tensor([ 435., 1335., 2235., 3135.])