pytorch如何知道某个Parameter是在哪一个Module中的创建的

pytorch如何知道某个Parameter是在哪一个Module中的创建的

在定位pytorch精度问题时,发现optimizer中某些Parameter值异常,想知道它属于哪个模块的.本文提供二种方法
1.全局搜索
2.在创建Parameter的地方加一个属性,写明所在的模块名,需要的时候直接获取

代码

python 复制代码
import torch
import sys
sys.setrecursionlimit(1000)
        
def search_recursive(var,stack,_id,depth):
    if var.__class__.__name__ in [
                                    "module","type","NoneType",
                                    "str","int","function","method-wrapper",
                                    "builtin_function_or_method",
                                    "method","_TensorMeta",
                                    "Tensor","method_descriptor",
                                    "bool","device","dtype",
                                    "getset_descriptor","layout",
                                    "wrapper_descriptor","property",
                                    "_ParameterMeta","mappingproxy",
                                    "Parameter","_abc_data","SourceFileLoader",
                                    "code","bytes","ABCMeta",
                                    "ForwardRef","ellipsis","TypeVar"
                                 ]:
        return False
    
    if isinstance(var,dict):
        for k,v in var.items():
            ret=search_recursive(v,stack,_id,depth+1)
            if ret:
                return ret
    elif isinstance(var,list) or isinstance(var,tuple):
        for i in var:
            ret=search_recursive(i,stack,_id,depth+1)
            if ret:
                return ret
    else:     
        if not var.__class__.__name__.startswith("_"):
            stack[depth]=var.__class__.__name__         
        for name in dir(var):
            try:
                obj=eval(f"var.{name}")
                if isinstance(obj,torch.nn.modules.linear.Linear) and id(obj.weight)==_id:
                    return stack[depth]
                ret=search_recursive(obj,stack,_id,depth+1)
                if ret:
                    return ret                 
            except:
                pass              
    return None

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.mlp=torch.nn.Linear(5120,3850)
        self.mlp.weight.__setattr__("model_name","MyModel") #方法一:通过添加属性
    def forward(self, x):
        out=self.mlp(x)
        return out
class MyContainer(object):
    def __init__(self):
        self.obj=MyModel()    
    def get_param(self):
        return self.obj.mlp.weight
obj = MyContainer()
param_group={}
param_group["w0"]=obj.get_param()
param_array=[param_group,obj]

print("GetModelName By getattr:",getattr(param_group["w0"],"model_name"))
# 方法二:递归搜索全局变量
model_name=search_recursive(globals(),{},id(param_group["w0"]),0)
print("GetModelName By search_recursive:",model_name)
相关推荐
算家计算1 分钟前
小鹏机器人真假难分引全网热议!而这只是开始......
人工智能·机器人·资讯
Hello_WOAIAI2 分钟前
2.4 python装饰器在 Web 框架和测试中的实战应用
开发语言·前端·python
百锦再12 分钟前
第1章 Rust语言概述
java·开发语言·人工智能·python·rust·go·1024程序员节
tokepson24 分钟前
chatgpt-to-md优化并重新复习
python·ai·技术·pypi·记录
说私域26 分钟前
开源AI智能名片链动2+1模式S2B2C商城系统下消费点评的信任构建机制研究
人工智能·开源
Victory_orsh32 分钟前
“自然搞懂”深度学习(基于Pytorch架构)——010203
人工智能·pytorch·python·深度学习·神经网络·算法·机器学习
java1234_小锋32 分钟前
PyTorch2 Python深度学习 - 模型保存与加载
开发语言·python·深度学习·pytorch2
Python图像识别35 分钟前
74_基于深度学习的垃圾桶垃圾溢出检测系统(yolo11、yolov8、yolov5+UI界面+Python项目源码+模型+标注好的数据集)
python·深度学习·yolo
长桥夜波36 分钟前
机器学习日报10
人工智能·机器学习
MrSYJ43 分钟前
可以指定 Jupyter Notebook 使用的虚拟环境吗
python·llm·agent