pytorch张量的new_zeros方法介绍

在 PyTorch 中,Tensor.new_zeros 是一种用于创建与现有张量形状或设备匹配的新张量的方法。该方法生成一个全为零的张量,且其数据类型、设备等属性与调用它的张量一致,除非另行指定。


new_zeros 方法的语法

复制代码
Tensor.new_zeros(size, *, dtype=None, device=None, requires_grad=False)

参数说明

  • size (tuple)

    指定新张量的形状。例如 (2, 3) 表示创建一个形状为 2x3 的张量。

  • dtype (torch.dtype, 可选)

    指定新张量的数据类型。如果未指定,将与原张量的数据类型一致。

  • device (torch.device, 可选)

    指定新张量所在的设备(如 CPU 或 GPU)。如果未指定,将与原张量所在的设备一致。

  • requires_grad (bool, 可选)

    指定新张量是否需要计算梯度(默认为 False)。


new_zeros 的特性

  • 新张量与原张量具有相同的设备默认数据类型(除非显式更改)。
  • 新张量的内容为全零。

使用示例

1. 创建与现有张量形状匹配的零张量

复制代码
import torch

x = torch.ones(2, 3, device='cuda')  # 创建一个形状为 (2, 3) 的张量
zeros = x.new_zeros((2, 3))          # 创建一个全零张量,与 x 具有相同形状和设备
print(zeros)
# 输出(在 GPU 上):
# tensor([[0., 0., 0.],
#         [0., 0., 0.]], device='cuda:0')

2. 创建具有不同形状的零张量

复制代码
x = torch.ones(4, 5)
zeros = x.new_zeros((2, 3))  # 创建一个形状为 (2, 3) 的零张量
print(zeros)
# 输出:
# tensor([[0., 0., 0.],
#         [0., 0., 0.]])

3. 指定数据类型

复制代码
x = torch.ones(3, 3, dtype=torch.float32)
zeros = x.new_zeros((2, 2), dtype=torch.int32)  # 显式指定数据类型
print(zeros)
# 输出:
# tensor([[0, 0],
#         [0, 0]], dtype=torch.int32)

4. 指定设备

复制代码
x = torch.ones(2, 2, device='cuda')
zeros = x.new_zeros((3, 3), device='cpu')  # 在 CPU 上创建新张量
print(zeros)
# 输出:
# tensor([[0., 0., 0.],
#         [0., 0., 0.],
#         [0., 0., 0.]])

与其他创建零张量的方法的对比

  1. torch.zeros

    zeros = torch.zeros((2, 3))

    • 独立于已有张量。
    • 需要显式指定数据类型和设备。
  • Tensor.new_zeros

    zeros = x.new_zeros((2, 3))

  • 与现有张量 x 共享设备和默认数据类型。


常见应用场景

  1. 快速创建与输入张量匹配的零张量 在深度学习中,可能需要创建与现有张量形状和设备匹配的零张量。例如,用于初始化中间结果或辅助计算。

  2. 动态操作 当输入张量的形状、设备不固定时,可以使用 new_zeros 动态生成匹配的零张量,无需手动指定设备或数据类型。


总结

Tensor.new_zeros 是一个高效、方便的方法,适合在动态模型或设备敏感的代码中使用。它避免了显式管理设备和数据类型的麻烦,有助于提高代码的简洁性和可维护性。

相关推荐
YJlio1 分钟前
[鸿蒙2025领航者闯关] 基于鸿蒙 6 的「隐私感知跨设备办公助手」实战:星盾安全 + AI防窥 + 方舟引擎优化全流程复盘
人工智能·安全·harmonyos
ghie90903 分钟前
线性三角波连续调频毫米波雷达目标识别
人工智能·算法·计算机视觉
闲人编程4 分钟前
Django中间件开发:从请求到响应的完整处理链
python·中间件·性能优化·django·配置·codecapsule
执笔论英雄6 分钟前
【RL】Slime异步 routout 过程7 AsyncLoopThread
开发语言·python
学习中的数据喵7 分钟前
可以看穿事物“本质“的LDA
人工智能·机器学习
fj_changing8 分钟前
Ubuntu 22.04部署CosyVoice
人工智能·python·深度学习·ubuntu·ai
z***02608 分钟前
Python大数据可视化:基于大数据技术的共享单车数据分析与辅助管理系统_flask+hadoop+spider
大数据·python·信息可视化
on_pluto_9 分钟前
【debug】解决 conda 和 镜像下载pytorch太慢的问题
人工智能·pytorch·conda
GIS程序媛—椰子9 分钟前
从后端到 AI/Agent:那些可迁移的系统思维(未完结)
人工智能·后端
雪域迷影11 分钟前
Python中通过get请求获取api.open-meteo.com网站的天气数据
开发语言·python·php