pytorch tril 用法并导出onnx demo

import torch

import torch.nn as nn

import onnxruntime as ort

import numpy as np

def create_tril_onnx():

class SimpleNet(nn.Module):

def init(self):

super(SimpleNet, self).init()

self.data1 = torch.ones((2,3), dtype=torch.bool)

def forward(self, x):

tril_x = torch.tril(x)

tril_x = tril_x.float()

x1 = x.float()

return tril_x+x1

model = SimpleNet()

data = torch.ones((2,3), dtype=torch.bool)

output = model(data)

print("output:")

print(output)

torch.onnx.export(model, data, "tril.onnx", input_names=["input"], output_names=["output"])

def inference_onnx():

model = ort.InferenceSession("tril.onnx", provider=["CPUExecutionProvider"])

outputs = model.run(["output"], {"input":np.random.randn(2,3).astype(np.bool_)})

print("outputs:", outputs)

def my_tril():

key_size = 5

data = torch.ones((key_size,key_size), dtype=torch.bool)

for i in range(key_size):

print("\n")

print(i)

print(data[i,i+1:])

data[i,i+1:] = False

print(data)

print(data)

def main():

create_tril_onnx()

inference_onnx()

my_tril()

if name == "main":

main()


导出onnx如下:

相关推荐
zskj_zhyl2 分钟前
数字康养新范式:七彩喜平台重构智慧养老生态的深度实践
大数据·人工智能·物联网
白码低代码10 分钟前
橡胶制品行业质检管理的痛点 质检LIMS如何重构橡胶制品质检价值链
大数据·人工智能·重构·lims·实验室管理系统
Amo Xiang18 分钟前
Python 常用内置函数详解(十):help()函数——查看对象的帮助信息
python·内置函数·help
boooo_hhh20 分钟前
第J7周:对于ResNeXt-50算法的思考
开发语言·python·深度学习
tmiger24 分钟前
图像匹配导航定位技术 第 10 章
人工智能·算法·计算机视觉
小彭律师28 分钟前
电动汽车充电设施可调能力聚合评估与预测
人工智能·深度学习·机器学习
_waylau33 分钟前
【HarmonyOS NEXT+AI】问答05:ArkTS和仓颉编程语言怎么选?
人工智能·华为·harmonyos·arkts·鸿蒙·仓颉
老实人y37 分钟前
TIME - MoE 模型代码 3.2——Time-MoE-main/time_moe/datasets/time_moe_dataset.py
人工智能·python·机器学习·icl·icp
极客智谷40 分钟前
Spring AI 系列——使用大模型对文本内容分类归纳并标签化输出
人工智能·spring·分类
夏子曦1 小时前
AI——认知建模工具:ACT-R
人工智能·机器学习·ai