《动手学深度学习 Pytorch版》 5.3 延后初始化

python 复制代码
import torch
from torch import nn
from d2l import torch as d2l

下面实例化的多层感知机的输入维度是未知的,因此框架尚未初始化任何参数,显示为"UninitializedParameter"。

python 复制代码
net = nn.Sequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))

net[0].weight
复制代码
c:\Software\Miniconda3\envs\d2l\lib\site-packages\torch\nn\modules\lazy.py:178: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
  warnings.warn('Lazy modules are a new feature under heavy development '





<UninitializedParameter>

一旦指定了输入维度,框架就可以一层一层的延迟初始化。

python 复制代码
X = torch.rand(2, 20)
net(X)

net[0].weight.shape
复制代码
torch.Size([256, 20])

练习

(1)如果指定了第一层的输入维度,但没有指定后续层的维度,会发生什么?是否立即进行初始化?

python 复制代码
net = nn.Sequential(
    nn.Linear(20, 256), nn.ReLU(),
    nn.LazyLinear(128), nn.ReLU(),
    nn.LazyLinear(10)
)
net[0].weight, net[2].weight, net[4].weight
复制代码
c:\Software\Miniconda3\envs\d2l\lib\site-packages\torch\nn\modules\lazy.py:178: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
  warnings.warn('Lazy modules are a new feature under heavy development '





(Parameter containing:
 tensor([[ 0.1332,  0.1372, -0.0939,  ..., -0.0579, -0.0911, -0.1820],
         [-0.1570, -0.0993, -0.0685,  ..., -0.0469, -0.0208,  0.0665],
         [ 0.0861,  0.1135,  0.1631,  ..., -0.1407,  0.1088, -0.2052],
         ...,
         [-0.1454, -0.0283, -0.1074,  ..., -0.2164, -0.2169,  0.1913],
         [-0.1617,  0.1206, -0.2119,  ..., -0.1862, -0.0951,  0.1535],
         [-0.0229, -0.2133, -0.1027,  ...,  0.1973,  0.1314,  0.1283]],
        requires_grad=True),
 <UninitializedParameter>,
 <UninitializedParameter>)
python 复制代码
net(X)  # 延迟初始化
net[0].weight.shape, net[2].weight.shape, net[4].weight.shape
复制代码
(torch.Size([256, 20]), torch.Size([128, 256]), torch.Size([10, 128]))

(2)如果指定了不匹配的维度会发生什么?

python 复制代码
X = torch.rand(2, 10)
net(X)  # 会报错
复制代码
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

Cell In[14], line 2
      1 X = torch.rand(2, 10)
----> 2 net(X)


File c:\Software\Miniconda3\envs\d2l\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []


File c:\Software\Miniconda3\envs\d2l\lib\site-packages\torch\nn\modules\container.py:139, in Sequential.forward(self, input)
    137 def forward(self, input):
    138     for module in self:
--> 139         input = module(input)
    140     return input


File c:\Software\Miniconda3\envs\d2l\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []


File c:\Software\Miniconda3\envs\d2l\lib\site-packages\torch\nn\modules\linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)


RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x10 and 20x256)

(3)如果输入具有不同的维度,需要做什么?

调整维度,要么填充,要么降维。

相关推荐
康康的AI博客1 小时前
腾讯王炸:CodeMoment - 全球首个产设研一体 AI IDE
ide·人工智能
中达瑞和-高光谱·多光谱1 小时前
中达瑞和LCTF:精准调控光谱,赋能显微成像新突破
人工智能
mahtengdbb11 小时前
【目标检测实战】基于YOLOv8-DynamicHGNetV2的猪面部检测系统搭建与优化
人工智能·yolo·目标检测
Pyeako1 小时前
深度学习--BP神经网络&梯度下降&损失函数
人工智能·python·深度学习·bp神经网络·损失函数·梯度下降·正则化惩罚
清 澜2 小时前
大模型面试400问第一部分第一章
人工智能·大模型·大模型面试
哥布林学者2 小时前
吴恩达深度学习课程五:自然语言处理 第二周:词嵌入(四)分层 softmax 和负采样
深度学习·ai
不大姐姐AI智能体2 小时前
搭了个小红书笔记自动生产线,一句话生成图文,一键发布,支持手机端、电脑端发布
人工智能·经验分享·笔记·矩阵·aigc
虹科网络安全2 小时前
艾体宝方案 | 释放数据潜能 · 构建 AI 驱动的自动驾驶实时数据处理与智能筛选平台
人工智能·机器学习·自动驾驶
Deepoch3 小时前
Deepoc数学大模型:发动机行业的算法引擎
人工智能·算法·机器人·发动机·deepoc·发动机行业