跟踪一个Pytorch Module在训练过程中的内存分配情况

跟踪一个Pytorch Module在训练过程中的内存分配情况

目的 :跟踪一个Pytorch Module在训练过程中的内存分配情况
方法 :
1.通过pre_hook module的来区分module的边界
2.通过__torch_dispatch__拦截所有的aten算子,计算在该算子中新创建tensor的总内存占用量
3.通过tensor.data_ptr()为tensor去重,表示一块独立的内存

代码

python 复制代码
import numpy as np
import torch
from torch.nn import Module, Linear
import torch.nn as nn
from torch.optim import Adam,SGD
from torch.utils._python_dispatch import TorchDispatchMode
from dataclasses import dataclass
from typing import Any
import time

@dataclass
class _ProfilerState:
    cls: Any
    object: Any = None

current_module=None
tesor_cache=set()

def get_current_mem():
    global current_module
    print(f'[INFO]{current_module["name"]}:{np.sum(current_module["size"])}')
    current_module=None

class InputDescriptor:
    def __init__(self) -> None:
        self.total_input_size=0
    def _save_var(self,v):
        class_name=v.__class__.__name__
        if class_name in ["Tensor","Parameter"]:
            global tesor_cache
            tensorid=v.data_ptr()
            if v.device.type!="cuda":
                return            
            if tensorid not in tesor_cache:
                tesor_cache.add(tensorid)
                sz=v.numel()*v.element_size()
                print(v.shape,v.dtype)
                self.total_input_size += sz
            if class_name=="Parameter" and v.grad is not None:                
                tensorid=v.grad.data_ptr()
                if tensorid not in tesor_cache:
                    tesor_cache.add(tensorid)
                    sz=v.grad.numel()*v.grad.element_size()
                    print("grad",v.grad.shape,v.grad.dtype)
                    self.total_input_size += sz
        elif class_name in ["list","tuple"]:
            for t in v:
                self._save_var(t)
        else:
            pass
    def save_vars(self,ret,*args,**kwargs):
        for arg in args:
            self._save_var(arg)        
        for k,v in kwargs.items():
            self._save_var(v)
        self._save_var(ret)
        global current_module        
        if current_module is None:
            current_module={"name":"Other","size":[]}
        current_module["size"].append(self.total_input_size)

# 对象和类名缓存
object_cache = {}
class_name_count = {}

def get_unique_name(class_name, obj_id):
    # 生成唯一的对象名称
    if class_name not in class_name_count:
        class_name_count[class_name] = 0
    uid = f"{class_name}_{obj_id}"
    if uid not in object_cache:
        class_name_count[class_name] += 1
        object_cache[uid] = {"idx": class_name_count[class_name]}
    return f'{class_name}-{object_cache[uid]["idx"]}'

def initialize_module_attributes(module):
    # 初始化模块属性
    if not hasattr(module, 'uuid'):
        module.uuid = get_unique_name(module.__class__.__name__, id(module))
    if not hasattr(module, 'backward_mem'):
        module.backward_mem = []
    if not hasattr(module, 'forward_mem'):
        module.forward_mem = []

def pre_backward_hook(module, grad_input):
    # 反向传播前的钩子函数
    initialize_module_attributes(module)
    global current_module
    if current_module is not None and np.sum(current_module["size"])>0:
        print(f'[INFO]{current_module["name"]}:{np.sum(current_module["size"])}')
    module.backward_mem.clear()
    current_module={"name":f"backward-{module.uuid}","size":module.backward_mem}

def post_backward_hook(module, grad_input, grad_output):
    # 反向传播后的钩子函数
    initialize_module_attributes(module)

def pre_forward_hook(module, input):
    # 前向传播前的钩子函数
    initialize_module_attributes(module)
    global current_module
    if current_module is not None and np.sum(current_module["size"])>0:
        print(f'[INFO]{current_module["name"]}:{np.sum(current_module["size"])}')
    module.forward_mem.clear()
    current_module={"name":f"forward-{module.uuid}","size":module.forward_mem}

def post_forward_hook(module, input, output):
    # 前向传播后的钩子函数
    initialize_module_attributes(module)

def register_forward_hooks(module):
    # 注册反向传播钩子
    module.register_forward_pre_hook(pre_forward_hook)
    module.register_forward_hook(post_forward_hook)

def register_backward_hooks(module):
    # 注册反向传播钩子
    module.register_full_backward_pre_hook(pre_backward_hook)
    module.register_full_backward_hook(post_backward_hook)

class HookModel(object):
    def __init__(self, model):
        output_dict = {}
        self.get_submodule_recrusicve(model, "", output_dict)
        for name, module in output_dict.items():
            if name.endswith("Sequential"):
                continue
            register_forward_hooks(module)
            register_backward_hooks(module)
    def get_submodule_recrusicve(self,module, prefix, output_dict):
        prefix = prefix + "/" + type(module).__name__
        output_dict[prefix] = module
        for name, submodule in module.named_children():
            self.get_submodule_recrusicve(submodule, f"{prefix}[{name}]", output_dict)

class TorchDumpDispatchMode(TorchDispatchMode):
    def __init__(self,parent):
        super().__init__()
        self.parent=parent
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}  
        ret= func(*args, **kwargs)
        desc=InputDescriptor()
        desc.save_vars(ret,*args,**kwargs)
        if desc.total_input_size>0:
            print(f"{func.__name__}:{desc.total_input_size}")
        return ret

