深入浅出Pytorch函数——torch.nn.init.kaiming_uniform_

分类目录:《深入浅出Pytorch函数》总目录

相关文章:

· 深入浅出Pytorch函数------torch.nn.init.calculate_gain

· 深入浅出Pytorch函数------torch.nn.init.uniform_

· 深入浅出Pytorch函数------torch.nn.init.normal_

· 深入浅出Pytorch函数------torch.nn.init.constant_

· 深入浅出Pytorch函数------torch.nn.init.ones_

· 深入浅出Pytorch函数------torch.nn.init.zeros_

· 深入浅出Pytorch函数------torch.nn.init.eye_

· 深入浅出Pytorch函数------torch.nn.init.dirac_

· 深入浅出Pytorch函数------torch.nn.init.xavier_uniform_

· 深入浅出Pytorch函数------torch.nn.init.xavier_normal_

· 深入浅出Pytorch函数------torch.nn.init.kaiming_uniform_

· 深入浅出Pytorch函数------torch.nn.init.kaiming_normal_

· 深入浅出Pytorch函数------torch.nn.init.trunc_normal_

· 深入浅出Pytorch函数------torch.nn.init.orthogonal_

· 深入浅出Pytorch函数------torch.nn.init.sparse_


torch.nn.init模块中的所有函数都用于初始化神经网络参数,因此它们都在torc.no_grad()模式下运行,autograd不会将其考虑在内。

根据He, K等人于2015年在《Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification》中描述的方法,用一个均匀分布生成值,填充输入的张量或变量。结果张量中的值采样自 U ( − bound , bound ) U(-\text{bound}, \text{bound}) U(−bound,bound),其中:
bound = gain × 3 fan_mode \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} bound=gain×fan_mode3

这种方法也被称为He initialisation。

语法

复制代码
torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

参数

  • tensor`Tensor` 一个 N N N维张量torch.Tensor
  • a`float` 这层之后使用的rectifier的斜率系数(ReLU的默认值为0)
  • mode`str` 可以为fan_infan_out。若为fan_in则保留前向传播时权值方差的量级,若为fan_out则保留反向传播时的量级,默认值为fan_in
  • nonlinearity`str` 一个非线性函数,即一个nn.functional的名称,推荐使用relu或者leaky_relu,默认值为leaky_relu

返回值

一个torch.Tensor且参数tensor也会更新

实例

复制代码
w = torch.empty(3, 5)
nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')

函数实现

复制代码
def kaiming_uniform_(
    tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
):
    r"""Fills the input `Tensor` with values according to the method
    described in `Delving deep into rectifiers: Surpassing human-level
    performance on ImageNet classification` - He, K. et al. (2015), using a
    uniform distribution. The resulting tensor will have values sampled from
    :math:`\mathcal{U}(-\text{bound}, \text{bound})` where

    .. math::
        \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}

    Also known as He initialization.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        a: the negative slope of the rectifier used after this layer (only
            used with ``'leaky_relu'``)
        mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
            preserves the magnitude of the variance of the weights in the
            forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
            backwards pass.
        nonlinearity: the non-linear function (`nn.functional` name),
            recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
    """
    if torch.overrides.has_torch_function_variadic(tensor):
        return torch.overrides.handle_torch_function(
            kaiming_uniform_,
            (tensor,),
            tensor=tensor,
            a=a,
            mode=mode,
            nonlinearity=nonlinearity)

    if 0 in tensor.shape:
        warnings.warn("Initializing zero-element tensors is a no-op")
        return tensor
    fan = _calculate_correct_fan(tensor, mode)
    gain = calculate_gain(nonlinearity, a)
    std = gain / math.sqrt(fan)
    bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
    with torch.no_grad():
        return tensor.uniform_(-bound, bound)
相关推荐
木雷坞14 小时前
让 AI 编程助手跑得起项目:Dev Container 实践记录
人工智能
腾讯云开发者15 小时前
港科大郭毅可谈Agentic AI时代的核心命题:人机共生,人不可能退场
人工智能
常丛丛15 小时前
5.6 LangGraph-Edges理解-Agent图的道路系统
人工智能
雪隐15 小时前
个人电脑玩AI-08让5060 Ti给你打工——我拿 Unlimited-OCR扫了 600 页书,然后悟了
人工智能·后端
Coffeeee15 小时前
Prompt要花心思写,与 AI 对话的七个技巧
人工智能·aigc·ai编程
蝎子莱莱爱打怪15 小时前
Claude Code 官宣新升级:子智能体默认后台跑,你边聊它边干活
人工智能
武子康15 小时前
调查研究-206 DeepSeek DSpark 深度解析:大模型推理加速,正在从“模型能力”转向“系统工程”
人工智能·agent·deepseek
甲维斯16 小时前
最佳work模型sonnet5来了,直接就能用!
人工智能
IT_陈寒16 小时前
React hooks 闭包陷阱把我的状态吃掉了,原来问题出在这里
前端·人工智能·后端