pytorch为自己的extension backend添加profiler功能

pytorch为自己的extension backend添加profiler功能

本文演示了pytorch如何为自己的extension backend添加profiler功能
背景介绍

  • 1.没有CNLight、Profiling AscendCL API、ROC Trace之类Profing功能,无法trace runtime,drive,kernel,也无法获取设备的metrics
  • 2.只有event功能,可以统计kernel耗时
  • 3.本文只是一种尝试,并不合理.
  • 4.torch原生的profiler框架,依赖kineto,kineto目前支持CUPTI和ROC Tracer,如果不修改torch源码,第三方设备不方便使用
  • 5.华为、寒武纪、habana都是采用torch.profile的接口形式及at::addThreadLocalCallback功能,但不依赖torch.profiler框架
    profing原始数据都是私有格式,并且修改TensorBoard的插件,可于可视化

实施步骤

  • 1.调用torch::profiler::impl::registerPrivateUse1Methods注册
  • 2.因为没有correlation ID去关联host api与kernel,因此export_chrome_trace出来的数据没有kernel信息
  • 3.获取prof.profiler.function_events里的数据,通过{ev.name}{ev.id}{ev.thread}拼成uuid与上面chrome trace中的events关联
  • 4.因为只有一个stream。可以根据Host lanuch时间、kernel耗时、launch latency(先验),推断出kernel的开始、结束时间,并用flow event进行关联(虽然并不准确)
  • 5.最后把kernel event以及flow event追加到chrome trace中

1.参考文档

2.your-extension-for-pytorch需要增加的代码

c 复制代码
#include <torch/csrc/profiler/stubs/base.h>
#include <torch/csrc/profiler/util.h>
#include <c10/util/irange.h>
#include <torch/csrc/profiler/stubs/base.h>
#include <torch/csrc/profiler/util.h>
 
using torch::profiler::impl::ProfilerStubs;
using torch::profiler::impl::ProfilerVoidEventStub;
  
namespace torch {
namespace profiler {
namespace impl {
 
struct NPUMethods : public ProfilerStubs {
   void record(
        int* device,
        ProfilerVoidEventStub* event,
        int64_t* cpu_ns) const override
    {
      if (device) {
          TORCH_CHECK(xpurtGetDevice((uint32_t*)device));
      }
      xpurtEvent_t xpurt_event;
      TORCH_CHECK(xpurtEventCreate(&xpurt_event));
      *event = std::shared_ptr<void>(xpurt_event, [](xpurtEvent_t ptr) {
          TORCH_CHECK(xpurtEventDestroy(ptr));
      });
      auto xpurt_stream = c10::xpu::getCurrentxpuStream(vastai::get_device());
      if (cpu_ns) {
          *cpu_ns = getTime();
      }
      TORCH_CHECK(xpurtEventRecord(xpurt_event, xpurt_stream)); 
    } 
    float elapsed(
        const ProfilerVoidEventStub* event1_,
        const ProfilerVoidEventStub* event2_) const override
    {
 
        auto event1 = static_cast<xpurtEvent_t>(event1_->get());
        TORCH_CHECK(xpurtEventSynchronize(event1));
        auto event2 = static_cast<xpurtEvent_t>(event2_->get());
        TORCH_CHECK(xpurtEventSynchronize(event2));
        int64_t time_ms = 0;
        TORCH_CHECK(xpurtEventElapsedTime(&time_ms, event1, event2));
        return time_ms*1.0;
    } 
    void onEachDevice(std::function<void(int)> op) const override
    {
        uint32_t device = 0;
        TORCH_CHECK(xpurtGetDevice(&device));
        op(device);
    } 
    void synchronize() const override { } 
    bool enabled() const override {return true;} 
    void mark(const char*name) const override { } 
    void rangePush(const char*name) const override { } 
    void rangePop() const override {}
};
 
struct RegisterNPUMethods {
    RegisterNPUMethods()
    {
        static NPUMethods methods;
        torch::profiler::impl::registerPrivateUse1Methods(&methods);
    }
};
RegisterNPUMethods reg;
}}}

3.pytorch demo及如何调整chrome trace json文件

python 复制代码
import time
import torchvision.models as models
from torch import nn
import torch.nn.functional as F
import copy
import math
import torch
from torch.profiler import profile
import json
import tqdm

def is_valid_kernel(name,duration,valid_kernel_threshold=100):
    '''通过算子的名字和耗时判断是否是Device Kernel'''
    invalid_kernels=["aten::view","aten::reshape",
                    "aten::t","aten::empty",
                    "aten::transpose",
                    "aten::as_strided",
                    "aten::item",
                    "aten::_local_scalar_dense",
                    "aten::result_type",
                    "aten::_unsafe_view",
                    "aten::expand"]
    for k in invalid_kernels:
        if name.find(k)>=0:
            return False
    if duration<valid_kernel_threshold:
        return False    
    return True

