Pytorch查看神经网络结构和参数量

基本方法

python 复制代码
print(model)
print(type(model))

# 模型参数
numEl_list = [p.numel() for p in model.parameters()]
total_params_mb = sum(numEl_list) / 1e6

print(f'Total parameters: {total_params_mb:.2f} MB')
# sum(numEl_list), numEl_list
print(sum(numEl_list))
print(numEl_list)
python 复制代码
# 查看模型参数的基本方法
def get_param_count(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

param_count = get_param_count(model)
print(f"Model Parameter Count: {param_count}")

# 计算每层参数量和大小
def print_layer_params_count(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name} : {param.numel()}")
            print(f"{name} : {param.shape}")

print_layer_params_count(model)

使用Pytorch中的torchsummary包

python 复制代码
from torchsummary import summary
summary(model, input_size=(1, 1, 128, 128, 32))

使用第三方库torchinfo

python 复制代码
from torchinfo import summary
summary(model, input_size=(1, 1, 128, 128, 32))
相关推荐
Despacito0o1 分钟前
OpenCV图像增强实战教程:从理论到代码实现
人工智能·opencv·计算机视觉
东临碣石829 分钟前
【字节跳动AI论文】Seaweed-7B:视频生成基础模型的高成本效益培训
人工智能
流云一号10 分钟前
Python实现贪吃蛇二
开发语言·python
大话数据分析11 分钟前
现在AI大模型能帮做数据分析吗?
人工智能·数据分析
lixy57915 分钟前
深度学习之线性代数基础
人工智能·深度学习·线性代数
happyprince29 分钟前
LLM Post-Training
人工智能
听风吹等浪起33 分钟前
NLP实战(3):RNN英文名国家分类
人工智能·python·rnn·深度学习
奋斗者1号35 分钟前
机器学习中的分类算法与数据处理实践:从理论到应用
人工智能·机器学习·分类
编码小哥35 分钟前
OpenCV直方图均衡化全面解析:从灰度到彩色图像的增强技术
人工智能·opencv·计算机视觉
不爱吃于先生36 分钟前
机器学习概述自用笔记(李宏毅)
人工智能·笔记·机器学习