Pytorch查看神经网络结构和参数量

基本方法

python 复制代码
print(model)
print(type(model))

# 模型参数
numEl_list = [p.numel() for p in model.parameters()]
total_params_mb = sum(numEl_list) / 1e6

print(f'Total parameters: {total_params_mb:.2f} MB')
# sum(numEl_list), numEl_list
print(sum(numEl_list))
print(numEl_list)
python 复制代码
# 查看模型参数的基本方法
def get_param_count(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

param_count = get_param_count(model)
print(f"Model Parameter Count: {param_count}")

# 计算每层参数量和大小
def print_layer_params_count(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name} : {param.numel()}")
            print(f"{name} : {param.shape}")

print_layer_params_count(model)

使用Pytorch中的torchsummary包

python 复制代码
from torchsummary import summary
summary(model, input_size=(1, 1, 128, 128, 32))

使用第三方库torchinfo

python 复制代码
from torchinfo import summary
summary(model, input_size=(1, 1, 128, 128, 32))
相关推荐
Virgil1392 小时前
【TrOCR】训练代码
人工智能·深度学习·ocr
锅挤2 小时前
深度学习3(向量化编程+ python中实现逻辑回归)
人工智能·深度学习
Deng9452013144 小时前
基于Python的职位画像系统设计与实现
开发语言·python·文本分析·自然语言处理nlp·scrapy框架·gensim应用
MARS_AI_6 小时前
云蝠智能 Voice Agent 落地展会邀约场景:重构会展行业的智能交互范式
人工智能·自然语言处理·重构·交互·语音识别·信息与通信
weixin_422456446 小时前
第N7周:调用Gensim库训练Word2Vec模型
人工智能·机器学习·word2vec
FreakStudio8 小时前
一文速通 Python 并行计算:13 Python 异步编程-基本概念与事件循环和回调机制
python·pycharm·协程·多进程·并行计算·异步编程
HuggingFace9 小时前
Hugging Face 开源机器人 Reachy Mini 开启预定
人工智能
豌豆花下猫10 小时前
让 Python 代码飙升330倍:从入门到精通的四种性能优化实践
后端·python·ai
夏末蝉未鸣0110 小时前
python transformers库笔记(BertForTokenClassification类)
python·自然语言处理·transformer
企企通采购云平台10 小时前
「天元宠物」×企企通,加速数智化升级,“链”接萌宠消费新蓝海
大数据·人工智能·宠物