pytorch tril 用法并导出onnx demo

import torch

import torch.nn as nn

import onnxruntime as ort

import numpy as np

def create_tril_onnx():

class SimpleNet(nn.Module):

def init(self):

super(SimpleNet, self).init()

self.data1 = torch.ones((2,3), dtype=torch.bool)

def forward(self, x):

tril_x = torch.tril(x)

tril_x = tril_x.float()

x1 = x.float()

return tril_x+x1

model = SimpleNet()

data = torch.ones((2,3), dtype=torch.bool)

output = model(data)

print("output:")

print(output)

torch.onnx.export(model, data, "tril.onnx", input_names=["input"], output_names=["output"])

def inference_onnx():

model = ort.InferenceSession("tril.onnx", provider=["CPUExecutionProvider"])

outputs = model.run(["output"], {"input":np.random.randn(2,3).astype(np.bool_)})

print("outputs:", outputs)

def my_tril():

key_size = 5

data = torch.ones((key_size,key_size), dtype=torch.bool)

for i in range(key_size):

print("\n")

print(i)

print(data[i,i+1:])

data[i,i+1:] = False

print(data)

print(data)

def main():

create_tril_onnx()

inference_onnx()

my_tril()

if name == "main":

main()


导出onnx如下:

相关推荐
l12345sy12 小时前
Day24_【深度学习—广播机制】
人工智能·pytorch·深度学习·广播机制
IT古董12 小时前
【第五章:计算机视觉-项目实战之图像分类实战】1.经典卷积神经网络模型Backbone与图像-(4)经典卷积神经网络ResNet的架构讲解
人工智能·计算机视觉·cnn
向往鹰的翱翔12 小时前
BKY莱德因:5大黑科技逆转时光
大数据·人工智能·科技·生活·健康医疗
L.fountain12 小时前
机器学习shap分析案例
人工智能·机器学习
weixin_4296302612 小时前
机器学习-第一章
人工智能·机器学习
Cedric111312 小时前
机器学习中的距离总结
人工智能·机器学习
大模型真好玩12 小时前
深入浅出LangGraph AI Agent智能体开发教程(五)—LangGraph 数据分析助手智能体项目实战
人工智能·python·mcp
测试老哥12 小时前
Selenium 使用指南
自动化测试·软件测试·python·selenium·测试工具·职场和发展·测试用例
IT_陈寒12 小时前
React性能优化:这5个Hook技巧让我的组件渲染效率提升50%(附代码对比)
前端·人工智能·后端