pytorch小记(二十):深入解析 PyTorch 的 `torch.randn_like`:原理、参数与实战示例

pytorch小记(二十):深入解析 PyTorch 的 `torch.randn_like`:原理、参数与实战示例

    • 一、函数签名与参数详解
    • [二、`torch.randn_like` vs `torch.randn`](#二、torch.randn_like vs torch.randn)
    • 三、基础示例
    • 四、进阶用法与参数覆盖
      • [4.1 覆盖数据类型(dtype)](#4.1 覆盖数据类型(dtype))
      • [4.2 覆盖设备(device)](#4.2 覆盖设备(device))
      • [4.3 开启梯度追踪(requires\_grad)](#4.3 开启梯度追踪(requires_grad))
      • [4.4 覆盖内存格式(memory\_format)](#4.4 覆盖内存格式(memory_format))
    • 五、典型应用场景
      • [1. 给模型参数添加噪声](#1. 给模型参数添加噪声)
      • [2. 数据增强:图像高斯噪声](#2. 数据增强:图像高斯噪声)
      • [3. 扩散模型(DDPM)中的噪声采样](#3. 扩散模型(DDPM)中的噪声采样)
    • 六、多种等价写法
    • 七、小结

在深度学习模型中,我们经常需要在已有张量的基础上生成与之「同形状」「同设备」「同或不同数据类型」的随机噪声,用于参数扰动、数据增强、扩散模型等场景。PyTorch 为我们提供了一个高效便捷的工具------torch.randn_like,它能一步完成上述需求。本文将从函数定义、参数详解、典型应用场景,到进阶用法,全面剖析 torch.randn_like,并通过丰富示例帮助你快速上手。


一、函数签名与参数详解

python 复制代码
torch.randn_like(
    input: Tensor,
    *,
    dtype: Optional[torch.dtype] = None,
    layout: Optional[torch.layout] = None,
    device: Optional[torch.device] = None,
    requires_grad: bool = False,
    memory_format: Optional[torch.memory_format] = None
) → Tensor
  • input (必选)

    源张量,randn_like 会读取它的 .shape.dtype.device.layout、以及 memory_format(如果未显式指定覆盖项)。

  • dtype (可选)

    生成张量的数据类型,如 torch.float32torch.int64 等。若不指定,则继承 input.dtype

  • device (可选)

    指定在 CPU 还是 GPU 上创建新张量,如 "cpu""cuda:0"。若不指定,则继承 input.device

  • requires_grad (可选)

    是否对新张量开启梯度追踪,默认为 False

  • 其他

    • layout:张量内存布局,一般使用默认;
    • memory_format:指定内存格式,如 torch.contiguous_format

二、torch.randn_like vs torch.randn

方法 参数 优点
torch.randn(size) 必须手动传入 size、可选传入 dtypedevice 简单直观,适合只关心形状的场景
torch.randn_like(input) 自动继承 input.shapedtypedevicelayout 等属性 减少样板代码,保证输出张量与输入环境一致

三、基础示例

python 复制代码
import torch

# 1. 构造一个形状为 (2, 3) 的零张量
x = torch.zeros(2, 3)
print("x:", x.shape, x.dtype, x.device)
# x: torch.Size([2, 3]) torch.float32 cpu

# 2. 生成与 x 同形状同属性的标准正态随机张量
noise = torch.randn_like(x)
print("noise:", noise)
# 示例输出:
# tensor([[-0.1245,  0.5487, -0.3221],
#         [ 0.8723, -1.0054,  0.0392]])
  • 新张量 noisex 的形状、数据类型、设备保持一致。

四、进阶用法与参数覆盖

4.1 覆盖数据类型(dtype)

python 复制代码
# 强制生成 float64 类型
noise_fp64 = torch.randn_like(x, dtype=torch.float64)
print(noise_fp64.dtype)  # torch.float64

4.2 覆盖设备(device)

python 复制代码
if torch.cuda.is_available():
    noise_gpu = torch.randn_like(x, device=torch.device('cuda:0'))
    print(noise_gpu.device)  # cuda:0

4.3 开启梯度追踪(requires_grad)

python 复制代码
noise_grad = torch.randn_like(x, requires_grad=True)
print(noise_grad.requires_grad)  # True

4.4 覆盖内存格式(memory_format)

python 复制代码
noise_contig = torch.randn_like(x, memory_format=torch.contiguous_format)
# 通常无需显式指定,除非对内存布局有特殊需求

五、典型应用场景

1. 给模型参数添加噪声

在对抗训练、参数平滑或元学习中,需要对权重做微小扰动:

python 复制代码
import torch.nn as nn

class NoisyLinear(nn.Linear):
    def forward(self, input):
        # 为权重张量添加微小高斯噪声
        weight_noise = torch.randn_like(self.weight) * 0.01
        return nn.functional.linear(input, self.weight + weight_noise, self.bias)

layer = NoisyLinear(128, 64)
x = torch.randn(32, 128)
out = layer(x)  # 前向过程中,自动生成同形状噪声

2. 数据增强:图像高斯噪声

对图像 Batch 注入随机噪声,提升模型鲁棒性:

python 复制代码
# 假设 images 形状为 [B, C, H, W]
images = torch.randn(16, 3, 224, 224)  # 示例输入
noise_std = 0.1

noisy_images = images + torch.randn_like(images) * noise_std
# 这样可以保证噪声形状 / dtype / device 与 images 完全一致

3. 扩散模型(DDPM)中的噪声采样

在扩散模型中,需要不断向数据添加标准正态噪声,且噪声张量形状与数据完全对齐:

python 复制代码
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    # 根据时间步 t 计算噪声比例等后续操作...
    return x_start * alpha_t[t] + noise * beta_t[t]

六、多种等价写法

  • tensor.long()tensor.to(torch.int64)
  • tensor.type(torch.float32) 等方法,均可对已有张量做类型转换,与 randn_like 结合时常用于进一步处理。

七、小结

  1. 功能torch.randn_like 快速生成与指定张量同形状、同设备的标准正态分布随机张量。
  2. 参数覆盖 :可选 dtypedevicerequires_gradmemory_format 等,灵活适配各种需求。
  3. 典型场景:参数扰动、数据增强、扩散模型、随机索引等。
  4. 最佳实践 :在不关心形状等属性细节时,用 randn_like 省去 boilerplate;在需要覆盖属性时,通过关键字参数一次性完成。
相关推荐
白白白飘22 分钟前
pytorch 15.1 学习率调度基本概念与手动实现方法
人工智能·pytorch·学习
深度学习入门27 分钟前
机器学习,深度学习,神经网络,深度神经网络之间有何区别?
人工智能·python·深度学习·神经网络·机器学习·机器学习入门·深度学习算法
张彦峰ZYF1 小时前
走出 Demo,走向现实:DeepSeek-VL 的多模态工程路线图
人工智能
森哥的歌1 小时前
Python uv包管理器使用指南:从入门到精通
python·开发工具·uv·虚拟环境·包管理
qq_214782611 小时前
给你的matplotlib images添加scale Bar
python·数据分析·matplotlib
Johny_Zhao2 小时前
Vmware workstation安装部署微软SCCM服务系统
网络·人工智能·python·sql·网络安全·信息安全·微软·云计算·shell·系统运维·sccm
waterHBO2 小时前
python + flask 做一个图床
python
动感光博2 小时前
Unity(URP渲染管线)的后处理、动画制作、虚拟相机(Virtual Camera)
开发语言·人工智能·计算机视觉·unity·c#·游戏引擎
IT古董2 小时前
【漫话机器学习系列】259.神经网络参数的初始化(Initialization Of Neural Network Parameters)
人工智能·神经网络·机器学习
tyatyatya2 小时前
神经网络在MATLAB中是如何实现的?
人工智能·神经网络·matlab