Pytorch代码:打印模型每层的参数数量和总参数量

这个代码片段定义了一个函数 print_model_parameters,它的作用是打印每层的参数数量以及模型的总参数量。下面是对这个函数的详细解释,重点解释 named_parametersrequires_gradnumel 参数的含义:

python 复制代码
# 打印每层的参数数量和总参数量
def print_model_parameters(model):
    total_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name}: {param.numel()} parameters")
            total_params += param.numel()
            print(f"For now parameters: {total_params}")
    print(f"Total parameters: {total_params}")

具体步骤和解释

  1. 定义和初始化

    python 复制代码
    def print_model_parameters(model):
        total_params = 0

    这个函数接收一个模型对象 model,并初始化一个变量 total_params 用于累积总参数量。

  2. 遍历模型参数

    python 复制代码
    for name, param in model.named_parameters():

    这里使用了 model.named_parameters() 方法,该方法返回一个生成器,生成模型中所有参数的名称和参数张量。它返回的是 (name, parameter) 形式的元组。

    • named_parameters:这是一个PyTorch模型的方法,它返回模型中所有参数的名称和参数本身。参数的名称是字符串类型,而参数是一个 torch.Tensor 对象。
  3. 判断参数是否需要梯度更新

    python 复制代码
    if param.requires_grad:

    每个参数张量都有一个 requires_grad 属性,这个属性是一个布尔值。如果 requires_gradTrue,表示这个参数在训练过程中需要计算梯度并进行更新。

    • requires_grad:这是一个布尔值属性,表示该参数是否需要在训练过程中计算梯度。如果是 True,则该参数会在反向传播时计算并存储梯度。
  4. 打印参数数量并累加

    python 复制代码
    print(f"{name}: {param.numel()} parameters")
    total_params += param.numel()
    print(f"For now parameters: {total_params}")

    对于需要梯度的参数,打印其名称和参数数量,并将该参数的数量累加到 total_params 中。

    • numel:这是一个方法,返回张量中所有元素的数量。例如,一个形状为 (3, 4) 的张量调用 numel() 方法会返回 12,因为这个张量有12个元素。
  5. 打印总参数量

    python 复制代码
    print(f"Total parameters: {total_params}")

    最后,打印模型的总参数数量。

总结

这个函数通过 model.named_parameters() 遍历模型的所有参数,检查每个参数的 requires_grad 属性,只有在 requires_gradTrue 时才计算并打印参数数量,同时累加总参数量。 numel() 方法用于获取每个参数张量的元素数量,从而帮助统计参数数量。最后打印总参数量,提供了对模型规模的一个直观了解。

相关推荐
丶213612 小时前
【CUDA】【PyTorch】安装 PyTorch 与 CUDA 11.7 的详细步骤
人工智能·pytorch·python
羊小猪~~15 小时前
深度学习项目----用LSTM模型预测股价(包含LSTM网络简介,代码数据均可下载)
pytorch·python·rnn·深度学习·机器学习·数据分析·lstm
醒了就刷牙15 小时前
58 深层循环神经网络_by《李沐:动手学深度学习v2》pytorch版
pytorch·rnn·深度学习
Hoper.J1 天前
PyTorch 模型保存与加载的三种常用方式
人工智能·pytorch·python
没有余地 EliasJie1 天前
Windows Ubuntu下搭建深度学习Pytorch训练框架与转换环境TensorRT
pytorch·windows·深度学习·ubuntu·pycharm·conda·tensorflow
被制作时长两年半的个人练习生1 天前
【pytorch】权重为0的情况
人工智能·pytorch·深度学习
GarryLau1 天前
使用pytorch进行迁移学习的两个步骤
pytorch·迁移学习·torchvision
醒了就刷牙1 天前
56 门控循环单元(GRU)_by《李沐:动手学深度学习v2》pytorch版
pytorch·深度学习·gru
橙子小哥的代码世界1 天前
【深度学习】05-RNN循环神经网络-02- RNN循环神经网络的发展历史与演化趋势/LSTM/GRU/Transformer
人工智能·pytorch·rnn·深度学习·神经网络·lstm·transformer
最近好楠啊2 天前
Pytorch实现RNN实验
人工智能·pytorch·rnn