pytorch tril 用法并导出onnx demo

import torch

import torch.nn as nn

import onnxruntime as ort

import numpy as np

def create_tril_onnx():

class SimpleNet(nn.Module):

def init(self):

super(SimpleNet, self).init()

self.data1 = torch.ones((2,3), dtype=torch.bool)

def forward(self, x):

tril_x = torch.tril(x)

tril_x = tril_x.float()

x1 = x.float()

return tril_x+x1

model = SimpleNet()

data = torch.ones((2,3), dtype=torch.bool)

output = model(data)

print("output:")

print(output)

torch.onnx.export(model, data, "tril.onnx", input_names=["input"], output_names=["output"])

def inference_onnx():

model = ort.InferenceSession("tril.onnx", provider=["CPUExecutionProvider"])

outputs = model.run(["output"], {"input":np.random.randn(2,3).astype(np.bool_)})

print("outputs:", outputs)

def my_tril():

key_size = 5

data = torch.ones((key_size,key_size), dtype=torch.bool)

for i in range(key_size):

print("\n")

print(i)

print(data[i,i+1:])

data[i,i+1:] = False

print(data)

print(data)

def main():

create_tril_onnx()

inference_onnx()

my_tril()

if name == "main":

main()


导出onnx如下:

相关推荐
小王子10249 分钟前
设计模式Python版 组合模式
python·设计模式·组合模式
kakaZhui12 分钟前
【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE
人工智能·深度学习·chatgpt·aigc·llama
struggle20251 小时前
一个开源 GenBI AI 本地代理(确保本地数据安全),使数据驱动型团队能够与其数据进行互动,生成文本到 SQL、图表、电子表格、报告和 BI
人工智能·深度学习·目标检测·语言模型·自然语言处理·数据挖掘·集成学习
佛州小李哥1 小时前
通过亚马逊云科技Bedrock打造自定义AI智能体Agent(上)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技
Mason Lin1 小时前
2025年1月22日(网络编程 udp)
网络·python·udp
清弦墨客2 小时前
【蓝桥杯】43697.机器人塔
python·蓝桥杯·程序算法
云空2 小时前
《DeepSeek 网页/API 性能异常(DeepSeek Web/API Degraded Performance):网络安全日志》
运维·人工智能·web安全·网络安全·开源·网络攻击模型·安全威胁分析
AIGC大时代2 小时前
对比DeepSeek、ChatGPT和Kimi的学术写作关键词提取能力
论文阅读·人工智能·chatgpt·数据分析·prompt
山晨啊83 小时前
2025年美赛B题-结合Logistic阻滞增长模型和SIR传染病模型研究旅游可持续性-成品论文
人工智能·机器学习
RZer4 小时前
Hypium+python鸿蒙原生自动化安装配置
python·自动化·harmonyos