拦截pytorch算子,dump输入输出

拦截pytorch算子,dump输入输出

希望dump出pytorch每个算子的输入输出,但pytorch普通的hook机制只能拦截module.以下提供一种方法可以拦截torch.add,torch.Tensor.add这类算子.原理是通过模板替换,劫持torch和torch.Tensor中的算子.遍历next_functions调用register_hook拦截backward.

一.代码

python 复制代码
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch
from torch import nn
import math
import torch.nn.functional as F
from torch.autograd import Variable
import time
import os
import threading
import base64
from jinja2 import Template

device="cuda"

class Attention(nn.Module):
    def __init__(self,max_seq_len,head_dim,flash):
        super().__init__()
        self.flash = flash #hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        self.dropout=0
        self.attn_dropout = nn.Dropout(self.dropout)
        self.head_dim=head_dim
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            mask = torch.full((1, 1, max_seq_len, max_seq_len), float("-inf")).to(device)
            mask = torch.triu(mask, diagonal=1).half().to(device)
            self.register_buffer("mask", mask)		
    def forward(
            self,xq: torch.Tensor,xk: torch.Tensor,xv: torch.Tensor):
        if self.flash:
            output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv,
                                                                       attn_mask=None, 
                                                                       dropout_p=self.dropout if self.training else 0.0, is_causal=True)
        else:
            _xk=xk.clone()
            t=_xk.transpose(2, 3)
            scores = torch.matmul(xq,t)
            scores = scores/math.sqrt(self.head_dim)
            a=self.mask[:, :, :seqlen, :seqlen]
            scores = torch.add(scores,a)
            scores = F.softmax(scores.float(), dim=-1)
            scores = scores.type_as(xq)
            scores = self.attn_dropout(scores)
            output = torch.matmul(scores, xv)  
        return output

lock=threading.Lock()
gindex=0
def save_tensor(name,args,index=0):
    if isinstance(args,torch.Tensor):
        print(name,index,args.shape)
        global gindex
        lock.acquire()
        torch.save(args,"{}_{}_{}_{}.pt".format(device,gindex,name,index))
        gindex+=1
        lock.release()
    if isinstance(args,tuple):
        for idx,x in enumerate(args):
            save_tensor(name,x,index+idx)

op_template=Template('''      
native1_{{new_name}}=getattr(torch.Tensor,'{{name}}')
def {{new_name}}(*args, **kwargs):
    save_tensor("{{name}}-input",args)    
    global native1_{{new_name}}             
    ret=native1_{{new_name}}(*args, **kwargs)
    save_tensor("{{name}}-output",ret)   
    return ret
setattr(torch.Tensor, '{{name}}', {{new_name}})
''')

for op in dir(torch.Tensor):
    if op in ["__iter__","shape","dim","unbind","normal_","data",
                "item","numel","save","has_names","data_ptr","untyped_storage",
                "storage_offset","size","stride","triu","half","is_floating_point",
                "to","ones","randint","ones_like"]:
        continue
    if getattr(torch.Tensor,op).__class__.__name__ not in ["method_descriptor"]:
        continue
    new_name=base64.b64encode(str(f"torch.Tensor.{op}").encode('utf-8')).decode("utf-8").replace("=","")
    exec(op_template.render(name=op,new_name=new_name))

op_template=Template('''      
native2_{{new_name}}=getattr(torch,'{{name}}')
def {{new_name}}(*args, **kwargs):
    save_tensor("{{name}}-input",args)    
    global native2_{{new_name}}             
    ret=native2_{{new_name}}(*args, **kwargs)
    save_tensor("{{name}}-output",ret) 
    return ret
setattr(torch, '{{name}}', {{new_name}})
''')

for op in dir(torch):
    if op in ["is_grad_enabled","__iter__","save","has_names","data_ptr",
              "untyped_storage","storage_offset","size","stride","triu",
              "is_floating_point","to","ones","randint","full","reshape","ones_like"]:
        continue
    if getattr(torch,op).__class__.__name__ not in ["builtin_function_or_method"]:
        continue
    new_name=base64.b64encode(str(f"torch.{op}").encode('utf-8')).decode("utf-8").replace("=","")
    exec(op_template.render(name=op,new_name=new_name))

def hook_backwards(loss, cached):
    if loss is None:
        return    
    def posthook(*args,**kwargs):
        save_tensor(loss.__class__.__name__,args)
    def prehook(*args,**kwargs):
        pass
    loss.register_prehook(prehook)
    loss.register_hook(posthook)
    cached.add(loss)
    for _, child in enumerate(loss.next_functions):
        if child[0] not in cached:
          hook_backwards(child[0],cached)

def main(flash,bs, n_local_heads, seqlen, head_dim):
    torch.random.manual_seed(1)

    q = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)
    k = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)
    v = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)

    q.data.normal_(0, 0.1)
    k.data.normal_(0, 0.1)
    v.data.normal_(0, 0.1)

    q=Variable(q, requires_grad=True).to(device)
    k=Variable(k, requires_grad=True).to(device)
    v=Variable(v, requires_grad=True).to(device)

    gt= torch.randint(0,head_dim,(bs*n_local_heads*seqlen,1)).reshape(-1).to(device)
    loss_func=nn.CrossEntropyLoss().to(device)

    model=Attention(seqlen,head_dim,flash).half().to(device)
    optim = torch.optim.SGD([q,k,v], lr=1.1)

    for i in range(1):
        output = model(q,k,v)
        loss=loss_func(output.reshape(-1,head_dim),gt)
        hook_backwards(loss.grad_fn, cached=set())
        loss.backward()  
        optim.step()
        print("{:.5f},{:.5f},{:.5f},{:.5f}".format(q.sum().item(),k.sum().item(),v.sum().item(),loss.item()))

bs, n_local_heads, seqlen, head_dim = 8, 8, 512, 64
main(False,bs, n_local_heads, seqlen, head_dim)

二.输出

