pytorch为自己的extension backend添加profiler功能

pytorch为自己的extension backend添加profiler功能

本文演示了pytorch如何为自己的extension backend添加profiler功能
背景介绍

  • 1.没有CNLight、Profiling AscendCL API、ROC Trace之类Profing功能,无法trace runtime,drive,kernel,也无法获取设备的metrics
  • 2.只有event功能,可以统计kernel耗时
  • 3.本文只是一种尝试,并不合理.
  • 4.torch原生的profiler框架,依赖kineto,kineto目前支持CUPTI和ROC Tracer,如果不修改torch源码,第三方设备不方便使用
  • 5.华为、寒武纪、habana都是采用torch.profile的接口形式及at::addThreadLocalCallback功能,但不依赖torch.profiler框架
    profing原始数据都是私有格式,并且修改TensorBoard的插件,可于可视化

实施步骤

  • 1.调用torch::profiler::impl::registerPrivateUse1Methods注册
  • 2.因为没有correlation ID去关联host api与kernel,因此export_chrome_trace出来的数据没有kernel信息
  • 3.获取prof.profiler.function_events里的数据,通过{ev.name}{ev.id}{ev.thread}拼成uuid与上面chrome trace中的events关联
  • 4.因为只有一个stream。可以根据Host lanuch时间、kernel耗时、launch latency(先验),推断出kernel的开始、结束时间,并用flow event进行关联(虽然并不准确)
  • 5.最后把kernel event以及flow event追加到chrome trace中

1.参考文档

2.your-extension-for-pytorch需要增加的代码

c 复制代码
#include <torch/csrc/profiler/stubs/base.h>
#include <torch/csrc/profiler/util.h>
#include <c10/util/irange.h>
#include <torch/csrc/profiler/stubs/base.h>
#include <torch/csrc/profiler/util.h>
 
using torch::profiler::impl::ProfilerStubs;
using torch::profiler::impl::ProfilerVoidEventStub;
  
namespace torch {
namespace profiler {
namespace impl {
 
struct NPUMethods : public ProfilerStubs {
   void record(
        int* device,
        ProfilerVoidEventStub* event,
        int64_t* cpu_ns) const override
    {
      if (device) {
          TORCH_CHECK(xpurtGetDevice((uint32_t*)device));
      }
      xpurtEvent_t xpurt_event;
      TORCH_CHECK(xpurtEventCreate(&xpurt_event));
      *event = std::shared_ptr<void>(xpurt_event, [](xpurtEvent_t ptr) {
          TORCH_CHECK(xpurtEventDestroy(ptr));
      });
      auto xpurt_stream = c10::xpu::getCurrentxpuStream(vastai::get_device());
      if (cpu_ns) {
          *cpu_ns = getTime();
      }
      TORCH_CHECK(xpurtEventRecord(xpurt_event, xpurt_stream)); 
    } 
    float elapsed(
        const ProfilerVoidEventStub* event1_,
        const ProfilerVoidEventStub* event2_) const override
    {
 
        auto event1 = static_cast<xpurtEvent_t>(event1_->get());
        TORCH_CHECK(xpurtEventSynchronize(event1));
        auto event2 = static_cast<xpurtEvent_t>(event2_->get());
        TORCH_CHECK(xpurtEventSynchronize(event2));
        int64_t time_ms = 0;
        TORCH_CHECK(xpurtEventElapsedTime(&time_ms, event1, event2));
        return time_ms*1.0;
    } 
    void onEachDevice(std::function<void(int)> op) const override
    {
        uint32_t device = 0;
        TORCH_CHECK(xpurtGetDevice(&device));
        op(device);
    } 
    void synchronize() const override { } 
    bool enabled() const override {return true;} 
    void mark(const char*name) const override { } 
    void rangePush(const char*name) const override { } 
    void rangePop() const override {}
};
 
struct RegisterNPUMethods {
    RegisterNPUMethods()
    {
        static NPUMethods methods;
        torch::profiler::impl::registerPrivateUse1Methods(&methods);
    }
};
RegisterNPUMethods reg;
}}}

3.pytorch demo及如何调整chrome trace json文件

python 复制代码
import time
import torchvision.models as models
from torch import nn
import torch.nn.functional as F
import copy
import math
import torch
from torch.profiler import profile
import json
import tqdm

def is_valid_kernel(name,duration,valid_kernel_threshold=100):
    '''通过算子的名字和耗时判断是否是Device Kernel'''
    invalid_kernels=["aten::view","aten::reshape",
                    "aten::t","aten::empty",
                    "aten::transpose",
                    "aten::as_strided",
                    "aten::item",
                    "aten::_local_scalar_dense",
                    "aten::result_type",
                    "aten::_unsafe_view",
                    "aten::expand"]
    for k in invalid_kernels:
        if name.find(k)>=0:
            return False
    if duration<valid_kernel_threshold:
        return False    
    return True

