拦截pytorch算子,dump输入输出

拦截pytorch算子,dump输入输出

希望dump出pytorch每个算子的输入输出,但pytorch普通的hook机制只能拦截module.以下提供一种方法可以拦截torch.add,torch.Tensor.add这类算子.原理是通过模板替换,劫持torch和torch.Tensor中的算子.遍历next_functions调用register_hook拦截backward.

一.代码

python 复制代码
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch
from torch import nn
import math
import torch.nn.functional as F
from torch.autograd import Variable
import time
import os
import threading
import base64
from jinja2 import Template

device="cuda"

class Attention(nn.Module):
    def __init__(self,max_seq_len,head_dim,flash):
        super().__init__()
        self.flash = flash #hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        self.dropout=0
        self.attn_dropout = nn.Dropout(self.dropout)
        self.head_dim=head_dim
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            mask = torch.full((1, 1, max_seq_len, max_seq_len), float("-inf")).to(device)
            mask = torch.triu(mask, diagonal=1).half().to(device)
            self.register_buffer("mask", mask)		
    def forward(
            self,xq: torch.Tensor,xk: torch.Tensor,xv: torch.Tensor):
        if self.flash:
            output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv,
                                                                       attn_mask=None, 
                                                                       dropout_p=self.dropout if self.training else 0.0, is_causal=True)
        else:
            _xk=xk.clone()
            t=_xk.transpose(2, 3)
            scores = torch.matmul(xq,t)
            scores = scores/math.sqrt(self.head_dim)
            a=self.mask[:, :, :seqlen, :seqlen]
            scores = torch.add(scores,a)
            scores = F.softmax(scores.float(), dim=-1)
            scores = scores.type_as(xq)
            scores = self.attn_dropout(scores)
            output = torch.matmul(scores, xv)  
        return output

lock=threading.Lock()
gindex=0
def save_tensor(name,args,index=0):
    if isinstance(args,torch.Tensor):
        print(name,index,args.shape)
        global gindex
        lock.acquire()
        torch.save(args,"{}_{}_{}_{}.pt".format(device,gindex,name,index))
        gindex+=1
        lock.release()
    if isinstance(args,tuple):
        for idx,x in enumerate(args):
            save_tensor(name,x,index+idx)

op_template=Template('''      
native1_{{new_name}}=getattr(torch.Tensor,'{{name}}')
def {{new_name}}(*args, **kwargs):
    save_tensor("{{name}}-input",args)    
    global native1_{{new_name}}             
    ret=native1_{{new_name}}(*args, **kwargs)
    save_tensor("{{name}}-output",ret)   
    return ret
setattr(torch.Tensor, '{{name}}', {{new_name}})
''')

for op in dir(torch.Tensor):
    if op in ["__iter__","shape","dim","unbind","normal_","data",
                "item","numel","save","has_names","data_ptr","untyped_storage",
                "storage_offset","size","stride","triu","half","is_floating_point",
                "to","ones","randint","ones_like"]:
        continue
    if getattr(torch.Tensor,op).__class__.__name__ not in ["method_descriptor"]:
        continue
    new_name=base64.b64encode(str(f"torch.Tensor.{op}").encode('utf-8')).decode("utf-8").replace("=","")
    exec(op_template.render(name=op,new_name=new_name))

op_template=Template('''      
native2_{{new_name}}=getattr(torch,'{{name}}')
def {{new_name}}(*args, **kwargs):
    save_tensor("{{name}}-input",args)    
    global native2_{{new_name}}             
    ret=native2_{{new_name}}(*args, **kwargs)
    save_tensor("{{name}}-output",ret) 
    return ret
setattr(torch, '{{name}}', {{new_name}})
''')

for op in dir(torch):
    if op in ["is_grad_enabled","__iter__","save","has_names","data_ptr",
              "untyped_storage","storage_offset","size","stride","triu",
              "is_floating_point","to","ones","randint","full","reshape","ones_like"]:
        continue
    if getattr(torch,op).__class__.__name__ not in ["builtin_function_or_method"]:
        continue
    new_name=base64.b64encode(str(f"torch.{op}").encode('utf-8')).decode("utf-8").replace("=","")
    exec(op_template.render(name=op,new_name=new_name))

