Pytorch查看神经网络结构和参数量

基本方法

python 复制代码
print(model)
print(type(model))

# 模型参数
numEl_list = [p.numel() for p in model.parameters()]
total_params_mb = sum(numEl_list) / 1e6

print(f'Total parameters: {total_params_mb:.2f} MB')
# sum(numEl_list), numEl_list
print(sum(numEl_list))
print(numEl_list)
python 复制代码
# 查看模型参数的基本方法
def get_param_count(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

param_count = get_param_count(model)
print(f"Model Parameter Count: {param_count}")

# 计算每层参数量和大小
def print_layer_params_count(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name} : {param.numel()}")
            print(f"{name} : {param.shape}")

print_layer_params_count(model)

使用Pytorch中的torchsummary包

python 复制代码
from torchsummary import summary
summary(model, input_size=(1, 1, 128, 128, 32))

使用第三方库torchinfo

python 复制代码
from torchinfo import summary
summary(model, input_size=(1, 1, 128, 128, 32))
相关推荐
hello_ejb325 分钟前
聊聊Spring AI Alibaba的SentenceSplitter
人工智能·python·spring
新辞旧梦1 小时前
企业微信自建消息推送应用
服务器·python·企业微信
虎头金猫1 小时前
如何解决 403 错误:请求被拒绝,无法连接到服务器
运维·服务器·python·ubuntu·chatgpt·centos·bug
摸鱼仙人~2 小时前
机器学习常用评价指标
人工智能·机器学习
一点.点3 小时前
WiseAD:基于视觉-语言模型的知识增强型端到端自动驾驶——论文阅读
人工智能·语言模型·自动驾驶
fanstuck4 小时前
从知识图谱到精准决策:基于MCP的招投标货物比对溯源系统实践
人工智能·知识图谱
dqsh064 小时前
树莓派5+Ubuntu24.04 LTS串口通信 保姆级教程
人工智能·python·物联网·ubuntu·机器人
打小就很皮...5 小时前
编写大模型Prompt提示词方法
人工智能·语言模型·prompt
Aliano2175 小时前
Prompt(提示词)工程师,“跟AI聊天”
人工智能·prompt
weixin_445238125 小时前
第R8周:RNN实现阿尔兹海默病诊断(pytorch)
人工智能·pytorch·rnn