import torch
import torch.nn as nn
import onnxruntime as ort
import numpy as np
def create_tril_onnx():
class SimpleNet(nn.Module):
def init(self):
super(SimpleNet, self).init()
self.data1 = torch.ones((2,3), dtype=torch.bool)
def forward(self, x):
tril_x = torch.tril(x)
tril_x = tril_x.float()
x1 = x.float()
return tril_x+x1
model = SimpleNet()
data = torch.ones((2,3), dtype=torch.bool)
output = model(data)
print("output:")
print(output)
torch.onnx.export(model, data, "tril.onnx", input_names=["input"], output_names=["output"])
def inference_onnx():
model = ort.InferenceSession("tril.onnx", provider=["CPUExecutionProvider"])
outputs = model.run(["output"], {"input":np.random.randn(2,3).astype(np.bool_)})
print("outputs:", outputs)
def my_tril():
key_size = 5
data = torch.ones((key_size,key_size), dtype=torch.bool)
for i in range(key_size):
print("\n")
print(i)
print(data[i,i+1:])
data[i,i+1:] = False
print(data)
print(data)
def main():
create_tril_onnx()
inference_onnx()
my_tril()
if name == "main":
main()
导出onnx如下: