pytorch tril 用法并导出onnx demo

import torch

import torch.nn as nn

import onnxruntime as ort

import numpy as np

def create_tril_onnx():

class SimpleNet(nn.Module):

def init(self):

super(SimpleNet, self).init()

self.data1 = torch.ones((2,3), dtype=torch.bool)

def forward(self, x):

tril_x = torch.tril(x)

tril_x = tril_x.float()

x1 = x.float()

return tril_x+x1

model = SimpleNet()

data = torch.ones((2,3), dtype=torch.bool)

output = model(data)

print("output:")

print(output)

torch.onnx.export(model, data, "tril.onnx", input_names=["input"], output_names=["output"])

def inference_onnx():

model = ort.InferenceSession("tril.onnx", provider=["CPUExecutionProvider"])

outputs = model.run(["output"], {"input":np.random.randn(2,3).astype(np.bool_)})

print("outputs:", outputs)

def my_tril():

key_size = 5

data = torch.ones((key_size,key_size), dtype=torch.bool)

for i in range(key_size):

print("\n")

print(i)

print(data[i,i+1:])

data[i,i+1:] = False

print(data)

print(data)

def main():

create_tril_onnx()

inference_onnx()

my_tril()

if name == "main":

main()


导出onnx如下:

相关推荐
CRUD酱几秒前
RabbitMQ是如何确保消息的可靠性的?
java·python·rabbitmq
Coding茶水间1 分钟前
基于深度学习的水稻虫害检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
图像处理·人工智能·深度学习·yolo·目标检测·计算机视觉
试着2 分钟前
【投资学习】腾讯控股(0700.HK)
大数据·人工智能·业界资讯·腾讯
百泰派克生物科技2 分钟前
液相色谱-质谱(LC-MS)肽段分析
人工智能·生物学·质谱·实验外包
qq_411262422 分钟前
四博智联的`AI-01开发板`,基于乐鑫ESP32-C2 + 专属定制的离线语音模组
人工智能·物联网·四博智联
CodeCraft Studio7 分钟前
让项目管理更智能:基于 DHTMLX Gantt + AI 的自然语言项目构建方案
人工智能·项目管理·甘特图·dhtmlx·dhtmlx gantt·gantt
天若有情6737 分钟前
PyTorch与OpenCV 计算机视觉实战指南(入门篇)
pytorch·opencv·计算机视觉
合合技术团队7 分钟前
论文解读-潜在思维链推理的全面综述
大数据·人工智能·深度学习·大模型
sivdead8 分钟前
Agent平台消息节点输出设计思路
后端·python·agent
盼哥PyAI实验室8 分钟前
【超详细教程】Python 连接 MySQL 全流程实战
python·mysql·oracle