深入浅出Pytorch函数——torch.nn.init.kaiming_normal_

分类目录:《深入浅出Pytorch函数》总目录

相关文章:

· 深入浅出Pytorch函数------torch.nn.init.calculate_gain

· 深入浅出Pytorch函数------torch.nn.init.uniform_

· 深入浅出Pytorch函数------torch.nn.init.normal_

· 深入浅出Pytorch函数------torch.nn.init.constant_

· 深入浅出Pytorch函数------torch.nn.init.ones_

· 深入浅出Pytorch函数------torch.nn.init.zeros_

· 深入浅出Pytorch函数------torch.nn.init.eye_

· 深入浅出Pytorch函数------torch.nn.init.dirac_

· 深入浅出Pytorch函数------torch.nn.init.xavier_uniform_

· 深入浅出Pytorch函数------torch.nn.init.xavier_normal_

· 深入浅出Pytorch函数------torch.nn.init.kaiming_uniform_

· 深入浅出Pytorch函数------torch.nn.init.kaiming_normal_

· 深入浅出Pytorch函数------torch.nn.init.trunc_normal_

· 深入浅出Pytorch函数------torch.nn.init.orthogonal_

· 深入浅出Pytorch函数------torch.nn.init.sparse_


torch.nn.init模块中的所有函数都用于初始化神经网络参数,因此它们都在torc.no_grad()模式下运行,autograd不会将其考虑在内。

根据He, K等人于2015年在《Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification》中描述的方法,用一个正态分布生成值,填充输入的张量或变量。结果张量中的值采样自 N ( 0 , std 2 ) N(0, \text{std}^2) N(0,std2),其中:
std = gain fan_mode \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} std=fan_mode gain

这种方法也被称为He initialisation。

语法

复制代码
torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

参数

  • tensor:[Tensor] 一个 N N N维张量torch.Tensor
  • a:[float] 这层之后使用的rectifier的斜率系数(ReLU的默认值为0)
  • mode:[str] 可以为fan_infan_out。若为fan_in则保留前向传播时权值方差的量级,若为fan_out则保留反向传播时的量级,默认值为fan_in
  • nonlinearity:[str] 一个非线性函数,即一个nn.functional的名称,推荐使用relu或者leaky_relu,默认值为leaky_relu

返回值

一个torch.Tensor且参数tensor也会更新

实例

复制代码
w = torch.empty(3, 5)
nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')

函数实现

复制代码
def kaiming_normal_(
    tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
):
    r"""Fills the input `Tensor` with values according to the method
    described in `Delving deep into rectifiers: Surpassing human-level
    performance on ImageNet classification` - He, K. et al. (2015), using a
    normal distribution. The resulting tensor will have values sampled from
    :math:`\mathcal{N}(0, \text{std}^2)` where

    .. math::
        \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}

    Also known as He initialization.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        a: the negative slope of the rectifier used after this layer (only
            used with ``'leaky_relu'``)
        mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
            preserves the magnitude of the variance of the weights in the
            forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
            backwards pass.
        nonlinearity: the non-linear function (`nn.functional` name),
            recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
    """
    if 0 in tensor.shape:
        warnings.warn("Initializing zero-element tensors is a no-op")
        return tensor
    fan = _calculate_correct_fan(tensor, mode)
    gain = calculate_gain(nonlinearity, a)
    std = gain / math.sqrt(fan)
    with torch.no_grad():
        return tensor.normal_(0, std)
相关推荐
一百天成为python专家10 分钟前
【项目】自然语言处理——情感分析 <上>
人工智能·rnn·自然语言处理·数据分析·lstm·pandas·easyui
新智元14 分钟前
独家!DeepSeek 最新模型上线,全新注意力机制基于北大 ACL 最佳论文
人工智能·openai
丁学文武20 分钟前
大模型原理与实践:第一章-NLP基础概念完整指南_第1部分-概念和发展历史
人工智能·自然语言处理·基础概念·大模型应用·发展历史
新智元27 分钟前
刚刚,Claude Sonnet 4.5 重磅发布,编程新王降临!
人工智能·openai
汽车仪器仪表相关领域29 分钟前
南华 NHXJ-02 汽车悬架检验台:技术特性与实操应用指南
人工智能·算法·汽车·安全性测试·稳定性测试·汽车检测·年检站
云澈ovo35 分钟前
量子计算预备役:AI辅助设计的下一代算力架构
人工智能·架构·量子计算
大千AI助手43 分钟前
MATH-500:大模型数学推理能力评估基准
人工智能·大模型·llm·强化学习·评估基准·数学推理能力·math500
hans汉斯1 小时前
【人工智能与机器人研究】一种库坝系统水下成像探查有缆机器人系统设计模式
大数据·数据库·论文阅读·人工智能·设计模式·机器人·论文笔记
之歆1 小时前
LangGraph构建多智能体
人工智能·python·llama
rhy200605201 小时前
SAM的低秩特性
人工智能·算法·机器学习·语言模型