Pytorch的named_children, named_modules和named_children

PyTorch 中,named_childrennamed_modulesnamed_parameters 是用于获取神经网络模型组件和参数的三种不同的方法。下面是它们各自的作用和区别:

  • named_parameters:递归地列出所有参数名称和tensor
  • named_modules:递归地列出所有子层,其中第一个返回值就是模型本身
  • named_children:列出模型的第一层级的子层,不往下进行深入递归

1. named_children:

  • named_children 返回一个生成器,它包含模型中所有直接子模块的名称和模块对。
  • 它只返回一层级的子模块,不递归到更深层次的子模块。
  • 这个方法通常用于迭代模型的直接子模块,并对其进行操作或检查。

示例:

复制代码
from torchvision.models import resnet18

model = resnet18()
# Print each layer name and its module
# Note that named_children method only returns the first level submodules
for name, layer in model.named_children():
    print(name.ljust(10), '-->', type(layer))

输出:

复制代码
conv1      --> <class 'torch.nn.modules.conv.Conv2d'>
bn1        --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
relu       --> <class 'torch.nn.modules.activation.ReLU'>
maxpool    --> <class 'torch.nn.modules.pooling.MaxPool2d'>
layer1     --> <class 'torch.nn.modules.container.Sequential'>
layer2     --> <class 'torch.nn.modules.container.Sequential'>
layer3     --> <class 'torch.nn.modules.container.Sequential'>
layer4     --> <class 'torch.nn.modules.container.Sequential'>
avgpool    --> <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>
fc         --> <class 'torch.nn.modules.linear.Linear'>

2. named_modules:

  • named_modules 返回一个生成器,它包含模型中所有模块的名称和模块对,包括子模块的子模块。
  • 它递归地遍历整个模型,返回所有模块的名称和引用。
  • 这个方法适用于当你需要对模型中的所有模块进行操作或检查时,无论它们位于哪一层级。

示例:

复制代码
from torchvision.models import resnet18

model = resnet18()
# The first layer is the model itself
for name, layer in model.named_modules():
    print(name.ljust(15), '-->', type(layer))

输出:

复制代码
                --> <class 'torchvision.models.resnet.ResNet'>
