pytorch tril 用法并导出onnx demo

import torch

import torch.nn as nn

import onnxruntime as ort

import numpy as np

def create_tril_onnx():

class SimpleNet(nn.Module):

def init(self):

super(SimpleNet, self).init()

self.data1 = torch.ones((2,3), dtype=torch.bool)

def forward(self, x):

tril_x = torch.tril(x)

tril_x = tril_x.float()

x1 = x.float()

return tril_x+x1

model = SimpleNet()

data = torch.ones((2,3), dtype=torch.bool)

output = model(data)

print("output:")

print(output)

torch.onnx.export(model, data, "tril.onnx", input_names=["input"], output_names=["output"])

def inference_onnx():

model = ort.InferenceSession("tril.onnx", provider=["CPUExecutionProvider"])

outputs = model.run(["output"], {"input":np.random.randn(2,3).astype(np.bool_)})

print("outputs:", outputs)

def my_tril():

key_size = 5

data = torch.ones((key_size,key_size), dtype=torch.bool)

for i in range(key_size):

print("\n")

print(i)

print(data[i,i+1:])

data[i,i+1:] = False

print(data)

print(data)

def main():

create_tril_onnx()

inference_onnx()

my_tril()

if name == "main":

main()


导出onnx如下:

相关推荐
互联网全栈架构19 分钟前
遨游Spring AI:第一盘菜Hello World
java·人工智能·后端·spring
m0_4652157920 分钟前
大语言模型解析
人工智能·语言模型·自然语言处理
张较瘦_1 小时前
[论文阅读] 人工智能+软件工程 | 结对编程中的知识转移新图景
人工智能·软件工程·结对编程
小Q小Q2 小时前
cmake编译LASzip和LAStools
人工智能·计算机视觉
yzx9910132 小时前
基于 Q-Learning 算法和 CNN 的强化学习实现方案
人工智能·算法·cnn
token-go2 小时前
[特殊字符] 革命性AI提示词优化平台正式开源!
人工智能·开源
cooldream20093 小时前
华为云Flexus+DeepSeek征文|基于华为云Flexus X和DeepSeek-R1打造个人知识库问答系统
人工智能·华为云·dify
老胖闲聊6 小时前
Python Copilot【代码辅助工具】 简介
开发语言·python·copilot
Blossom.1186 小时前
使用Python和Scikit-Learn实现机器学习模型调优
开发语言·人工智能·python·深度学习·目标检测·机器学习·scikit-learn
曹勖之7 小时前
基于ROS2,撰写python脚本,根据给定的舵-桨动力学模型实现动力学更新
开发语言·python·机器人·ros2