python-pytorch 如何使用python库Netron查看模型结构(以pytorch官网模型为例)0.9.2

Netron查看模型结构

  • 2024年4月27日14:32:30----0.9.2

参照模型

以pytorch官网的tutorial为观察对象,链接是https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html

模型代码如下

python 复制代码
import torch.nn as nn
import torch.nn.functional as F

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size, hidden_size)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        self.h2o = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        hidden = F.tanh(self.i2h(input) + self.h2h(hidden))
        output = self.h2o(hidden)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

n_hidden = 128
rnn = RNN(n_letters, n_hidden, n_categories)

安装Netron

pip install netron即可

其他安装方式参考链接

https://blog.csdn.net/m0_49963403/article/details/136242313

写netron代码

随便找一个地方打个点,如sample方法中

python 复制代码
import netron
max_length = 20

# Sample from a category and starting letter
def sample(category, start_letter='A'):
    with torch.no_grad():  # no need to track history in sampling
        category_tensor = categoryTensor(category)
        input = inputTensor(start_letter)
        hidden = rnn.initHidden()

        output_name = start_letter

        for i in range(max_length):
#             print("category_tensor",category_tensor.size())
#             print("input[0]",input[0].size())
#             print("hidden",hidden.size())
            
            output, hidden = rnn(category_tensor, input[0], hidden)
            torch.onnx.export(rnn,(category_tensor, input[0], hidden) , f='AlexNet1.onnx')   #导出 .onnx 文件
            netron.start('AlexNet1.onnx') #展示结构图
        
            break
#             print("output",output.size())
#             print("hidden",hidden.size())
#             print("====================")
        
            topv, topi = output.topk(1)
            topi = topi[0][0]
            if topi == n_letters - 1:
                break
            else:
                letter = all_letters[topi]
                output_name += letter
            input = inputTensor(letter)

        return output_name

# Get multiple samples from one category and multiple starting letters
def samples(category, start_letters='ABC'):
    for start_letter in start_letters:
        print(sample(category, start_letter))
        break

samples('Russian', 'RUS')

运行查看结果

结果是在浏览器中,运行成功后会显示:

Serving 'AlexNet.onnx' at http://localhost:8080

打开这个网页就可以看见模型结构,如下图

需要关注的地方

  1. 关于参数
    如果模型是一个参数的情况下,如下使用就可以了
python 复制代码
import torch
from torchvision.models import AlexNet
import netron
model = AlexNet()
input = torch.ones((1,3,224,224))
torch.onnx.export(model, input, f='AlexNet.onnx')
netron.start('AlexNet.onnx')

如果模型有多个参数的情况下,则需要如下用括号括起来,如本文中的例子

python 复制代码
torch.onnx.export(rnn,(category_tensor, input[0], hidden) , f='AlexNet1.onnx')   #导出 .onnx 文件
netron.start('AlexNet1.onnx') #展示结构图
  1. 如果运行过程中发现报错找不到模型
    有可能是你手动删除了生成的模型,最好的方法是重新生成这个模型,再运行
相关推荐
醒了就刷牙2 分钟前
56 门控循环单元(GRU)_by《李沐:动手学深度学习v2》pytorch版
pytorch·深度学习·gru
炼丹师小米3 分钟前
Ubuntu24.04.1系统下VideoMamba环境配置
python·环境配置·videomamba
橙子小哥的代码世界3 分钟前
【深度学习】05-RNN循环神经网络-02- RNN循环神经网络的发展历史与演化趋势/LSTM/GRU/Transformer
人工智能·pytorch·rnn·深度学习·神经网络·lstm·transformer
GFCGUO9 分钟前
ubuntu18.04运行OpenPCDet出现的问题
linux·python·学习·ubuntu·conda·pip
快乐就好ya34 分钟前
Java多线程
java·开发语言
CS_GaoMing1 小时前
Centos7 JDK 多版本管理与 Maven 构建问题和注意!
java·开发语言·maven·centos7·java多版本
985小水博一枚呀2 小时前
【深度学习基础模型】神经图灵机(Neural Turing Machines, NTM)详细理解并附实现代码。
人工智能·python·rnn·深度学习·lstm·ntm
2401_858120532 小时前
Spring Boot框架下的大学生就业招聘平台
java·开发语言
转调2 小时前
每日一练:地下城游戏
开发语言·c++·算法·leetcode
Java探秘者2 小时前
Maven下载、安装与环境配置详解:从零开始搭建高效Java开发环境
java·开发语言·数据库·spring boot·spring cloud·maven·idea