用于神经网络的FLOP和Params计算工具
1. FlopCountAnalysis
bash
pip install fvcore
python
import torch
from torchvision.models import resnet152, resnet18
from fvcore.nn import FlopCountAnalysis, parameter_count_table
model = resnet152(num_classes=1000)
tensor = (torch.rand(1, 3, 224, 224),)
#分析FLOPs
flops = FlopCountAnalysis(model, tensor)
print("FLOPs: ", flops.total())
def print_model_parm_nums(model):
total = sum([param.nelement() for param in model.parameters()])
print(' + Number of params: %.2fM' % (total / 1e6))
print_model_parm_nums(model)
2. flopth
bash
pip install flopth
Running on models in torchvision.models
python
$ flopth -m alexnet
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| module_name | module_type | in_shape | out_shape | params | params_percent | params_percent_vis | flops | flops_percent | flops_percent_vis |
+===============+===================+=============+=============+==========+==================+================================+==========+=================+=====================+
| features.0 | Conv2d | (3,224,224) | (64,55,55) | 23.296K | 0.0381271% | | 70.4704M | 9.84839% | #### |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.1 | ReLU | (64,55,55) | (64,55,55) | 0.0 | 0.0% | | 193.6K | 0.027056% | |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.2 | MaxPool2d | (64,55,55) | (64,27,27) | 0.0 | 0.0% | | 193.6K | 0.027056% | |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.3 | Conv2d | (64,27,27) | (192,27,27) | 307.392K | 0.50309% | | 224.089M | 31.3169% | ############### |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.4 | ReLU | (192,27,27) | (192,27,27) | 0.0 | 0.0% | | 139.968K | 0.0195608% | |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.5 | MaxPool2d | (192,27,27) | (192,13,13) | 0.0 | 0.0% | | 139.968K | 0.0195608% | |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.6 | Conv2d | (192,13,13) | (384,13,13) | 663.936K | 1.08662% | | 112.205M | 15.6809% | ####### |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.7 | ReLU | (384,13,13) | (384,13,13) | 0.0 | 0.0% | | 64.896K | 0.00906935% | |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.8 | Conv2d | (384,13,13) | (256,13,13) | 884.992K | 1.44841% | | 149.564M | 20.9018% | ########## |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.9 | ReLU | (256,13,13) | (256,13,13) | 0.0 | 0.0% | | 43.264K | 0.00604624% | |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.10 | Conv2d | (256,13,13) | (256,13,13) | 590.08K | 0.965748% | | 99.7235M | 13.9366% | ###### |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.11 | ReLU | (256,13,13) | (256,13,13) | 0.0 | 0.0% | | 43.264K | 0.00604624% | |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.12 | MaxPool2d | (256,13,13) | (256,6,6) | 0.0 | 0.0% | | 43.264K | 0.00604624% | |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| avgpool | AdaptiveAvgPool2d | (256,6,6) | (256,6,6) | 0.0 | 0.0% | | 9.216K | 0.00128796% | |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| classifier.0 | Dropout | (9216) | (9216) | 0.0 | 0.0% | | 0.0 | 0.0% | |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| classifier.1 | Linear | (9216) | (4096) | 37.7528M | 61.7877% | ############################## | 37.7487M | 5.27547% | ## |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| classifier.2 | ReLU | (4096) | (4096) | 0.0 | 0.0% | | 4.096K | 0.000572425% | |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| classifier.3 | Dropout | (4096) | (4096) | 0.0 | 0.0% | | 0.0 | 0.0% | |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| classifier.4 | Linear | (4096) | (4096) | 16.7813M | 27.4649% | ############# | 16.7772M | 2.34465% | # |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| classifier.5 | ReLU | (4096) | (4096) | 0.0 | 0.0% | | 4.096K | 0.000572425% | |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| classifier.6 | Linear | (4096) | (1000) | 4.097M | 6.70531% | ### | 4.096M | 0.572425% | |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
FLOPs: 715.553M
Params: 61.1008M
Running on custom models
python
# file path: /tmp/my_model.py
# model name: MyModel
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
self.conv4 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
def forward(self, x1):
x1 = self.conv1(x1)
x1 = self.conv2(x1)
x1 = self.conv3(x1)
x1 = self.conv4(x1)
return x1
python
$ flopth -m MyModel -p /tmp/my_model.py -i 3 224 224
+---------------+---------------+-------------+-------------+----------+------------------+----------------------+----------+-----------------+---------------------+
| module_name | module_type | in_shape | out_shape | params | params_percent | params_percent_vis | flops | flops_percent | flops_percent_vis |
+===============+===============+=============+=============+==========+==================+======================+==========+=================+=====================+
| conv1 | Conv2d | (3,224,224) | (3,224,224) | 84 | 25.0% | ############ | 4.21478M | 25.0% | ############ |
+---------------+---------------+-------------+-------------+----------+------------------+----------------------+----------+-----------------+---------------------+
| conv2 | Conv2d | (3,224,224) | (3,224,224) | 84 | 25.0% | ############ | 4.21478M | 25.0% | ############ |
+---------------+---------------+-------------+-------------+----------+------------------+----------------------+----------+-----------------+---------------------+
| conv3 | Conv2d | (3,224,224) | (3,224,224) | 84 | 25.0% | ############ | 4.21478M | 25.0% | ############ |
+---------------+---------------+-------------+-------------+----------+------------------+----------------------+----------+-----------------+---------------------+
| conv4 | Conv2d | (3,224,224) | (3,224,224) | 84 | 25.0% | ############ | 4.21478M | 25.0% | ############ |
+---------------+---------------+-------------+-------------+----------+------------------+----------------------+----------+-----------------+---------------------+
FLOPs: 16.8591M
Params: 336.0
3. calflops
https://github.com/MrYxJ/calculate-flops.pytorch/tree/main
bash
pip install calflops
python
from calflops import calculate_flops
from torchvision import models
model = models.alexnet()
batch_size = 1
input_shape = (batch_size, 3, 224, 224)
flops, macs, params = calculate_flops(model=model,
input_shape=input_shape,
output_as_string=True,
output_precision=4)
print("Alexnet FLOPs:%s MACs:%s Params:%s \n" %(flops, macs, params))
#Alexnet FLOPs:4.2892 GFLOPS MACs:2.1426 GMACs Params:61.1008 M
- from thop import profile
https://github.com/Lyken17/pytorch-OpCounter
bash
pip install thop
python
from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ))
python
class YourModule(nn.Module):
# your definition
def count_your_model(model, x, y):
# your rule here
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ),
custom_ops={YourModule: count_your_model})