深入浅出Pytorch函数——torch.nn.init.trunc_normal_

分类目录:《深入浅出Pytorch函数》总目录

相关文章:

· 深入浅出Pytorch函数------torch.nn.init.calculate_gain

· 深入浅出Pytorch函数------torch.nn.init.uniform_

· 深入浅出Pytorch函数------torch.nn.init.normal_

· 深入浅出Pytorch函数------torch.nn.init.constant_

· 深入浅出Pytorch函数------torch.nn.init.ones_

· 深入浅出Pytorch函数------torch.nn.init.zeros_

· 深入浅出Pytorch函数------torch.nn.init.eye_

· 深入浅出Pytorch函数------torch.nn.init.dirac_

· 深入浅出Pytorch函数------torch.nn.init.xavier_uniform_

· 深入浅出Pytorch函数------torch.nn.init.xavier_normal_

· 深入浅出Pytorch函数------torch.nn.init.kaiming_uniform_

· 深入浅出Pytorch函数------torch.nn.init.kaiming_normal_

· 深入浅出Pytorch函数------torch.nn.init.trunc_normal_

· 深入浅出Pytorch函数------torch.nn.init.orthogonal_

· 深入浅出Pytorch函数------torch.nn.init.sparse_


torch.nn.init模块中的所有函数都用于初始化神经网络参数,因此它们都在torc.no_grad()模式下运行,autograd不会将其考虑在内。

该函数用截断正态分布中的值填充输入张量。这些值实际上是从正态分布 N ( mean , std 2 ) N(\text{mean}, \text{std}^2) N(mean,std2)中得出的,其中 a , b a, b a,b之外的值被重新绘制,直到它们在边界内。用于生成随机值的方法在 a ≤ mean ≤ b a\leq\text{mean}\leq b a≤mean≤b情况下效果最佳。

语法

复制代码
torch.nn.init.trunc_normal_(tensor, mean=0.0, std=1.0, a=- 2.0, b=2.0)

参数

  • tensor`Tensor` 一个 N N N维张量torch.Tensor
  • mean `float` 正态分布的均值
  • std `float` 正态分布的标准差
  • a`float` 截断边界的最小值
  • b`float` 截断边界的最大值

返回值

一个torch.Tensor且参数tensor也会更新

实例

复制代码
w = torch.empty(3, 5)
nn.init.trunc_normal_(w)

函数实现

复制代码
def trunc_normal_(tensor: Tensor, mean: float = 0., std: float = 1., a: float = -2., b: float = 2.) -> Tensor:
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)
相关推荐
双翌视觉1 分钟前
工业AI视觉检测中的“小样本困境”
人工智能·计算机视觉·视觉检测
CoderIsArt6 分钟前
声纹识别与音频AI领域
人工智能·音视频
tedcloud1238 分钟前
HyperFrames部署教程:用HTML生成MP4视频
前端·数据库·人工智能·html·音视频
jixunwulian15 分钟前
AI+边缘计算,工业智能网关智慧交通IoT解决方案
人工智能·物联网·边缘计算
启程在掘金15 分钟前
LangGraph 执行流程解析
人工智能
清辞85321 分钟前
Coze从入门到实战---第一、二章
大数据·人工智能·学习·语言模型
质造者29 分钟前
LangChain + Ollama + Tavily 实现旅游问答系统
linux·人工智能·python·langchain·rag
追梦人电立电子35 分钟前
X、Y电容的分类与选择
人工智能·分类·数据挖掘·追梦人电力电子
美狐美颜SDK开放平台36 分钟前
直播APP开发实战:第三方美颜sdk接入步骤与注意事项
人工智能·音视频·美颜sdk·第三方美颜sdk·短视频美颜sdk
yychen_java40 分钟前
当算法成为武器:AI泛滥时代的多维危机透视与治理路径
网络·人工智能·ai