pytorch tril 用法并导出onnx demo

import torch

import torch.nn as nn

import onnxruntime as ort

import numpy as np

def create_tril_onnx():

class SimpleNet(nn.Module):

def init(self):

super(SimpleNet, self).init()

self.data1 = torch.ones((2,3), dtype=torch.bool)

def forward(self, x):

tril_x = torch.tril(x)

tril_x = tril_x.float()

x1 = x.float()

return tril_x+x1

model = SimpleNet()

data = torch.ones((2,3), dtype=torch.bool)

output = model(data)

print("output:")

print(output)

torch.onnx.export(model, data, "tril.onnx", input_names=["input"], output_names=["output"])

def inference_onnx():

model = ort.InferenceSession("tril.onnx", provider=["CPUExecutionProvider"])

outputs = model.run(["output"], {"input":np.random.randn(2,3).astype(np.bool_)})

print("outputs:", outputs)

def my_tril():

key_size = 5

data = torch.ones((key_size,key_size), dtype=torch.bool)

for i in range(key_size):

print("\n")

print(i)

print(data[i,i+1:])

data[i,i+1:] = False

print(data)

print(data)

def main():

create_tril_onnx()

inference_onnx()

my_tril()

if name == "main":

main()


导出onnx如下:

相关推荐
七月shi人3 小时前
【AI编程工具IDE/CLI/插件专栏】-国外IDE与Cursor能力对比
ide·人工智能·ai编程·代码助手
橙 子_4 小时前
基于 Amazon Nova Sonic 和 MCP 构建语音交互 Agent
python
2zcode5 小时前
基于Matlab的深度学习智能行人检测与统计系统
人工智能·深度学习·目标跟踪
宇寒风暖5 小时前
Flask 框架全面详解
笔记·后端·python·学习·flask·知识
哪 吒6 小时前
【2025C卷】华为OD机试九日集训第3期 - 按算法分类,由易到难,提升编程能力和解题技巧
python·算法·华为od·华为od机试·2025c卷
weixin_464078076 小时前
机器学习sklearn:过滤
人工智能·机器学习·sklearn
weixin_464078076 小时前
机器学习sklearn:降维
人工智能·机器学习·sklearn
数据与人工智能律师6 小时前
智能合约漏洞导致的损失,法律责任应如何分配
大数据·网络·人工智能·算法·区块链
张艾拉 Fun AI Everyday6 小时前
小宿科技:AI Agent 的卖铲人
人工智能·aigc·创业创新·ai-native
zhongqu_3dnest6 小时前
三维火灾调查重建:科技赋能,探寻真相
人工智能