pytorch tril 用法并导出onnx demo

import torch

import torch.nn as nn

import onnxruntime as ort

import numpy as np

def create_tril_onnx():

class SimpleNet(nn.Module):

def init(self):

super(SimpleNet, self).init()

self.data1 = torch.ones((2,3), dtype=torch.bool)

def forward(self, x):

tril_x = torch.tril(x)

tril_x = tril_x.float()

x1 = x.float()

return tril_x+x1

model = SimpleNet()

data = torch.ones((2,3), dtype=torch.bool)

output = model(data)

print("output:")

print(output)

torch.onnx.export(model, data, "tril.onnx", input_names="input", output_names="output")

def inference_onnx():

model = ort.InferenceSession("tril.onnx", provider="CPUExecutionProvider")

outputs = model.run("output", {"input":np.random.randn(2,3).astype(np.bool_)})

print("outputs:", outputs)

def my_tril():

key_size = 5

data = torch.ones((key_size,key_size), dtype=torch.bool)

for i in range(key_size):

print("\n")

print(i)

print(datai,i+1:)

datai,i+1: = False

print(data)

print(data)

def main():

create_tril_onnx()

inference_onnx()

my_tril()

if name == "main":

main()


导出onnx如下:

相关推荐
小杨在厦门9 分钟前
从AI验布到智能质检:纺织企业智能化升级的三个台阶
人工智能·服装·服装厂·服装机械·铺布机
达之云*驭影10 分钟前
解锁流量密码:详解抖音AI智能推荐封面功能
人工智能
lpd_lt19 分钟前
AI Coding的常用Prompt技巧
python·ai·ai编程
小江的记录本21 分钟前
【JVM虚拟机】堆内存分代模型:年轻代(Eden+Survivor)、老年代、元空间Metaspace(附《思维导图》+《面试高频考点清单》)
java·前端·jvm·后端·python·spring·面试
火山引擎开发者社区22 分钟前
ArkClaw 投研助理 —— 零门槛做投研,从一句话开始产出你的第一份深度研报
人工智能
在繁华处24 分钟前
Java从零到熟练(三):流程控制
java·开发语言·python
码农小白AI26 分钟前
AI报告审核加速融入自动化实验室:IACheck破解智能设备时代报告管理新挑战
运维·人工智能·自动化
xingyuzhisuan32 分钟前
自建聚合网关VS第三方聚合平台,适配场景与数据实测
人工智能·ai·云计算·oneapi
tedcloud12334 分钟前
DeepSeek-TUI部署教程:打造CLI AI助手环境
服务器·人工智能·word·excel·dreamweaver
EnCi Zheng39 分钟前
09b-斯坦福CS336作业一-Transformer语言模型
人工智能