Pytorch查看神经网络结构和参数量

基本方法

python 复制代码
print(model)
print(type(model))

# 模型参数
numEl_list = [p.numel() for p in model.parameters()]
total_params_mb = sum(numEl_list) / 1e6

print(f'Total parameters: {total_params_mb:.2f} MB')
# sum(numEl_list), numEl_list
print(sum(numEl_list))
print(numEl_list)
python 复制代码
# 查看模型参数的基本方法
def get_param_count(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

param_count = get_param_count(model)
print(f"Model Parameter Count: {param_count}")

# 计算每层参数量和大小
def print_layer_params_count(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name} : {param.numel()}")
            print(f"{name} : {param.shape}")

print_layer_params_count(model)

使用Pytorch中的torchsummary包

python 复制代码
from torchsummary import summary
summary(model, input_size=(1, 1, 128, 128, 32))

使用第三方库torchinfo

python 复制代码
from torchinfo import summary
summary(model, input_size=(1, 1, 128, 128, 32))
相关推荐
王中阳Go2 分钟前
05 Go Eino AI应用开发实战 | Docker 部署指南
人工智能·后端·go
木头左4 分钟前
机器学习辅助的LSTM交易策略特征工程与入参筛选技巧
python
腾讯云开发者7 分钟前
当10年架构师拿起AI:不是写不动了,是写得太快了
人工智能
Lenyiin7 分钟前
《 Linux 修炼全景指南: 八 》别再碎片化学习!掌控 Linux 开发工具链:gcc、g++、GDB、Bash、Python 与工程化实践
linux·python·bash·gdb·gcc·g++·lenyiin
Swizard12 分钟前
告别“意大利面条”:FastAPI 生产级架构的最佳实践指南
python·fastapi
小马过河R17 分钟前
RAG检索增强生成:通过重排序提升AI信息检索精准度
人工智能·语言模型
不惑_18 分钟前
通俗理解卷积神经网络
人工智能·windows·python·深度学习·机器学习
滴啦嘟啦哒26 分钟前
【机械臂】【总览】基于VLA结构的指令驱动式机械臂
python·ros2·vla
rayufo37 分钟前
自定义数据在深度学习中的应用方法
人工智能·深度学习
写代码的【黑咖啡】37 分钟前
深入理解 Python 中的函数
开发语言·python