def hook_backwards(loss, cached):
    if loss is None:
        return    
    def posthook(*args,**kwargs):
        save_tensor(loss.__class__.__name__,args)
    def prehook(*args,**kwargs):
        pass
    loss.register_prehook(prehook)
    loss.register_hook(posthook)
    cached.add(loss)
    for _, child in enumerate(loss.next_functions):
        if child[0] not in cached:
          hook_backwards(child[0],cached)

def main(flash,bs, n_local_heads, seqlen, head_dim):
    torch.random.manual_seed(1)

    q = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)
    k = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)
    v = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)

    q.data.normal_(0, 0.1)
    k.data.normal_(0, 0.1)
    v.data.normal_(0, 0.1)

    q=Variable(q, requires_grad=True).to(device)
    k=Variable(k, requires_grad=True).to(device)
    v=Variable(v, requires_grad=True).to(device)

    gt= torch.randint(0,head_dim,(bs*n_local_heads*seqlen,1)).reshape(-1).to(device)
    loss_func=nn.CrossEntropyLoss().to(device)

    model=Attention(seqlen,head_dim,flash).half().to(device)
    optim = torch.optim.SGD([q,k,v], lr=1.1)

    for i in range(1):
        output = model(q,k,v)
        loss=loss_func(output.reshape(-1,head_dim),gt)
        hook_backwards(loss.grad_fn, cached=set())
        loss.backward()  
        optim.step()
        print("{:.5f},{:.5f},{:.5f},{:.5f}".format(q.sum().item(),k.sum().item(),v.sum().item(),loss.item()))

bs, n_local_heads, seqlen, head_dim = 8, 8, 512, 64
main(False,bs, n_local_heads, seqlen, head_dim)

二.输出

