统计神经网络参数量、MAC、FLOPs等信息

0、基础提示

1、FLOPS是用来衡量硬件算力的指标,FLOPs用来衡量模型复杂度。

2、MAC 一般为 FLOPs的2倍

3、并非FLOPs越小在硬件上就一定运行更快,还与模型占用的内存,带宽,等有关

1、FLOPs计算

神经网络参数量。用于衡量模型大小。一般卷积计算方式为:
F L O P s = 2 ∗ H W ( K h ∗ K w ∗ C i n + 1 ) C o u t FLOPs = 2*HW(Kh*Kw*Cin+1)Cout FLOPs=2∗HW(Kh∗Kw∗Cin+1)Cout

其中,

H,W表示该层卷积的高和宽

Kh,Kw表示卷积核的高和宽

2 表示一次乘操作 + 一次加操作

+1 表示bias操作

2、统计工具-THOP

源代码链接

2.1 安装

python 复制代码
pip install thop

python 复制代码
pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git

2.2 基础使用

python 复制代码
from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ))

2.3 定义自己的规则

python 复制代码
class YourModule(nn.Module):
    # your definition
	def count_your_model(model, x, y):
	    # your rule here

input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ), 
                        custom_ops={YourModule: count_your_model})

2.4 模型包含多个输入

修改input就好

python 复制代码
from torchvision.models import resnet50
from thop import profile
model = resnet50()
input1 = input2 = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input1, input2,))

3、 统计工具-torchstat

这个是我更中意的,因为他统计信息更加丰富,包含params,memory, Madd, FLOPs等。缺点在于已经不更新了,且不支持多输入,好在我们可以修改代码支持。
源代码链接

3.1 安装

python 复制代码
pip install torchstat

3.2 基础使用

python 复制代码
from torchstat import stat
import torchvision.models as models
model = models.resnet18()
stat(model, (3, 224, 224))

3.3 输入多个Input

将torchstat 库安装目录下的 torchstat/statistics.py 中按如下修改:

python 复制代码
class ModelStat(object):
	def __init__(self, model, input_size, query_granularity=1):
 		assert isinstance(model, nn.Module)
		# 删除输入长度为3的限制
		# assert isinstance(input_size, (tuple, list)) and len(input_size) == 3
		assert isinstance(input_size, (tuple, list))
		self._model = model
		self._input_size = input_size
		self._query_granularity = query_granularity

将torchstat 库安装目录下的 torchstat/model_hook.py 中按如下修改:

python 复制代码
class ModelHook(object):
	def __init__(self, model, input_size):
		assert isinstance(model, nn.Module)
		assert isinstance(input_size, (list, tuple))
		self._model = model
		# 原始是通过单个输入的尺寸,再构建输入tensor,我们可以修改为在网络外构建输入tensor后直接送入网络
		# self._input_size = input_size
		self._origin_call = dict() # sub module call hook
		self._hook_model()
		# x = torch.rand(1, *self._input_size) # add module duration time
		self._model.eval()
		# self._model(x)
		self._model(*self._input_size)

使用时候测试代码

python 复制代码
from torchstat import stat
import torchvision.models as models
model = models.resnet18()
input1, input2 = torch.rand(1, 3, 224, 224), torch.rand(1, 3, 224, 224)
stat(model, (input1, input2))

大致改动就是这样了,还有什么bug可以自己稍微修改一下哈。另外找修改地方可以看报错提示torchstat安装路径修改。

4、fvcore

stat有个很麻烦的问题是,他不支持transformer,因此包含transformer的网络可以使用fvcore,他是Facebook开源的一个轻量级的核心库。

4.1、 安装

python 复制代码
pip install fvcore

4.2、 基础使用

python 复制代码
from fvcore.nn import FlopCountAnalysis, parameter_count_table
# 创建网络
model = MobileViTBlock(in_channels=32, transformer_dim=64, ffn_dim=256)

# 创建输入网络的tensor
tensor = (torch.rand(1, 32, 64, 64),)

# 分析FLOPs
flops = FlopCountAnalysis(model, tensor)
print("FLOPs: ", flops.total())

# 分析parameters
print(parameter_count_table(model))

参考来自:https://zhuanlan.zhihu.com/p/583106030

欢迎交流补充

相关推荐
埃菲尔铁塔_CV算法11 分钟前
人工智能图像算法:开启视觉新时代的钥匙
人工智能·算法
EasyCVR12 分钟前
EHOME视频平台EasyCVR视频融合平台使用OBS进行RTMP推流,WebRTC播放出现抖动、卡顿如何解决?
人工智能·算法·ffmpeg·音视频·webrtc·监控视频接入
打羽毛球吗️18 分钟前
机器学习中的两种主要思路:数据驱动与模型驱动
人工智能·机器学习
好喜欢吃红柚子35 分钟前
万字长文解读空间、通道注意力机制机制和超详细代码逐行分析(SE,CBAM,SGE,CA,ECA,TA)
人工智能·pytorch·python·计算机视觉·cnn
小馒头学python39 分钟前
机器学习是什么?AIGC又是什么?机器学习与AIGC未来科技的双引擎
人工智能·python·机器学习
神奇夜光杯1 小时前
Python酷库之旅-第三方库Pandas(202)
开发语言·人工智能·python·excel·pandas·标准库及第三方库·学习与成长
正义的彬彬侠1 小时前
《XGBoost算法的原理推导》12-14决策树复杂度的正则化项 公式解析
人工智能·决策树·机器学习·集成学习·boosting·xgboost
Debroon1 小时前
RuleAlign 规则对齐框架:将医生的诊断规则形式化并注入模型,无需额外人工标注的自动对齐方法
人工智能
羊小猪~~1 小时前
神经网络基础--什么是正向传播??什么是方向传播??
人工智能·pytorch·python·深度学习·神经网络·算法·机器学习
AI小杨1 小时前
【车道线检测】一、传统车道线检测:基于霍夫变换的车道线检测史诗级详细教程
人工智能·opencv·计算机视觉·霍夫变换·车道线检测