pytorch tril 用法并导出onnx demo

import torch

import torch.nn as nn

import onnxruntime as ort

import numpy as np

def create_tril_onnx():

class SimpleNet(nn.Module):

def init(self):

super(SimpleNet, self).init()

self.data1 = torch.ones((2,3), dtype=torch.bool)

def forward(self, x):

tril_x = torch.tril(x)

tril_x = tril_x.float()

x1 = x.float()

return tril_x+x1

model = SimpleNet()

data = torch.ones((2,3), dtype=torch.bool)

output = model(data)

print("output:")

print(output)

torch.onnx.export(model, data, "tril.onnx", input_names=["input"], output_names=["output"])

def inference_onnx():

model = ort.InferenceSession("tril.onnx", provider=["CPUExecutionProvider"])

outputs = model.run(["output"], {"input":np.random.randn(2,3).astype(np.bool_)})

print("outputs:", outputs)

def my_tril():

key_size = 5

data = torch.ones((key_size,key_size), dtype=torch.bool)

for i in range(key_size):

print("\n")

print(i)

print(data[i,i+1:])

data[i,i+1:] = False

print(data)

print(data)

def main():

create_tril_onnx()

inference_onnx()

my_tril()

if name == "main":

main()


导出onnx如下:

相关推荐
luoganttcc2 分钟前
除了视觉伺服 还有哪些 方法
人工智能
ST小智3 分钟前
2025年创作历程回顾与个人生活平衡
大数据·linux·人工智能
NiceAsiv6 分钟前
VSCode之打开python终端 取消conda activate的powershell弹窗
vscode·python·conda
weixin_437988129 分钟前
范式智能发布“风控哨兵”大模型 引领金融风控新范式
人工智能
哥本哈士奇10 分钟前
使用Gradio构建AI前端 - RAG的QA模块
前端·人工智能·状态模式
5G全域通13 分钟前
面向5G复杂性的下一代运维技术体系:架构、工具与实践
大数据·运维·人工智能·5g·架构
你们补药再卷啦15 分钟前
人工智能算法概览
人工智能·算法
蔚说17 分钟前
is 与 == 的区别 python
python
悟纤18 分钟前
续写卡在 2 秒?解决方案全解析|Suno 进阶指南|第 13 篇
人工智能·suno·suno ai·suno api·ai music
cnxy18821 分钟前
围棋对弈Python程序开发完整指南:步骤3 - 气(Liberties)的计算算法设计
python·算法·深度优先