class TorchDebugDumper:
    _CURRENT_Dumper = None
    def __init__(self):
        self.p= _ProfilerState(TorchDumpDispatchMode)

    def __enter__(self):
        assert TorchDebugDumper._CURRENT_Dumper is None
        TorchDebugDumper._CURRENT_Dumper = self
        if self.p.object is None:
            o = self.p.cls(self)
            o.__enter__()
            self.p.object = o
        else:
            self.p.object.step()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        TorchDebugDumper._CURRENT_Dumper = None
        if self.p.object is not None:
            self.p.object.__exit__(exc_type, exc_val, exc_tb)
            del self.p.object

class FeedForward(Module):
    def __init__(self,hidden_size,ffn_size):
        super().__init__()
        self.fc = nn.Sequential(
            Linear(in_features=hidden_size, out_features=ffn_size,bias=False),
            nn.ReLU(),
            Linear(in_features=ffn_size, out_features=ffn_size*2,bias=False),
            nn.Dropout(0.5),
            Linear(in_features=ffn_size*2, out_features=hidden_size,bias=False),
        )
        self.norm = nn.LayerNorm(normalized_shape=hidden_size, elementwise_affine=False)
 
    def forward(self, x):
        return x + self.fc(self.norm(x))
    
def main():
    model=FeedForward(100,128)
    model=model.float().cuda()
    model.train()
    obj=HookModel(model)
    global current_module
    with TorchDebugDumper():
        opt=Adam(model.parameters(),lr=0.001)
        input=torch.randn(1,100).float().cuda()
        output=model(input)
        get_current_mem()
        loss=-torch.log(output.sum())
        opt.zero_grad()
        loss.backward()
        get_current_mem()
        current_module=None
        opt.step()    
    get_current_mem()
    num_model_params = sum(p.numel() for p in model.parameters())
    print(f"[INFO]Number of model parameters: {num_model_params}")
main()

输出

bash 复制代码
torch.Size([1, 100]) torch.float32
_to_copy.default:400
[INFO]Other:400
torch.Size([1, 100]) torch.float32
torch.Size([1, 1]) torch.float32
torch.Size([1, 1]) torch.float32
native_layer_norm.default:408
[INFO]forward-LayerNorm-1:408
torch.Size([128, 100]) torch.float32
t.default:51200
[INFO]forward-Linear-1:51200
torch.Size([256, 128]) torch.float32
t.default:131072
torch.Size([1, 256]) torch.float32
mm.default:1024
[INFO]forward-Linear-2:132096
torch.Size([1, 256]) torch.float32
native_dropout.default:1024
[INFO]forward-Dropout-1:1024
torch.Size([100, 256]) torch.float32
t.default:102400
torch.Size([1, 100]) torch.float32
add.Tensor:400
[INFO]forward-Linear-3:102800
torch.Size([]) torch.float32
log.default:4
torch.Size([]) torch.float32
neg.default:4
torch.Size([]) torch.float32
neg.default:4
torch.Size([]) torch.float32
div.Tensor:4
[INFO]Other:16
torch.Size([100, 256]) torch.float32
mm.default:102400
torch.Size([1, 256]) torch.float32
mm.default:1024
[INFO]backward-Linear-3:103424
torch.Size([128, 100]) torch.float32
mm.default:51200
[INFO]backward-Linear-1:51200
torch.Size([128, 100]) torch.float32
zeros_like.default:51200
torch.Size([128, 100]) torch.float32
zeros_like.default:51200
torch.Size([256, 128]) torch.float32
zeros_like.default:131072
torch.Size([256, 128]) torch.float32
zeros_like.default:131072
torch.Size([100, 256]) torch.float32
zeros_like.default:102400
torch.Size([100, 256]) torch.float32
zeros_like.default:102400
torch.Size([128, 100]) torch.float32
torch.Size([256, 128]) torch.float32
torch.Size([100, 256]) torch.float32
_foreach_sqrt.default:284672
[INFO]Other:854016
[INFO]Number of model parameters: 71168
相关推荐
Cachel wood18 分钟前
python round四舍五入和decimal库精确四舍五入
java·linux·前端·数据库·vue.js·python·前端框架
IT古董19 分钟前
【漫话机器学习系列】017.大O算法(Big-O Notation)
人工智能·机器学习
凯哥是个大帅比19 分钟前
人工智能ACA(五)--深度学习基础
人工智能·深度学习
終不似少年遊*24 分钟前
pyecharts
python·信息可视化·数据分析·学习笔记·pyecharts·使用技巧
Python之栈25 分钟前
【无标题】
数据库·python·mysql
m0_7482329239 分钟前
DALL-M:基于大语言模型的上下文感知临床数据增强方法 ,补充
人工智能·语言模型·自然语言处理
袁袁袁袁满41 分钟前
100天精通Python(爬虫篇)——第113天:‌爬虫基础模块之urllib详细教程大全
开发语言·爬虫·python·网络爬虫·爬虫实战·urllib·urllib模块教程
szxinmai主板定制专家44 分钟前
【国产NI替代】基于FPGA的32通道(24bits)高精度终端采集核心板卡
大数据·人工智能·fpga开发
海棠AI实验室1 小时前
AI的进阶之路:从机器学习到深度学习的演变(三)
人工智能·深度学习·机器学习
机器懒得学习1 小时前
基于YOLOv5的智能水域监测系统:从目标检测到自动报告生成
人工智能·yolo·目标检测