torch.full_like
是 PyTorch 中用于创建一个具有特定值的新张量,其形状和数据类型与给定张量相同。
函数定义
torch.full_like(input, fill_value, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format)
参数说明
-
input
(Tensor):- 用来提供形状和其他属性(如设备、数据类型等)的参考张量。
-
fill_value
(float 或 int):- 用于填充新张量的值。
-
dtype
(torch.dtype, 可选):- 新张量的数据类型。如果未指定,则与
input
的数据类型相同。
- 新张量的数据类型。如果未指定,则与
-
layout
(torch.layout, 可选):- 新张量的内存布局。默认为
input
的布局。
- 新张量的内存布局。默认为
-
device
(torch.device, 可选):- 新张量所在的设备。如果未指定,则与
input
的设备相同。
- 新张量所在的设备。如果未指定,则与
-
requires_grad
(bool, 可选, 默认值:False
):- 如果为
True
,新张量将需要梯度计算。
- 如果为
-
memory_format
(torch.memory_format, 可选):- 新张量的内存格式。默认为
torch.preserve_format
,即与input
相同的内存格式。
- 新张量的内存格式。默认为
返回值
- 返回一个新张量,其形状、设备、数据类型等与
input
相同,但所有元素均为fill_value
。
示例
1. 基本用法
import torch
# 创建一个参考张量
x = torch.tensor([[1, 2], [3, 4]])
# 创建一个与 x 形状相同的新张量,元素全为 5
result = torch.full_like(x, 5)
print(result)
# tensor([[5, 5],
# [5, 5]])
2. 指定数据类型
result = torch.full_like(x, 5.0, dtype=torch.float32)
print(result)
# tensor([[5.0, 5.0],
# [5.0, 5.0]])
3. 指定设备
result = torch.full_like(x, 3, device='cuda')
print(result) # 张量在 GPU 上
4. 需要梯度
result = torch.full_like(x, 2, requires_grad=True)
print(result.requires_grad) # True
常见用途
- 快速初始化张量:在网络初始化、测试时创建具有固定值的张量。
- 占位符:生成形状与参考张量相同的占位张量。
- 兼容性计算:确保新张量与给定张量具有相同的数据类型和设备。
注意事项
-
与
torch.full
的区别 :torch.full
需要手动指定张量的形状,而torch.full_like
自动使用参考张量的形状。 -
支持广播 :
fill_value
可以是标量。 -
性能优化 :
torch.full_like
会自动优化设备和数据类型,便于高效地创建张量。