pytorch tril 用法并导出onnx demo

import torch

import torch.nn as nn

import onnxruntime as ort

import numpy as np

def create_tril_onnx():

class SimpleNet(nn.Module):

def init(self):

super(SimpleNet, self).init()

self.data1 = torch.ones((2,3), dtype=torch.bool)

def forward(self, x):

tril_x = torch.tril(x)

tril_x = tril_x.float()

x1 = x.float()

return tril_x+x1

model = SimpleNet()

data = torch.ones((2,3), dtype=torch.bool)

output = model(data)

print("output:")

print(output)

torch.onnx.export(model, data, "tril.onnx", input_names=["input"], output_names=["output"])

def inference_onnx():

model = ort.InferenceSession("tril.onnx", provider=["CPUExecutionProvider"])

outputs = model.run(["output"], {"input":np.random.randn(2,3).astype(np.bool_)})

print("outputs:", outputs)

def my_tril():

key_size = 5

data = torch.ones((key_size,key_size), dtype=torch.bool)

for i in range(key_size):

print("\n")

print(i)

print(data[i,i+1:])

data[i,i+1:] = False

print(data)

print(data)

def main():

create_tril_onnx()

inference_onnx()

my_tril()

if name == "main":

main()


导出onnx如下:

相关推荐
果冻人工智能19 分钟前
2025 年将颠覆商业的 8 大 AI 应用场景
人工智能·ai员工
代码不行的搬运工21 分钟前
神经网络12-Time-Series Transformer (TST)模型
人工智能·神经网络·transformer
进击的六角龙21 分钟前
深入浅出:使用Python调用API实现智能天气预报
开发语言·python
檀越剑指大厂21 分钟前
【Python系列】浅析 Python 中的字典更新与应用场景
开发语言·python
石小石Orz23 分钟前
Three.js + AI:AI 算法生成 3D 萤火虫飞舞效果~
javascript·人工智能·算法
湫ccc29 分钟前
Python简介以及解释器安装(保姆级教学)
开发语言·python
孤独且没人爱的纸鹤32 分钟前
【深度学习】:从人工神经网络的基础原理到循环神经网络的先进技术,跨越智能算法的关键发展阶段及其未来趋势,探索技术进步与应用挑战
人工智能·python·深度学习·机器学习·ai
阿_旭35 分钟前
TensorFlow构建CNN卷积神经网络模型的基本步骤:数据处理、模型构建、模型训练
人工智能·深度学习·cnn·tensorflow
羊小猪~~35 分钟前
tensorflow案例7--数据增强与测试集, 训练集, 验证集的构建
人工智能·python·深度学习·机器学习·cnn·tensorflow·neo4j
lzhlizihang38 分钟前
python如何使用spark操作hive
hive·python·spark