在 PyTorch 中,Tensor.new_zeros
是一种用于创建与现有张量形状或设备匹配的新张量的方法。该方法生成一个全为零的张量,且其数据类型、设备等属性与调用它的张量一致,除非另行指定。
new_zeros
方法的语法
Tensor.new_zeros(size, *, dtype=None, device=None, requires_grad=False)
参数说明
-
size
(tuple)指定新张量的形状。例如
(2, 3)
表示创建一个形状为 2x3 的张量。 -
dtype
(torch.dtype, 可选)指定新张量的数据类型。如果未指定,将与原张量的数据类型一致。
-
device
(torch.device, 可选)指定新张量所在的设备(如 CPU 或 GPU)。如果未指定,将与原张量所在的设备一致。
-
requires_grad
(bool, 可选)指定新张量是否需要计算梯度(默认为
False
)。
new_zeros
的特性
- 新张量与原张量具有相同的设备 和默认数据类型(除非显式更改)。
- 新张量的内容为全零。
使用示例
1. 创建与现有张量形状匹配的零张量
import torch
x = torch.ones(2, 3, device='cuda') # 创建一个形状为 (2, 3) 的张量
zeros = x.new_zeros((2, 3)) # 创建一个全零张量,与 x 具有相同形状和设备
print(zeros)
# 输出(在 GPU 上):
# tensor([[0., 0., 0.],
# [0., 0., 0.]], device='cuda:0')
2. 创建具有不同形状的零张量
x = torch.ones(4, 5)
zeros = x.new_zeros((2, 3)) # 创建一个形状为 (2, 3) 的零张量
print(zeros)
# 输出:
# tensor([[0., 0., 0.],
# [0., 0., 0.]])
3. 指定数据类型
x = torch.ones(3, 3, dtype=torch.float32)
zeros = x.new_zeros((2, 2), dtype=torch.int32) # 显式指定数据类型
print(zeros)
# 输出:
# tensor([[0, 0],
# [0, 0]], dtype=torch.int32)
4. 指定设备
x = torch.ones(2, 2, device='cuda')
zeros = x.new_zeros((3, 3), device='cpu') # 在 CPU 上创建新张量
print(zeros)
# 输出:
# tensor([[0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.]])
与其他创建零张量的方法的对比
-
torch.zeros
zeros = torch.zeros((2, 3))
-
- 独立于已有张量。
- 需要显式指定数据类型和设备。
-
Tensor.new_zeros
zeros = x.new_zeros((2, 3))
-
与现有张量
x
共享设备和默认数据类型。
常见应用场景
-
快速创建与输入张量匹配的零张量 在深度学习中,可能需要创建与现有张量形状和设备匹配的零张量。例如,用于初始化中间结果或辅助计算。
-
动态操作 当输入张量的形状、设备不固定时,可以使用
new_zeros
动态生成匹配的零张量,无需手动指定设备或数据类型。
总结
Tensor.new_zeros
是一个高效、方便的方法,适合在动态模型或设备敏感的代码中使用。它避免了显式管理设备和数据类型的麻烦,有助于提高代码的简洁性和可维护性。