pytorch如何知道某个Parameter是在哪一个Module中的创建的
在定位pytorch精度问题时,发现optimizer中某些Parameter值异常,想知道它属于哪个模块的.本文提供二种方法
1.全局搜索
2.在创建Parameter的地方加一个属性,写明所在的模块名,需要的时候直接获取
代码
python
import torch
import sys
sys.setrecursionlimit(1000)
def search_recursive(var,stack,_id,depth):
if var.__class__.__name__ in [
"module","type","NoneType",
"str","int","function","method-wrapper",
"builtin_function_or_method",
"method","_TensorMeta",
"Tensor","method_descriptor",
"bool","device","dtype",
"getset_descriptor","layout",
"wrapper_descriptor","property",
"_ParameterMeta","mappingproxy",
"Parameter","_abc_data","SourceFileLoader",
"code","bytes","ABCMeta",
"ForwardRef","ellipsis","TypeVar"
]:
return False
if isinstance(var,dict):
for k,v in var.items():
ret=search_recursive(v,stack,_id,depth+1)
if ret:
return ret
elif isinstance(var,list) or isinstance(var,tuple):
for i in var:
ret=search_recursive(i,stack,_id,depth+1)
if ret:
return ret
else:
if not var.__class__.__name__.startswith("_"):
stack[depth]=var.__class__.__name__
for name in dir(var):
try:
obj=eval(f"var.{name}")
if isinstance(obj,torch.nn.modules.linear.Linear) and id(obj.weight)==_id:
return stack[depth]
ret=search_recursive(obj,stack,_id,depth+1)
if ret:
return ret
except:
pass
return None
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.mlp=torch.nn.Linear(5120,3850)
self.mlp.weight.__setattr__("model_name","MyModel") #方法一:通过添加属性
def forward(self, x):
out=self.mlp(x)
return out
class MyContainer(object):
def __init__(self):
self.obj=MyModel()
def get_param(self):
return self.obj.mlp.weight
obj = MyContainer()
param_group={}
param_group["w0"]=obj.get_param()
param_array=[param_group,obj]
print("GetModelName By getattr:",getattr(param_group["w0"],"model_name"))
# 方法二:递归搜索全局变量
model_name=search_recursive(globals(),{},id(param_group["w0"]),0)
print("GetModelName By search_recursive:",model_name)