pytorch tril 用法并导出onnx demo

import torch

import torch.nn as nn

import onnxruntime as ort

import numpy as np

def create_tril_onnx():

class SimpleNet(nn.Module):

def init(self):

super(SimpleNet, self).init()

self.data1 = torch.ones((2,3), dtype=torch.bool)

def forward(self, x):

tril_x = torch.tril(x)

tril_x = tril_x.float()

x1 = x.float()

return tril_x+x1

model = SimpleNet()

data = torch.ones((2,3), dtype=torch.bool)

output = model(data)

print("output:")

print(output)

torch.onnx.export(model, data, "tril.onnx", input_names=["input"], output_names=["output"])

def inference_onnx():

model = ort.InferenceSession("tril.onnx", provider=["CPUExecutionProvider"])

outputs = model.run(["output"], {"input":np.random.randn(2,3).astype(np.bool_)})

print("outputs:", outputs)

def my_tril():

key_size = 5

data = torch.ones((key_size,key_size), dtype=torch.bool)

for i in range(key_size):

print("\n")

print(i)

print(data[i,i+1:])

data[i,i+1:] = False

print(data)

print(data)

def main():

create_tril_onnx()

inference_onnx()

my_tril()

if name == "main":

main()


导出onnx如下:

相关推荐
Y1nhl1 小时前
搜广推校招面经六十四
人工智能·深度学习·leetcode·广告算法·推荐算法·搜索算法
禁默1 小时前
智能体开发基础:从概念到实现
人工智能·大模型·智能体
Y1nhl2 小时前
Pyspark学习一:概述
数据库·人工智能·深度学习·学习·spark·pyspark·大数据技术
维度攻城狮4 小时前
实现在Unity3D中仿真汽车,而且还能使用ros2控制
python·unity·docker·汽车·ros2·rviz2
简简单单做算法4 小时前
基于mediapipe深度学习和限定半径最近邻分类树算法的人体摔倒检测系统python源码
人工智能·python·深度学习·算法·分类·mediapipe·限定半径最近邻分类树
hvinsion5 小时前
基于PyQt5的自动化任务管理软件:高效、智能的任务调度与执行管理
开发语言·python·自动化·自动化任务管理
就决定是你啦!5 小时前
机器学习 第一章 绪论
人工智能·深度学习·机器学习
飞飞翼6 小时前
python-flask
后端·python·flask
有个人神神叨叨7 小时前
OpenAI发布的《Addendum to GPT-4o System Card: Native image generation》文件的详尽笔记
人工智能·笔记
林九生7 小时前
【Python】Browser-Use:让 AI 替你掌控浏览器,开启智能自动化新时代!
人工智能·python·自动化