def filter_ev(ev):
    '''过滤Kernel'''
    if 'args' in ev and "External id" in ev['args']:
        return True
    return False

def get_uuid(ev,tid_map):
    return f"{ev['name']}_{ev['args']['External id']}_{tid_map[ev['tid']]}"

def get_valid_kernels(traceEvents,kernel_event,tid_map):
    valid_kernels=[]
    device_memory_usage=0
    for ev in traceEvents:
        if filter_ev(ev):
            uuid=get_uuid(ev,tid_map)
            if uuid not in kernel_event:
                continue
            duration=kernel_event[uuid]['kernel_time']
            kernel_name=ev['name']
            if kernel_event[uuid]['device_memory_usage']>0:
                device_memory_usage=kernel_event[uuid]['device_memory_usage']
            if is_valid_kernel(kernel_name,duration):
                launch_beg=ev['ts']
                launch_end=ev['ts']+ev['dur']            
                valid_kernels.append({"name":kernel_name,
                                      "launch_beg":launch_beg,
                                      "launch_end":launch_end,
                                      "kernel_duration":duration,
                                      "host_pid":ev['pid'],
                                      "host_tid":ev['tid'],
                                      "device_memory_usage":device_memory_usage,
                                      "is_leaf_kernel":False})
                                      
    return sorted(valid_kernels,key=lambda x:x['launch_beg'])
    
def is_leaf_kernel(kernel,valid_kernels):
    '''判断是否是叶子Kernel'''
    ret=True
    for k in valid_kernels:
        if k['is_leaf_kernel']:
            continue
        #自己的时间跨度内还有别的Kernel
        if k['launch_beg']>kernel['launch_beg'] and k['launch_end']<kernel['launch_end']:
            ret=False
            break
    return ret

def create_tid_map(traceEvents):
    tids=set()
    for ev in traceEvents:
        if filter_ev(ev):
            tid=ev['tid']
            tids.add(tid)
    tid_map={}
    tids=sorted(tids,reverse=False)
    for i,v in enumerate(tids):
        tid_map[v]=i+1
    return tid_map
                                      
def merge_prof_timeline(prof_json,kernel_event_json,output_json):
    
    kernel_lanuch_latency=0
    with open(prof_json,'r',encoding='utf-8') as f:
        prof = json.load(f)

    with open(kernel_event_json,'r',encoding='utf-8') as f:
        kernel_event = json.load(f)   
    
    traceEvents=prof['traceEvents']
    tid_map=create_tid_map(traceEvents)
    print(tid_map)
    #获取所有kernel
    valid_kernels=get_valid_kernels(traceEvents,kernel_event,tid_map)
    print(len(valid_kernels))
    #筛出所有会在device上执行的kernel
    on_device_kernels=[]
    for kernel in tqdm.tqdm(valid_kernels):
        if is_leaf_kernel(kernel,valid_kernels):
            on_device_kernels.append(kernel)
    
    kernel_start_offset=0
    kernel_index=0

    for kernel in on_device_kernels:
        name=kernel['name']
        kernel_duration=kernel["kernel_duration"]
        lanuch_time=kernel["launch_beg"]
        host_pid=kernel['host_pid']
        host_tid=kernel['host_tid']
        device_memory_usage=kernel['device_memory_usage']
        
        if kernel_start_offset==0:
            kernel_start_offset=lanuch_time+kernel_start_offset
            
        if lanuch_time>kernel_start_offset: #kernel 队列空闲
            kernel_start_offset=lanuch_time
        
        #增加kernel事件
        traceEvents.append({"ph": "X", "cat": "device_kernel", "name":name, "pid": 10, "tid": 10,"ts": kernel_start_offset, "dur": kernel_duration})
        
        #增加内存事件
        traceEvents.append({"ph": "C", "cat": "memory", "name":"memory", "pid": 11, "tid": 11,"ts": lanuch_time, "args": {"value":device_memory_usage}})
        
        #增加flow event
        traceEvents.append({"ph": "s", "id": kernel_index, "pid": host_pid, "tid": host_tid, "ts": lanuch_time,"cat": "ac2g", "name": "ac2g"})
        traceEvents.append({"ph": "f", "id": kernel_index, "pid": 10,  "tid": 10,"ts": kernel_start_offset,"cat": "ac2g", "name": "ac2g", "bp": "e"})
        
        kernel_index+=1
        kernel_start_offset+=(kernel_duration+kernel_lanuch_latency)
    
    #保存最终的结果
    with open(output_json,'w',encoding='utf-8') as f:
        json.dump(prof, f,ensure_ascii=False,indent=4)
		
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
 
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()
 
    def forward(self,query, key, value, mask=None, dropout=None):
        d_k = query.size(-1)
        scores = query@key.transpose(-2,-1) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e20)
        p_attn = F.softmax(scores, dim = -1)
        if dropout is not None:
            p_attn = dropout(p_attn)
        return p_attn@value, p_attn
 
class MultiHeadAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        self.attention = ScaledDotProductAttention()
 
    def forward(self, query, key, value, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        query=self.linears[0](query).view(nbatches, -1, self.h, self.d_k)
        query=query.transpose(1, 2)
        key=self.linears[1](key).view(nbatches, -1, self.h, self.d_k)
        key=key.transpose(1, 2)
        value=self.linears[2](value).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        x, self.attn = self.attention(query, key, value, mask=mask,
                                 dropout=self.dropout)
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)
 
use_cuda=True
try:
    import torch_xpu
    import torch_xpu.contrib.transfer_to_xpu
    torch.xpu.set_device(0)
    torch.profiler.ProfilerActivity.PrivateUse1="xpu"
    use_cuda=False
except:
    pass
 
import os
os.environ['LOCAL_RANK']="0"
os.environ['RANK']="0"
os.environ['WORLD_SIZE']="1"
os.environ['MASTER_ADDR']="localhost"
os.environ['MASTER_PORT']="6006"

import torch.distributed as dist
dist.init_process_group(backend='vccl')
local_rank=int(os.environ['LOCAL_RANK'])
rank=torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
if not dist.is_available() or not dist.is_initialized():
    print("dist init error")
 
cross_attn = MultiHeadAttention(h=8, d_model=64).half().cuda()
cross_attn.eval()
q1 = torch.ones((1, 50, 64),dtype=torch.float32).half().cuda()
k1 = q1.clone()
v1 = q1.clone()
out = cross_attn.forward(q1,k1,v1).sum()
torch.cuda.synchronize()
 
activities=[torch.profiler.ProfilerActivity.CPU]
if use_cuda:
    activities.append(torch.profiler.ProfilerActivity.CUDA)
 
with profile(
    activities=activities,
    schedule=torch.profiler.schedule(
                wait=1,
                warmup=1,
                active=3,
                repeat=1),
    record_shapes=True,
    with_stack=True,
    with_modules=True,
    with_flops=True,
    profile_memory=True,
   ) as prof:
        for i in range(10):
            out = cross_attn.forward(q1,k1,v1).sum()
            prof.step()
        torch.cuda.synchronize()
 
if not use_cuda:
    kernel_event={}
    for ev in prof.profiler.function_events:
        if ev.privateuse1_time>0:
            uuid=f"{ev.name}_{ev.id}_{ev.thread}"
            #print(uuid,ev.id,ev.name,ev.privateuse1_time,ev.time_range.start,ev.time_range.end-ev.time_range.start,ev.privateuse1_memory_usage)
            kernel_event[uuid]={"kernel_time":ev.privateuse1_time,
								"device_memory_usage":ev.privateuse1_memory_usage,
								"start_us":ev.time_range.start,
								"host_dur":ev.time_range.end-ev.time_range.start,
								"thread":ev.thread} 
    import json
    with open(f"kernel_event_{rank}.json",'w',encoding='utf-8') as f:
        json.dump(kernel_event, f,ensure_ascii=False,indent=4)

    prof.export_chrome_trace(f"prof_{rank}.json")
    merge_prof_timeline(f"prof_{rank}.json",f"kernel_event_{rank}.json",f"prof_{rank}.json")
else:
    #print(prof.key_averages().table(sort_by="self_cpu_time_total"))
    prof.export_chrome_trace(f"prof_{q1.device.type}.json")

4.可视化

相关推荐
小二·12 分钟前
java基础面试题笔记(基础篇)
java·笔记·python
小喵要摸鱼2 小时前
Python 神经网络项目常用语法
python
一念之坤3 小时前
零基础学Python之数据结构 -- 01篇
数据结构·python
wxl7812273 小时前
如何使用本地大模型做数据分析
python·数据挖掘·数据分析·代码解释器
NoneCoder3 小时前
Python入门(12)--数据处理
开发语言·python
ZHOU_WUYI4 小时前
3.langchain中的prompt模板 (few shot examples in chat models)
人工智能·langchain·prompt
如若1234 小时前
主要用于图像的颜色提取、替换以及区域修改
人工智能·opencv·计算机视觉
老艾的AI世界4 小时前
AI翻唱神器,一键用你喜欢的歌手翻唱他人的曲目(附下载链接)
人工智能·深度学习·神经网络·机器学习·ai·ai翻唱·ai唱歌·ai歌曲
DK221514 小时前
机器学习系列----关联分析
人工智能·机器学习
Robot2514 小时前
Figure 02迎重大升级!!人形机器人独角兽[Figure AI]商业化加速
人工智能·机器人·微信公众平台