bash 复制代码
reshape-input 0 torch.Size([32768, 1])
reshape-output 0 torch.Size([32768])
clone-input 0 torch.Size([8, 8, 512, 64])
clone-output 0 torch.Size([8, 8, 512, 64])
transpose-input 0 torch.Size([8, 8, 512, 64])
transpose-output 0 torch.Size([8, 8, 64, 512])
matmul-input 0 torch.Size([8, 8, 512, 64])
matmul-input 1 torch.Size([8, 8, 64, 512])
matmul-output 0 torch.Size([8, 8, 512, 512])
__truediv__-input 0 torch.Size([8, 8, 512, 512])
__truediv__-output 0 torch.Size([8, 8, 512, 512])
add-input 0 torch.Size([8, 8, 512, 512])
add-input 1 torch.Size([1, 1, 512, 512])
add-output 0 torch.Size([8, 8, 512, 512])
float-input 0 torch.Size([8, 8, 512, 512])
float-output 0 torch.Size([8, 8, 512, 512])
softmax-input 0 torch.Size([8, 8, 512, 512])
softmax-output 0 torch.Size([8, 8, 512, 512])
type_as-input 0 torch.Size([8, 8, 512, 512])
type_as-input 1 torch.Size([8, 8, 512, 64])
type_as-output 0 torch.Size([8, 8, 512, 512])
matmul-input 0 torch.Size([8, 8, 512, 512])
matmul-input 1 torch.Size([8, 8, 512, 64])
matmul-output 0 torch.Size([8, 8, 512, 64])
reshape-input 0 torch.Size([8, 8, 512, 64])
reshape-output 0 torch.Size([32768, 64])
NllLossBackward0 0 torch.Size([32768, 64])
NllLossBackward0 1 torch.Size([])
LogSoftmaxBackward0 0 torch.Size([32768, 64])
LogSoftmaxBackward0 1 torch.Size([32768, 64])
ViewBackward0 0 torch.Size([8, 8, 512, 64])
ViewBackward0 1 torch.Size([32768, 64])
UnsafeViewBackward0 0 torch.Size([64, 512, 64])
UnsafeViewBackward0 1 torch.Size([8, 8, 512, 64])
BmmBackward0 0 torch.Size([64, 512, 512])
BmmBackward0 1 torch.Size([64, 512, 64])
BmmBackward0 1 torch.Size([64, 512, 64])
ViewBackward0 0 torch.Size([8, 8, 512, 64])
ViewBackward0 1 torch.Size([64, 512, 64])
ExpandBackward0 0 torch.Size([8, 8, 512, 64])
ExpandBackward0 1 torch.Size([8, 8, 512, 64])
AccumulateGrad 1 torch.Size([8, 8, 512, 64])
ViewBackward0 0 torch.Size([8, 8, 512, 512])
ViewBackward0 1 torch.Size([64, 512, 512])
ExpandBackward0 0 torch.Size([8, 8, 512, 512])
ExpandBackward0 1 torch.Size([8, 8, 512, 512])
ToCopyBackward0 0 torch.Size([8, 8, 512, 512])
ToCopyBackward0 1 torch.Size([8, 8, 512, 512])
SoftmaxBackward0 0 torch.Size([8, 8, 512, 512])
SoftmaxBackward0 1 torch.Size([8, 8, 512, 512])
ToCopyBackward0 0 torch.Size([8, 8, 512, 512])
ToCopyBackward0 1 torch.Size([8, 8, 512, 512])
AddBackward0 0 torch.Size([8, 8, 512, 512])
AddBackward0 1 torch.Size([8, 8, 512, 512])
DivBackward0 0 torch.Size([8, 8, 512, 512])
DivBackward0 1 torch.Size([8, 8, 512, 512])
UnsafeViewBackward0 0 torch.Size([64, 512, 512])
UnsafeViewBackward0 1 torch.Size([8, 8, 512, 512])
BmmBackward0 0 torch.Size([64, 512, 64])
BmmBackward0 1 torch.Size([64, 64, 512])
BmmBackward0 1 torch.Size([64, 512, 512])
ReshapeAliasBackward0 0 torch.Size([8, 8, 64, 512])
ReshapeAliasBackward0 1 torch.Size([64, 64, 512])
ExpandBackward0 0 torch.Size([8, 8, 64, 512])
ExpandBackward0 1 torch.Size([8, 8, 64, 512])
ViewBackward0 0 torch.Size([8, 8, 512, 64])
ViewBackward0 1 torch.Size([64, 512, 64])
ExpandBackward0 0 torch.Size([8, 8, 512, 64])
ExpandBackward0 1 torch.Size([8, 8, 512, 64])
AccumulateGrad 1 torch.Size([8, 8, 512, 64])
TransposeBackward0 0 torch.Size([8, 8, 512, 64])
TransposeBackward0 1 torch.Size([8, 8, 64, 512])
CloneBackward0 0 torch.Size([8, 8, 512, 64])
CloneBackward0 1 torch.Size([8, 8, 512, 64])
AccumulateGrad 1 torch.Size([8, 8, 512, 64])
sum-input 0 torch.Size([8, 8, 512, 64])
sum-output 0 torch.Size([])
sum-input 0 torch.Size([8, 8, 512, 64])
sum-output 0 torch.Size([])
sum-input 0 torch.Size([8, 8, 512, 64])
sum-output 0 torch.Size([])
45.56250,-12.76562,121.68750,4.16016
相关推荐
小爬菜5 分钟前
Django学习笔记(项目默认文件)-02
前端·数据库·笔记·python·学习·django
XianxinMao13 分钟前
2024大模型双向突破:MoE架构创新与小模型崛起
人工智能·架构
Francek Chen25 分钟前
【深度学习基础】多层感知机 | 模型选择、欠拟合和过拟合
人工智能·pytorch·深度学习·神经网络·多层感知机·过拟合
Channing Lewis35 分钟前
python生成随机字符串
服务器·开发语言·python
pchmi1 小时前
C# OpenCV机器视觉:红外体温检测
人工智能·数码相机·opencv·计算机视觉·c#·机器视觉·opencvsharp
资深设备全生命周期管理1 小时前
以Python 做服务器,N Robot 做客户端,小小UI,拿捏
服务器·python·ui
洪小帅1 小时前
Django 的 `Meta` 类和外键的使用
数据库·python·django·sqlite
认知作战壳吉桔1 小时前
中国认知作战研究中心:从认知战角度分析2007年iPhone发布
大数据·人工智能·新质生产力·认知战·认知战研究中心
夏沫mds1 小时前
web3py+flask+ganache的智能合约教育平台
python·flask·web3·智能合约
去往火星2 小时前
opencv在图片上添加中文汉字(c++以及python)
开发语言·c++·python