conv1           --> <class 'torch.nn.modules.conv.Conv2d'>
bn1             --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
relu            --> <class 'torch.nn.modules.activation.ReLU'>
maxpool         --> <class 'torch.nn.modules.pooling.MaxPool2d'>
layer1          --> <class 'torch.nn.modules.container.Sequential'>
layer1.0        --> <class 'torchvision.models.resnet.BasicBlock'>
layer1.0.conv1  --> <class 'torch.nn.modules.conv.Conv2d'>
layer1.0.bn1    --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
layer1.0.relu   --> <class 'torch.nn.modules.activation.ReLU'>
layer1.0.conv2  --> <class 'torch.nn.modules.conv.Conv2d'>
layer1.0.bn2    --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
layer1.1        --> <class 'torchvision.models.resnet.BasicBlock'>
layer1.1.conv1  --> <class 'torch.nn.modules.conv.Conv2d'>
layer1.1.bn1    --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
layer1.1.relu   --> <class 'torch.nn.modules.activation.ReLU'>
layer1.1.conv2  --> <class 'torch.nn.modules.conv.Conv2d'>
layer1.1.bn2    --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
layer2          --> <class 'torch.nn.modules.container.Sequential'>
layer2.0        --> <class 'torchvision.models.resnet.BasicBlock'>
layer2.0.conv1  --> <class 'torch.nn.modules.conv.Conv2d'>
layer2.0.bn1    --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
layer2.0.relu   --> <class 'torch.nn.modules.activation.ReLU'>
layer2.0.conv2  --> <class 'torch.nn.modules.conv.Conv2d'>
layer2.0.bn2    --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
layer2.0.downsample --> <class 'torch.nn.modules.container.Sequential'>
layer2.0.downsample.0 --> <class 'torch.nn.modules.conv.Conv2d'>
layer2.0.downsample.1 --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
layer2.1        --> <class 'torchvision.models.resnet.BasicBlock'>
layer2.1.conv1  --> <class 'torch.nn.modules.conv.Conv2d'>
layer2.1.bn1    --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
layer2.1.relu   --> <class 'torch.nn.modules.activation.ReLU'>
layer2.1.conv2  --> <class 'torch.nn.modules.conv.Conv2d'>
layer2.1.bn2    --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
layer3          --> <class 'torch.nn.modules.container.Sequential'>
layer3.0        --> <class 'torchvision.models.resnet.BasicBlock'>
layer3.0.conv1  --> <class 'torch.nn.modules.conv.Conv2d'>
layer3.0.bn1    --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
layer3.0.relu   --> <class 'torch.nn.modules.activation.ReLU'>
layer3.0.conv2  --> <class 'torch.nn.modules.conv.Conv2d'>
layer3.0.bn2    --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
layer3.0.downsample --> <class 'torch.nn.modules.container.Sequential'>
layer3.0.downsample.0 --> <class 'torch.nn.modules.conv.Conv2d'>
layer3.0.downsample.1 --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
layer3.1        --> <class 'torchvision.models.resnet.BasicBlock'>
layer3.1.conv1  --> <class 'torch.nn.modules.conv.Conv2d'>
layer3.1.bn1    --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
layer3.1.relu   --> <class 'torch.nn.modules.activation.ReLU'>
layer3.1.conv2  --> <class 'torch.nn.modules.conv.Conv2d'>
layer3.1.bn2    --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
layer4          --> <class 'torch.nn.modules.container.Sequential'>
layer4.0        --> <class 'torchvision.models.resnet.BasicBlock'>
layer4.0.conv1  --> <class 'torch.nn.modules.conv.Conv2d'>
layer4.0.bn1    --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
layer4.0.relu   --> <class 'torch.nn.modules.activation.ReLU'>
layer4.0.conv2  --> <class 'torch.nn.modules.conv.Conv2d'>
layer4.0.bn2    --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
layer4.0.downsample --> <class 'torch.nn.modules.container.Sequential'>
layer4.0.downsample.0 --> <class 'torch.nn.modules.conv.Conv2d'>
layer4.0.downsample.1 --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
layer4.1        --> <class 'torchvision.models.resnet.BasicBlock'>
layer4.1.conv1  --> <class 'torch.nn.modules.conv.Conv2d'>
layer4.1.bn1    --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
layer4.1.relu   --> <class 'torch.nn.modules.activation.ReLU'>
layer4.1.conv2  --> <class 'torch.nn.modules.conv.Conv2d'>
layer4.1.bn2    --> <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
avgpool         --> <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>
fc              --> <class 'torch.nn.modules.linear.Linear'>

3. named_parameters:

  • named_parameters 返回一个生成器,它包含模型中所有参数的名称和参数值对。
  • 它递归地遍历模型,返回所有可训练参数的名称和参数张量。
  • 这个方法用于获取和检查模型中的参数,例如在打印模型参数、保存模型或加载模型时使用。

示例:

复制代码
from torchvision.models import resnet18

model = resnet18()
for name, param in model.named_parame

conv1.weight              --> torch.Size([64, 3, 7, 7])
bn1.weight                --> torch.Size([64])
bn1.bias                  --> torch.Size([64])
layer1.0.conv1.weight     --> torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight       --> torch.Size([64])
layer1.0.bn1.bias         --> torch.Size([64])
layer1.0.conv2.weight     --> torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight       --> torch.Size([64])
layer1.0.bn2.bias         --> torch.Size([64])
layer1.1.conv1.weight     --> torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight       --> torch.Size([64])
layer1.1.bn1.bias         --> torch.Size([64])
layer1.1.conv2.weight     --> torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight       --> torch.Size([64])
layer1.1.bn2.bias         --> torch.Size([64])
layer2.0.conv1.weight     --> torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight       --> torch.Size([128])
layer2.0.bn1.bias         --> torch.Size([128])
layer2.0.conv2.weight     --> torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight       --> torch.Size([128])
layer2.0.bn2.bias         --> torch.Size([128])
layer2.0.downsample.0.weight --> torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight --> torch.Size([128])
layer2.0.downsample.1.bias --> torch.Size([128])
layer2.1.conv1.weight     --> torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight       --> torch.Size([128])
layer2.1.bn1.bias         --> torch.Size([128])
layer2.1.conv2.weight     --> torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight       --> torch.Size([128])
layer2.1.bn2.bias         --> torch.Size([128])
layer3.0.conv1.weight     --> torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight       --> torch.Size([256])
layer3.0.bn1.bias         --> torch.Size([256])
layer3.0.conv2.weight     --> torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight       --> torch.Size([256])
layer3.0.bn2.bias         --> torch.Size([256])
layer3.0.downsample.0.weight --> torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight --> torch.Size([256])
layer3.0.downsample.1.bias --> torch.Size([256])
layer3.1.conv1.weight     --> torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight       --> torch.Size([256])
layer3.1.bn1.bias         --> torch.Size([256])
layer3.1.conv2.weight     --> torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight       --> torch.Size([256])
layer3.1.bn2.bias         --> torch.Size([256])
layer4.0.conv1.weight     --> torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight       --> torch.Size([512])
layer4.0.bn1.bias         --> torch.Size([512])
layer4.0.conv2.weight     --> torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight       --> torch.Size([512])
layer4.0.bn2.bias         --> torch.Size([512])
layer4.0.downsample.0.weight --> torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight --> torch.Size([512])
layer4.0.downsample.1.bias --> torch.Size([512])
layer4.1.conv1.weight     --> torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight       --> torch.Size([512])
layer4.1.bn1.bias         --> torch.Size([512])
layer4.1.conv2.weight     --> torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight       --> torch.Size([512])
layer4.1.bn2.bias         --> torch.Size([512])
fc.weight                 --> torch.Size([1000, 512])
fc.bias                   --> torch.Size([1000])

总结来说,named_children 用于获取模型的直接子模块,named_modules 用于获取模型的所有模块(包括嵌套的子模块),而 named_parameters 用于获取模型中的所有参数。这些方法在模型调试、分析和优化时非常有用。

相关推荐
RFdragon1 小时前
分享本周所学——三维重建算法3D Gaussian Splatting(3DGS)
人工智能·线性代数·算法·机器学习·计算机视觉·矩阵·paddlepaddle
星河耀银海1 小时前
3D效果:HTML5 WebGL结合AI实现智能3D场景渲染
前端·人工智能·深度学习·3d·html5·webgl
li99yo1 小时前
3DGS的复现
图像处理·pytorch·经验分享·python·3d·conda·pip
balmtv5 小时前
2026年多模态AI文件处理与联网搜索完全教程:国内镜像方案实测
人工智能
2501_926978335 小时前
AI的三次起落发展分析,及未来预测----理论5.0的应用
人工智能·经验分享·笔记·ai写作·agi
前网易架构师-高司机5 小时前
带标注的瓶盖识别数据集,识别率99.5%,可识别瓶盖,支持yolo,coco json,pascal voc xml格式
人工智能·yolo·数据集·瓶盖
Dontla5 小时前
用pip install -e .开发Python包时,Python项目目录结构(项目结构)(可编辑安装editable install)
python·pip
Thomas.Sir5 小时前
第三章:Python3 之 字符串
开发语言·python·字符串·string
软件供应链安全指南5 小时前
以AI治理AI|问境AIST首家通过信通院大模型安全扫描产品能力评估!
人工智能·安全·ai安全·问境aist·aist·智能体安全