bash 复制代码
reshape-input 0 torch.Size([32768, 1])
reshape-output 0 torch.Size([32768])
clone-input 0 torch.Size([8, 8, 512, 64])
clone-output 0 torch.Size([8, 8, 512, 64])
transpose-input 0 torch.Size([8, 8, 512, 64])
transpose-output 0 torch.Size([8, 8, 64, 512])
matmul-input 0 torch.Size([8, 8, 512, 64])
matmul-input 1 torch.Size([8, 8, 64, 512])
matmul-output 0 torch.Size([8, 8, 512, 512])
__truediv__-input 0 torch.Size([8, 8, 512, 512])
__truediv__-output 0 torch.Size([8, 8, 512, 512])
add-input 0 torch.Size([8, 8, 512, 512])
add-input 1 torch.Size([1, 1, 512, 512])
add-output 0 torch.Size([8, 8, 512, 512])
float-input 0 torch.Size([8, 8, 512, 512])
float-output 0 torch.Size([8, 8, 512, 512])
softmax-input 0 torch.Size([8, 8, 512, 512])
softmax-output 0 torch.Size([8, 8, 512, 512])
type_as-input 0 torch.Size([8, 8, 512, 512])
type_as-input 1 torch.Size([8, 8, 512, 64])
type_as-output 0 torch.Size([8, 8, 512, 512])
matmul-input 0 torch.Size([8, 8, 512, 512])
matmul-input 1 torch.Size([8, 8, 512, 64])
matmul-output 0 torch.Size([8, 8, 512, 64])
reshape-input 0 torch.Size([8, 8, 512, 64])
reshape-output 0 torch.Size([32768, 64])
NllLossBackward0 0 torch.Size([32768, 64])
NllLossBackward0 1 torch.Size([])
LogSoftmaxBackward0 0 torch.Size([32768, 64])
LogSoftmaxBackward0 1 torch.Size([32768, 64])
ViewBackward0 0 torch.Size([8, 8, 512, 64])
ViewBackward0 1 torch.Size([32768, 64])
UnsafeViewBackward0 0 torch.Size([64, 512, 64])
UnsafeViewBackward0 1 torch.Size([8, 8, 512, 64])
BmmBackward0 0 torch.Size([64, 512, 512])
BmmBackward0 1 torch.Size([64, 512, 64])
BmmBackward0 1 torch.Size([64, 512, 64])
ViewBackward0 0 torch.Size([8, 8, 512, 64])
ViewBackward0 1 torch.Size([64, 512, 64])
ExpandBackward0 0 torch.Size([8, 8, 512, 64])
ExpandBackward0 1 torch.Size([8, 8, 512, 64])
AccumulateGrad 1 torch.Size([8, 8, 512, 64])
ViewBackward0 0 torch.Size([8, 8, 512, 512])
ViewBackward0 1 torch.Size([64, 512, 512])
ExpandBackward0 0 torch.Size([8, 8, 512, 512])
ExpandBackward0 1 torch.Size([8, 8, 512, 512])
ToCopyBackward0 0 torch.Size([8, 8, 512, 512])
ToCopyBackward0 1 torch.Size([8, 8, 512, 512])
SoftmaxBackward0 0 torch.Size([8, 8, 512, 512])
SoftmaxBackward0 1 torch.Size([8, 8, 512, 512])
ToCopyBackward0 0 torch.Size([8, 8, 512, 512])
ToCopyBackward0 1 torch.Size([8, 8, 512, 512])
AddBackward0 0 torch.Size([8, 8, 512, 512])
AddBackward0 1 torch.Size([8, 8, 512, 512])
DivBackward0 0 torch.Size([8, 8, 512, 512])
DivBackward0 1 torch.Size([8, 8, 512, 512])
UnsafeViewBackward0 0 torch.Size([64, 512, 512])
UnsafeViewBackward0 1 torch.Size([8, 8, 512, 512])
BmmBackward0 0 torch.Size([64, 512, 64])
BmmBackward0 1 torch.Size([64, 64, 512])
BmmBackward0 1 torch.Size([64, 512, 512])
ReshapeAliasBackward0 0 torch.Size([8, 8, 64, 512])
ReshapeAliasBackward0 1 torch.Size([64, 64, 512])
ExpandBackward0 0 torch.Size([8, 8, 64, 512])
ExpandBackward0 1 torch.Size([8, 8, 64, 512])
ViewBackward0 0 torch.Size([8, 8, 512, 64])
ViewBackward0 1 torch.Size([64, 512, 64])
ExpandBackward0 0 torch.Size([8, 8, 512, 64])
ExpandBackward0 1 torch.Size([8, 8, 512, 64])
AccumulateGrad 1 torch.Size([8, 8, 512, 64])
TransposeBackward0 0 torch.Size([8, 8, 512, 64])
TransposeBackward0 1 torch.Size([8, 8, 64, 512])
CloneBackward0 0 torch.Size([8, 8, 512, 64])
CloneBackward0 1 torch.Size([8, 8, 512, 64])
AccumulateGrad 1 torch.Size([8, 8, 512, 64])
sum-input 0 torch.Size([8, 8, 512, 64])
sum-output 0 torch.Size([])
sum-input 0 torch.Size([8, 8, 512, 64])
sum-output 0 torch.Size([])
sum-input 0 torch.Size([8, 8, 512, 64])
sum-output 0 torch.Size([])
45.56250,-12.76562,121.68750,4.16016
相关推荐
风逸hhh1 分钟前
python打卡day58@浙大疏锦行
开发语言·python
盼小辉丶2 分钟前
PyTorch实战(14)——条件生成对抗网络(conditional GAN,cGAN)
人工智能·pytorch·生成对抗网络
Allen_LVyingbo44 分钟前
数智读书笔记系列035《未来医疗:医疗4.0引领第四次医疗产业变革》
人工智能·经验分享·笔记·健康医疗
zzc9211 小时前
时频图数据集更正程序,去除坐标轴白边及调整对应的标签值
人工智能·深度学习·数据集·标签·时频图·更正·白边
isNotNullX1 小时前
什么是数据分析?常见方法全解析
大数据·数据库·数据仓库·人工智能·数据分析
烛阴1 小时前
一文搞懂 Python 闭包:让你的代码瞬间“高级”起来!
前端·python
riveting1 小时前
明远智睿H618:开启多场景智慧生活新时代
人工智能·嵌入式硬件·智能硬件·lga封装·3506
JosieBook1 小时前
【Java编程动手学】Java中的数组与集合
java·开发语言·python
夜阑卧听风吹雨,铁马冰河入梦来1 小时前
Spring AI 阿里巴巴学习
人工智能·学习·spring
c7691 小时前
【文献笔记】Automatic Chain of Thought Prompting in Large Language Models
人工智能·笔记·语言模型·论文笔记