基于picklerpc的pytorch单算子测试[单算子远程测试]
通过主流大模型测试程序-用于导出算子列表 得到了算子类型及参数信息。我希望对比每个算子在不同硬件平台上的性能和误差。如果将所有的结果都存成文件,则占用空间太大。下文演示了如何使用picklerpc 将算子类型及参数传递到远程服务器测试
一.服务端
python
from picklerpc import PickleRPCServer
import torch
import numpy as np
import time
import traceback
class TorchOpRunner(PickleRPCServer):
def __init__(self, addr=('localhost', 8080)):
super().__init__(addr)
def run(self,op_type,input_desc):
input_args=[]
input_kwargs={}
for arg in input_desc:
seed,shape,dtype,device=arg
torch.random.manual_seed(seed)
input_args.append(torch.rand(shape,dtype=dtype,device=device))
op=eval(f"torch.ops.{op_type}")
warmup_count=1
test_count=3
record={}
record["error"]=0
try:
for _ in range(warmup_count):
output=op(*input_args,**input_kwargs)
torch.cuda.synchronize()
t0=time.time()
for _ in range(test_count):
output=op(*input_args,**input_kwargs)
torch.cuda.synchronize()
t1=time.time()
latency=(t1-t0)/test_count
all=[]
if isinstance(output,torch.Tensor):
all.append(output.detach().cpu().float().numpy().reshape(-1))
elif isinstance(output,list) or isinstance(output,tuple):
for out in output:
if isinstance(out,torch.Tensor):
all.append(out.detach().cpu().float().numpy().reshape(-1))
else:
print("error type:",type(output))
record["error"]=3
if len(all)!=0 and record["error"]==0:
all=np.concatenate(all,axis=0)
if all.shape[0]>0:
record["data"]=all
else:
record["error"]=5
else:
record["error"]=4
record["latency"]=latency
return record
except:
traceback.print_exc()
record["error"]=6
return record
def raise_error(self):
"""Raise an error"""
raise NotImplementedError('Not ready')
if __name__ == '__main__':
srv = TorchOpRunner(addr=('localhost',10001))
srv.register_function(srv.run)
srv.serve_forever()
二.客户端
python
import torch
import picklerpc
def main():
op_type="aten.gelu_backward.default"
seed=0
shape=(1,512,40,128)
dtype=torch.float32
device="cuda:0"
input_desc=[(seed,shape,dtype,device),(seed,shape,dtype,device)]
client = picklerpc.PickleRPCClient(('localhost', 10001))
output=client.run(op_type,input_desc)
print(output["error"],output["data"].shape)
main()