pytorch笔记:named_parameters

  • named_parameters 是 PyTorch 中一个非常有用的函数,用于访问模型中所有定义的参数及其对应的名称。
  • 它是 torch.nn.Module 类的方法之一,返回一个生成器,生成 (name, parameter) 对,name 是参数的名称,parameter 是对应的参数张量。

1 举例

1.0 创建模型

python 复制代码
import torch
import torch.nn as nn

# 定义一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 64, 5)
        self.fc1 = nn.Linear(64 * 4 * 4, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = x.view(-1, 64 * 4 * 4)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 实例化模型
model_tst = SimpleModel()

1.1 应用1:打印模型的所有参数及其名称

python 复制代码
for name, param in model_tst.named_parameters():
    print(name, param.shape)

'''
conv1.weight torch.Size([20, 1, 5, 5])
conv1.bias torch.Size([20])
conv2.weight torch.Size([64, 20, 5, 5])
conv2.bias torch.Size([64])
fc1.weight torch.Size([500, 1024])
fc1.bias torch.Size([500])
fc2.weight torch.Size([10, 500])
fc2.bias torch.Size([10])
conv1.weight torch.Size([20, 1, 5, 5])
conv1.bias torch.Size([20])
conv2.weight torch.Size([64, 20, 5, 5])
conv2.bias torch.Size([64])
fc1.weight torch.Size([500, 1024])
fc1.bias torch.Size([500])
fc2.weight torch.Size([10, 500])
fc2.bias torch.Size([10])
'''

1.2 应用2:冻结特定层的参数

假设我们只想训练全连接层,而冻结卷积层的参数:

python 复制代码
for name, param in model_tst.named_parameters():
    if 'conv' in name:
        param.requires_grad = False

1.3 应用3:自定义优化器参数

可以使用 named_parameters 创建自定义的参数组,以便对不同的参数组应用不同的学习率:

python 复制代码
optimizer = torch.optim.SGD([
    {'params': [param for name, param in model_tst.named_parameters() if 'conv' in name], 'lr': 0.01},
    {'params': [param for name, param in model_tst.named_parameters() if 'fc' in name], 'lr': 0.1}
], momentum=0.9)
相关推荐
杏仁橙橙饼几秒前
2024自然语言处理期末回忆
人工智能·自然语言处理
YmgmY3 分钟前
推荐算法学习笔记2.2:基于深度学习的推荐算法-基于特征交叉组合+逻辑回归思路的深度推荐算法-Deep Crossing模型
笔记·学习·推荐算法
紫色沙4 分钟前
数据分析入门指南:从基础概念到实际应用(一)
大数据·人工智能·数据分析
三花AI7 分钟前
MotionClone: 视频运动克隆技术
人工智能·音视频
zhyhg14 分钟前
AI硬件加速版XVDPU入门
人工智能
Purepisces35 分钟前
深度学习笔记: 最详尽解释预测系统的分类指标(精确率、召回率和 F1 值)
人工智能·笔记·python·深度学习·机器学习·分类
科研小白 新人上路44 分钟前
ChatGPT-4o医学应用、论文撰写、数据分析与可视化、机器学习建模、病例自动化处理、病情分析与诊断支持
人工智能·chatgpt·自动化·论文撰写
猫头虎1 小时前
猫头虎博主全栈前沿AI技术领域矩阵社群
人工智能·职场和发展·创业创新·学习方法·业界资讯·程序员创富·改行学it
数据超市1 小时前
全国现状建筑数据,选中范围即可查询下载,富含建筑物位置、层数、建筑物功能、名称地址等信息!
大数据·人工智能·信息可视化·数据挖掘·数据分析
coolkidlan1 小时前
【AI原理解析】—k-means原理
人工智能·机器学习·kmeans