pytorch为自己的extension backend添加profiler功能
- 1.参考文档
- 2.your-extension-for-pytorch需要增加的代码
- [3.pytorch demo及如何调整chrome trace json文件](#3.pytorch demo及如何调整chrome trace json文件)
- 4.[可视化](https://ui.perfetto.dev/)
本文演示了pytorch如何为自己的extension backend添加profiler功能
背景介绍
- 1.没有CNLight、Profiling AscendCL API、ROC Trace之类Profing功能,无法trace runtime,drive,kernel,也无法获取设备的metrics
- 2.只有event功能,可以统计kernel耗时
- 3.本文只是一种尝试,并不合理.
- 4.torch原生的profiler框架,依赖kineto,kineto目前支持CUPTI和ROC Tracer,如果不修改torch源码,第三方设备不方便使用
- 5.华为、寒武纪、habana都是采用torch.profile的接口形式及at::addThreadLocalCallback功能,但不依赖torch.profiler框架
profing原始数据都是私有格式,并且修改TensorBoard的插件,可于可视化
实施步骤
- 1.调用torch::profiler::impl::registerPrivateUse1Methods注册
- 2.因为没有correlation ID去关联host api与kernel,因此export_chrome_trace出来的数据没有kernel信息
- 3.获取prof.profiler.function_events里的数据,通过{ev.name}{ev.id}{ev.thread}拼成uuid与上面chrome trace中的events关联
- 4.因为只有一个stream。可以根据Host lanuch时间、kernel耗时、launch latency(先验),推断出kernel的开始、结束时间,并用flow event进行关联(虽然并不准确)
- 5.最后把kernel event以及flow event追加到chrome trace中
1.参考文档
- ROC Tracer
- CUPTI
- 华为profiler_npu
- Profiling AscendCL API
- 寒武纪profile_mlu
- 寒武纪CNLight
- habana torch
- intel_extension_for_pytorch
- Make the kineto extendable for other runtime than CUD
- pytorch_open_registration_example
- rename_privateuse1_backend
- Trace Event Format
2.your-extension-for-pytorch需要增加的代码
c
#include <torch/csrc/profiler/stubs/base.h>
#include <torch/csrc/profiler/util.h>
#include <c10/util/irange.h>
#include <torch/csrc/profiler/stubs/base.h>
#include <torch/csrc/profiler/util.h>
using torch::profiler::impl::ProfilerStubs;
using torch::profiler::impl::ProfilerVoidEventStub;
namespace torch {
namespace profiler {
namespace impl {
struct NPUMethods : public ProfilerStubs {
void record(
int* device,
ProfilerVoidEventStub* event,
int64_t* cpu_ns) const override
{
if (device) {
TORCH_CHECK(xpurtGetDevice((uint32_t*)device));
}
xpurtEvent_t xpurt_event;
TORCH_CHECK(xpurtEventCreate(&xpurt_event));
*event = std::shared_ptr<void>(xpurt_event, [](xpurtEvent_t ptr) {
TORCH_CHECK(xpurtEventDestroy(ptr));
});
auto xpurt_stream = c10::xpu::getCurrentxpuStream(vastai::get_device());
if (cpu_ns) {
*cpu_ns = getTime();
}
TORCH_CHECK(xpurtEventRecord(xpurt_event, xpurt_stream));
}
float elapsed(
const ProfilerVoidEventStub* event1_,
const ProfilerVoidEventStub* event2_) const override
{
auto event1 = static_cast<xpurtEvent_t>(event1_->get());
TORCH_CHECK(xpurtEventSynchronize(event1));
auto event2 = static_cast<xpurtEvent_t>(event2_->get());
TORCH_CHECK(xpurtEventSynchronize(event2));
int64_t time_ms = 0;
TORCH_CHECK(xpurtEventElapsedTime(&time_ms, event1, event2));
return time_ms*1.0;
}
void onEachDevice(std::function<void(int)> op) const override
{
uint32_t device = 0;
TORCH_CHECK(xpurtGetDevice(&device));
op(device);
}
void synchronize() const override { }
bool enabled() const override {return true;}
void mark(const char*name) const override { }
void rangePush(const char*name) const override { }
void rangePop() const override {}
};
struct RegisterNPUMethods {
RegisterNPUMethods()
{
static NPUMethods methods;
torch::profiler::impl::registerPrivateUse1Methods(&methods);
}
};
RegisterNPUMethods reg;
}}}
3.pytorch demo及如何调整chrome trace json文件
python
import time
import torchvision.models as models
from torch import nn
import torch.nn.functional as F
import copy
import math
import torch
from torch.profiler import profile
import json
import tqdm
def is_valid_kernel(name,duration,valid_kernel_threshold=100):
'''通过算子的名字和耗时判断是否是Device Kernel'''
invalid_kernels=["aten::view","aten::reshape",
"aten::t","aten::empty",
"aten::transpose",
"aten::as_strided",
"aten::item",
"aten::_local_scalar_dense",
"aten::result_type",
"aten::_unsafe_view",
"aten::expand"]
for k in invalid_kernels:
if name.find(k)>=0:
return False
if duration<valid_kernel_threshold:
return False
return True
def filter_ev(ev):
'''过滤Kernel'''
if 'args' in ev and "External id" in ev['args']:
return True
return False
def get_uuid(ev,tid_map):
return f"{ev['name']}_{ev['args']['External id']}_{tid_map[ev['tid']]}"
def get_valid_kernels(traceEvents,kernel_event,tid_map):
valid_kernels=[]
device_memory_usage=0
for ev in traceEvents:
if filter_ev(ev):
uuid=get_uuid(ev,tid_map)
if uuid not in kernel_event:
continue
duration=kernel_event[uuid]['kernel_time']
kernel_name=ev['name']
if kernel_event[uuid]['device_memory_usage']>0:
device_memory_usage=kernel_event[uuid]['device_memory_usage']
if is_valid_kernel(kernel_name,duration):
launch_beg=ev['ts']
launch_end=ev['ts']+ev['dur']
valid_kernels.append({"name":kernel_name,
"launch_beg":launch_beg,
"launch_end":launch_end,
"kernel_duration":duration,
"host_pid":ev['pid'],
"host_tid":ev['tid'],
"device_memory_usage":device_memory_usage,
"is_leaf_kernel":False})
return sorted(valid_kernels,key=lambda x:x['launch_beg'])
def is_leaf_kernel(kernel,valid_kernels):
'''判断是否是叶子Kernel'''
ret=True
for k in valid_kernels:
if k['is_leaf_kernel']:
continue
#自己的时间跨度内还有别的Kernel
if k['launch_beg']>kernel['launch_beg'] and k['launch_end']<kernel['launch_end']:
ret=False
break
return ret
def create_tid_map(traceEvents):
tids=set()
for ev in traceEvents:
if filter_ev(ev):
tid=ev['tid']
tids.add(tid)
tid_map={}
tids=sorted(tids,reverse=False)
for i,v in enumerate(tids):
tid_map[v]=i+1
return tid_map
def merge_prof_timeline(prof_json,kernel_event_json,output_json):
kernel_lanuch_latency=0
with open(prof_json,'r',encoding='utf-8') as f:
prof = json.load(f)
with open(kernel_event_json,'r',encoding='utf-8') as f:
kernel_event = json.load(f)
traceEvents=prof['traceEvents']
tid_map=create_tid_map(traceEvents)
print(tid_map)
#获取所有kernel
valid_kernels=get_valid_kernels(traceEvents,kernel_event,tid_map)
print(len(valid_kernels))
#筛出所有会在device上执行的kernel
on_device_kernels=[]
for kernel in tqdm.tqdm(valid_kernels):
if is_leaf_kernel(kernel,valid_kernels):
on_device_kernels.append(kernel)
kernel_start_offset=0
kernel_index=0
for kernel in on_device_kernels:
name=kernel['name']
kernel_duration=kernel["kernel_duration"]
lanuch_time=kernel["launch_beg"]
host_pid=kernel['host_pid']
host_tid=kernel['host_tid']
device_memory_usage=kernel['device_memory_usage']
if kernel_start_offset==0:
kernel_start_offset=lanuch_time+kernel_start_offset
if lanuch_time>kernel_start_offset: #kernel 队列空闲
kernel_start_offset=lanuch_time
#增加kernel事件
traceEvents.append({"ph": "X", "cat": "device_kernel", "name":name, "pid": 10, "tid": 10,"ts": kernel_start_offset, "dur": kernel_duration})
#增加内存事件
traceEvents.append({"ph": "C", "cat": "memory", "name":"memory", "pid": 11, "tid": 11,"ts": lanuch_time, "args": {"value":device_memory_usage}})
#增加flow event
traceEvents.append({"ph": "s", "id": kernel_index, "pid": host_pid, "tid": host_tid, "ts": lanuch_time,"cat": "ac2g", "name": "ac2g"})
traceEvents.append({"ph": "f", "id": kernel_index, "pid": 10, "tid": 10,"ts": kernel_start_offset,"cat": "ac2g", "name": "ac2g", "bp": "e"})
kernel_index+=1
kernel_start_offset+=(kernel_duration+kernel_lanuch_latency)
#保存最终的结果
with open(output_json,'w',encoding='utf-8') as f:
json.dump(prof, f,ensure_ascii=False,indent=4)
def clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def forward(self,query, key, value, mask=None, dropout=None):
d_k = query.size(-1)
scores = query@key.transpose(-2,-1) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e20)
p_attn = F.softmax(scores, dim = -1)
if dropout is not None:
p_attn = dropout(p_attn)
return p_attn@value, p_attn
class MultiHeadAttention(nn.Module):
def __init__(self, h, d_model, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % h == 0
self.d_k = d_model // h
self.h = h
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(p=dropout)
self.attention = ScaledDotProductAttention()
def forward(self, query, key, value, mask=None):
if mask is not None:
mask = mask.unsqueeze(1)
nbatches = query.size(0)
query=self.linears[0](query).view(nbatches, -1, self.h, self.d_k)
query=query.transpose(1, 2)
key=self.linears[1](key).view(nbatches, -1, self.h, self.d_k)
key=key.transpose(1, 2)
value=self.linears[2](value).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
x, self.attn = self.attention(query, key, value, mask=mask,
dropout=self.dropout)
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
return self.linears[-1](x)
use_cuda=True
try:
import torch_xpu
import torch_xpu.contrib.transfer_to_xpu
torch.xpu.set_device(0)
torch.profiler.ProfilerActivity.PrivateUse1="xpu"
use_cuda=False
except:
pass
import os
os.environ['LOCAL_RANK']="0"
os.environ['RANK']="0"
os.environ['WORLD_SIZE']="1"
os.environ['MASTER_ADDR']="localhost"
os.environ['MASTER_PORT']="6006"
import torch.distributed as dist
dist.init_process_group(backend='vccl')
local_rank=int(os.environ['LOCAL_RANK'])
rank=torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
if not dist.is_available() or not dist.is_initialized():
print("dist init error")
cross_attn = MultiHeadAttention(h=8, d_model=64).half().cuda()
cross_attn.eval()
q1 = torch.ones((1, 50, 64),dtype=torch.float32).half().cuda()
k1 = q1.clone()
v1 = q1.clone()
out = cross_attn.forward(q1,k1,v1).sum()
torch.cuda.synchronize()
activities=[torch.profiler.ProfilerActivity.CPU]
if use_cuda:
activities.append(torch.profiler.ProfilerActivity.CUDA)
with profile(
activities=activities,
schedule=torch.profiler.schedule(
wait=1,
warmup=1,
active=3,
repeat=1),
record_shapes=True,
with_stack=True,
with_modules=True,
with_flops=True,
profile_memory=True,
) as prof:
for i in range(10):
out = cross_attn.forward(q1,k1,v1).sum()
prof.step()
torch.cuda.synchronize()
if not use_cuda:
kernel_event={}
for ev in prof.profiler.function_events:
if ev.privateuse1_time>0:
uuid=f"{ev.name}_{ev.id}_{ev.thread}"
#print(uuid,ev.id,ev.name,ev.privateuse1_time,ev.time_range.start,ev.time_range.end-ev.time_range.start,ev.privateuse1_memory_usage)
kernel_event[uuid]={"kernel_time":ev.privateuse1_time,
"device_memory_usage":ev.privateuse1_memory_usage,
"start_us":ev.time_range.start,
"host_dur":ev.time_range.end-ev.time_range.start,
"thread":ev.thread}
import json
with open(f"kernel_event_{rank}.json",'w',encoding='utf-8') as f:
json.dump(kernel_event, f,ensure_ascii=False,indent=4)
prof.export_chrome_trace(f"prof_{rank}.json")
merge_prof_timeline(f"prof_{rank}.json",f"kernel_event_{rank}.json",f"prof_{rank}.json")
else:
#print(prof.key_averages().table(sort_by="self_cpu_time_total"))
prof.export_chrome_trace(f"prof_{q1.device.type}.json")