pytorch tril 用法并导出onnx demo

import torch

import torch.nn as nn

import onnxruntime as ort

import numpy as np

def create_tril_onnx():

class SimpleNet(nn.Module):

def init(self):

super(SimpleNet, self).init()

self.data1 = torch.ones((2,3), dtype=torch.bool)

def forward(self, x):

tril_x = torch.tril(x)

tril_x = tril_x.float()

x1 = x.float()

return tril_x+x1

model = SimpleNet()

data = torch.ones((2,3), dtype=torch.bool)

output = model(data)

print("output:")

print(output)

torch.onnx.export(model, data, "tril.onnx", input_names=["input"], output_names=["output"])

def inference_onnx():

model = ort.InferenceSession("tril.onnx", provider=["CPUExecutionProvider"])

outputs = model.run(["output"], {"input":np.random.randn(2,3).astype(np.bool_)})

print("outputs:", outputs)

def my_tril():

key_size = 5

data = torch.ones((key_size,key_size), dtype=torch.bool)

for i in range(key_size):

print("\n")

print(i)

print(data[i,i+1:])

data[i,i+1:] = False

print(data)

print(data)

def main():

create_tril_onnx()

inference_onnx()

my_tril()

if name == "main":

main()


导出onnx如下:

相关推荐
Ronin-Lotus1 小时前
深度学习篇--- ResNet-18
人工智能·深度学习·resnet
说私域2 小时前
基于开源 AI 智能名片链动 2+1 模式 S2B2C 商城小程序的新开非连锁品牌店开业引流策略研究
人工智能·小程序·开源
moonsims3 小时前
无人机和无人系统的计算机视觉-人工智能无人机
人工智能·计算机视觉·无人机
钓了猫的鱼儿3 小时前
无人机航拍数据集|第27期 无人机交通目标检测YOLO数据集3717张yolov11/yolov8/yolov5可训练
人工智能·yolo·目标检测
tzc_fly3 小时前
rbio1:以生物学世界模型为软验证器训练科学推理大语言模型
人工智能·语言模型·自然语言处理
AndrewHZ3 小时前
【python与生活】如何用Python写一个简单的自动整理文件的脚本?
开发语言·python·生活·脚本·文件整理
北方有星辰zz4 小时前
语音识别:概念与接口
网络·人工智能·语音识别
binbinaijishu884 小时前
Python爬虫入门指南:从零开始的网络数据获取之旅
开发语言·爬虫·python·其他
阿里-于怀4 小时前
携程旅游的 AI 网关落地实践
人工智能·网关·ai·旅游·携程·higress·ai网关
赴3354 小时前
神经网络和深度学习介绍
人工智能·深度学习·反向传播