def filter_ev(ev):
    '''过滤Kernel'''
    if 'args' in ev and "External id" in ev['args']:
        return True
    return False

def get_uuid(ev,tid_map):
    return f"{ev['name']}_{ev['args']['External id']}_{tid_map[ev['tid']]}"

def get_valid_kernels(traceEvents,kernel_event,tid_map):
    valid_kernels=[]
    device_memory_usage=0
    for ev in traceEvents:
        if filter_ev(ev):
            uuid=get_uuid(ev,tid_map)
            if uuid not in kernel_event:
                continue
            duration=kernel_event[uuid]['kernel_time']
            kernel_name=ev['name']
            if kernel_event[uuid]['device_memory_usage']>0:
                device_memory_usage=kernel_event[uuid]['device_memory_usage']
            if is_valid_kernel(kernel_name,duration):
                launch_beg=ev['ts']
                launch_end=ev['ts']+ev['dur']            
                valid_kernels.append({"name":kernel_name,
                                      "launch_beg":launch_beg,
                                      "launch_end":launch_end,
                                      "kernel_duration":duration,
                                      "host_pid":ev['pid'],
                                      "host_tid":ev['tid'],
                                      "device_memory_usage":device_memory_usage,
                                      "is_leaf_kernel":False})
                                      
    return sorted(valid_kernels,key=lambda x:x['launch_beg'])
    
def is_leaf_kernel(kernel,valid_kernels):
    '''判断是否是叶子Kernel'''
    ret=True
    for k in valid_kernels:
        if k['is_leaf_kernel']:
            continue
        #自己的时间跨度内还有别的Kernel
        if k['launch_beg']>kernel['launch_beg'] and k['launch_end']<kernel['launch_end']:
            ret=False
            break
    return ret

def create_tid_map(traceEvents):
    tids=set()
    for ev in traceEvents:
        if filter_ev(ev):
            tid=ev['tid']
            tids.add(tid)
    tid_map={}
    tids=sorted(tids,reverse=False)
    for i,v in enumerate(tids):
        tid_map[v]=i+1
    return tid_map
                                      
def merge_prof_timeline(prof_json,kernel_event_json,output_json):
    
    kernel_lanuch_latency=0
    with open(prof_json,'r',encoding='utf-8') as f:
        prof = json.load(f)

    with open(kernel_event_json,'r',encoding='utf-8') as f:
        kernel_event = json.load(f)   
    
    traceEvents=prof['traceEvents']
    tid_map=create_tid_map(traceEvents)
    print(tid_map)
    #获取所有kernel
    valid_kernels=get_valid_kernels(traceEvents,kernel_event,tid_map)
    print(len(valid_kernels))
    #筛出所有会在device上执行的kernel
    on_device_kernels=[]
    for kernel in tqdm.tqdm(valid_kernels):
        if is_leaf_kernel(kernel,valid_kernels):
            on_device_kernels.append(kernel)
    
    kernel_start_offset=0
    kernel_index=0

    for kernel in on_device_kernels:
        name=kernel['name']
        kernel_duration=kernel["kernel_duration"]
        lanuch_time=kernel["launch_beg"]
        host_pid=kernel['host_pid']
        host_tid=kernel['host_tid']
        device_memory_usage=kernel['device_memory_usage']
        
        if kernel_start_offset==0:
            kernel_start_offset=lanuch_time+kernel_start_offset
            
        if lanuch_time>kernel_start_offset: #kernel 队列空闲
            kernel_start_offset=lanuch_time
        
        #增加kernel事件
        traceEvents.append({"ph": "X", "cat": "device_kernel", "name":name, "pid": 10, "tid": 10,"ts": kernel_start_offset, "dur": kernel_duration})
        
        #增加内存事件
        traceEvents.append({"ph": "C", "cat": "memory", "name":"memory", "pid": 11, "tid": 11,"ts": lanuch_time, "args": {"value":device_memory_usage}})
        
        #增加flow event
        traceEvents.append({"ph": "s", "id": kernel_index, "pid": host_pid, "tid": host_tid, "ts": lanuch_time,"cat": "ac2g", "name": "ac2g"})
        traceEvents.append({"ph": "f", "id": kernel_index, "pid": 10,  "tid": 10,"ts": kernel_start_offset,"cat": "ac2g", "name": "ac2g", "bp": "e"})
        
        kernel_index+=1
        kernel_start_offset+=(kernel_duration+kernel_lanuch_latency)
    
    #保存最终的结果
    with open(output_json,'w',encoding='utf-8') as f:
        json.dump(prof, f,ensure_ascii=False,indent=4)
		
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
 
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()
 
    def forward(self,query, key, value, mask=None, dropout=None):
        d_k = query.size(-1)
        scores = query@key.transpose(-2,-1) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e20)
        p_attn = F.softmax(scores, dim = -1)
        if dropout is not None:
            p_attn = dropout(p_attn)
        return p_attn@value, p_attn
 
class MultiHeadAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        self.attention = ScaledDotProductAttention()
 
    def forward(self, query, key, value, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        query=self.linears[0](query).view(nbatches, -1, self.h, self.d_k)
        query=query.transpose(1, 2)
        key=self.linears[1](key).view(nbatches, -1, self.h, self.d_k)
        key=key.transpose(1, 2)
        value=self.linears[2](value).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        x, self.attn = self.attention(query, key, value, mask=mask,
                                 dropout=self.dropout)
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)
 
use_cuda=True
try:
    import torch_xpu
    import torch_xpu.contrib.transfer_to_xpu
    torch.xpu.set_device(0)
    torch.profiler.ProfilerActivity.PrivateUse1="xpu"
    use_cuda=False
except:
    pass
 
import os
os.environ['LOCAL_RANK']="0"
os.environ['RANK']="0"
os.environ['WORLD_SIZE']="1"
os.environ['MASTER_ADDR']="localhost"
os.environ['MASTER_PORT']="6006"

import torch.distributed as dist
dist.init_process_group(backend='vccl')
local_rank=int(os.environ['LOCAL_RANK'])
rank=torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
if not dist.is_available() or not dist.is_initialized():
    print("dist init error")
 
cross_attn = MultiHeadAttention(h=8, d_model=64).half().cuda()
cross_attn.eval()
q1 = torch.ones((1, 50, 64),dtype=torch.float32).half().cuda()
k1 = q1.clone()
v1 = q1.clone()
out = cross_attn.forward(q1,k1,v1).sum()
torch.cuda.synchronize()
 
activities=[torch.profiler.ProfilerActivity.CPU]
if use_cuda:
    activities.append(torch.profiler.ProfilerActivity.CUDA)
 
with profile(
    activities=activities,
    schedule=torch.profiler.schedule(
                wait=1,
                warmup=1,
                active=3,
                repeat=1),
    record_shapes=True,
    with_stack=True,
    with_modules=True,
    with_flops=True,
    profile_memory=True,
   ) as prof:
        for i in range(10):
            out = cross_attn.forward(q1,k1,v1).sum()
            prof.step()
        torch.cuda.synchronize()
 
if not use_cuda:
    kernel_event={}
    for ev in prof.profiler.function_events:
        if ev.privateuse1_time>0:
            uuid=f"{ev.name}_{ev.id}_{ev.thread}"
            #print(uuid,ev.id,ev.name,ev.privateuse1_time,ev.time_range.start,ev.time_range.end-ev.time_range.start,ev.privateuse1_memory_usage)
            kernel_event[uuid]={"kernel_time":ev.privateuse1_time,
								"device_memory_usage":ev.privateuse1_memory_usage,
								"start_us":ev.time_range.start,
								"host_dur":ev.time_range.end-ev.time_range.start,
								"thread":ev.thread} 
    import json
    with open(f"kernel_event_{rank}.json",'w',encoding='utf-8') as f:
        json.dump(kernel_event, f,ensure_ascii=False,indent=4)

    prof.export_chrome_trace(f"prof_{rank}.json")
    merge_prof_timeline(f"prof_{rank}.json",f"kernel_event_{rank}.json",f"prof_{rank}.json")
else:
    #print(prof.key_averages().table(sort_by="self_cpu_time_total"))
    prof.export_chrome_trace(f"prof_{q1.device.type}.json")

4.可视化

相关推荐
aaaa_a1333 小时前
The lllustrated Transformer——阅读笔记
人工智能·深度学习·transformer
jinxinyuuuus4 小时前
文件格式转换工具:数据序列化、Web Worker与离线数据处理
人工智能·自动化
言之。4 小时前
Dropbear远程连接
python
易天ETU4 小时前
短距离光模块 COB 封装与同轴工艺的区别有哪些
网络·人工智能·光模块·光通信·cob·qsfp28·100g
秋刀鱼 ..4 小时前
第二届光电科学与智能传感国际学术会议(ICOIS 2026)
运维·人工智能·科技·机器学习·制造
郭庆汝4 小时前
(九)自然语言处理笔记——命名实体的识别
人工智能·自然语言处理·命名实体识别
Oxo Security4 小时前
【AI安全】拆解 OWASP LLM Top 10 攻击架构图
人工智能·安全
Math_teacher_fan4 小时前
第二篇:核心几何工具类详解
人工智能·算法
yingxiao8884 小时前
11月海外AI应用市场:“AI轻工具”贡献最大新增;“通用型AI助手”用户留存强劲
人工智能·ai·ai应用