Pytorch查看神经网络结构和参数量

基本方法

python 复制代码
print(model)
print(type(model))

# 模型参数
numEl_list = [p.numel() for p in model.parameters()]
total_params_mb = sum(numEl_list) / 1e6

print(f'Total parameters: {total_params_mb:.2f} MB')
# sum(numEl_list), numEl_list
print(sum(numEl_list))
print(numEl_list)
python 复制代码
# 查看模型参数的基本方法
def get_param_count(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

param_count = get_param_count(model)
print(f"Model Parameter Count: {param_count}")

# 计算每层参数量和大小
def print_layer_params_count(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name} : {param.numel()}")
            print(f"{name} : {param.shape}")

print_layer_params_count(model)

使用Pytorch中的torchsummary包

python 复制代码
from torchsummary import summary
summary(model, input_size=(1, 1, 128, 128, 32))

使用第三方库torchinfo

python 复制代码
from torchinfo import summary
summary(model, input_size=(1, 1, 128, 128, 32))
相关推荐
程序员爱钓鱼4 分钟前
Python编程实战 - Python实用工具与库 - 正则表达式匹配(re 模块)
后端·python·面试
程序员爱钓鱼6 分钟前
Python编程实战 - Python实用工具与库 - 爬取并存储网页数据
后端·python·面试
bin91538 分钟前
PHP文档保卫战:AI自动生成下的创意守护与反制指南
开发语言·人工智能·php·工具·ai工具
AI 研究所15 分钟前
1024开发者节:开源发布,引领生态繁荣
人工智能·语言模型·开源·大模型·交互·agent
深圳市青牛科技实业有限公司 小芋圆23 分钟前
30V N 沟道 MOSFET SP30N06NK 全面解析:参数、特性与应用场景
人工智能·单片机·嵌入式硬件·无人机·高频dc-dc谐振变换器·笔记本电脑开合检测
leafff12329 分钟前
AI数据库研究:RAG 架构运行算力需求?
数据库·人工智能·语言模型·自然语言处理·架构
陈辛chenxin38 分钟前
【大数据技术01】数据科学的基础理论
大数据·人工智能·python·深度学习·机器学习·数据挖掘·数据分析
极客BIM工作室1 小时前
扩散模型核心机制解析:U-Net调用逻辑、反向传播时机与步骤对称性
人工智能·深度学习·机器学习
从零开始的奋豆1 小时前
计算机视觉(三):特征检测与光流法
人工智能·计算机视觉
蒋星熠1 小时前
爬虫中Cookies模拟浏览器登录技术详解
开发语言·爬虫·python·正则表达式·自动化·php·web