pytorch中,numel()函数用于获取张量中元素数目 ,其中 numel() 可以理解为是 number of elements 的缩写。
例如:
python
import torch
a = torch.randn(2,3)
b = a.numel()
print(a,b)
# tensor([[-0.4062, -0.8251, -2.2294],
# [ 0.5109, -1.4237, 0.8322]]) 6
比如实际应用,numel()函数可用于获取模型参数的总数目:
python
import logging
# model = ...
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")