Pytorch查看神经网络结构和参数量

基本方法

python 复制代码
print(model)
print(type(model))

# 模型参数
numEl_list = [p.numel() for p in model.parameters()]
total_params_mb = sum(numEl_list) / 1e6

print(f'Total parameters: {total_params_mb:.2f} MB')
# sum(numEl_list), numEl_list
print(sum(numEl_list))
print(numEl_list)
python 复制代码
# 查看模型参数的基本方法
def get_param_count(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

param_count = get_param_count(model)
print(f"Model Parameter Count: {param_count}")

# 计算每层参数量和大小
def print_layer_params_count(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name} : {param.numel()}")
            print(f"{name} : {param.shape}")

print_layer_params_count(model)

使用Pytorch中的torchsummary包

python 复制代码
from torchsummary import summary
summary(model, input_size=(1, 1, 128, 128, 32))

使用第三方库torchinfo

python 复制代码
from torchinfo import summary
summary(model, input_size=(1, 1, 128, 128, 32))
相关推荐
玩大数据的龙威11 小时前
农经权二轮延包—批量替换签名盖章页扫描页
python
lqqjuly11 小时前
《AI Agent智能体与MCP开发实战》之构建个性化的arXiv科研论文MCP服务实战
人工智能·深度学习
羊仔AI探索11 小时前
GLM-4.6接入Claude Code插件,国内丝滑编程
ide·人工智能·ai·aigc·ai编程
Bdygsl11 小时前
数字图像处理总结 Day 1
人工智能·算法·计算机视觉
墨染星辰云水间11 小时前
机器学习(一)
人工智能·机器学习
张彦峰ZYF11 小时前
Coze文章仿写:智能体 + 工作流实现内容自动生成与插图输出
人工智能·ai·coze dify
AI视觉网奇11 小时前
手部检测 yolov5 实战笔记
python·深度学习·计算机视觉
Jerry.张蒙11 小时前
SAP传输请求流程:从开发到生产的安全流转
大数据·网络·人工智能·学习·职场和发展·区块链·运维开发
WXG101111 小时前
【Flask-7】前后端数据交互
python·ios·flask
Lethehong11 小时前
openGauss在教育领域的AI实践:基于Java JDBC的学生成绩预测系统
java·开发语言·人工智能·sql·rag