计算 Pytorch 和 Tensorflow 中的 Flops 不相等?

问题描述

给定相同的模型,我发现pytorch和tensorflow中计算的flops是不同的。 我在 tensorflow 中使用了 keras_flops (https://pypi.org/project/keras-flops/),在 pytorch 中使用了 ptflops (https://pypi.org/project/ptflops/) 来计算 flops。 pytorch 中的 flops 看来是靠我自己亲手计算的了。 tensorflow 是否有一些技巧来加速计算,以便测量很少的触发器? 我在张量流中的模型

d=56
s=12

inp = Input((750,750,1))
x = Conv2D(d,(5,5),padding='same')(inp)
x = PReLU()(x)

x = Conv2D(s,(1,1),padding='valid')(x)
x = PReLU()(x)

x = Conv2D(s,(3,3),padding='same')(x)
x = PReLU()(x)
x = Conv2D(s,padding='same')(x)
x = PReLU()(x)

x = Conv2D(d,padding='same')(x)
x = PReLU()(x)
out = Conv2DTranspose(1,(9,9),strides=(4,4),padding='same',output_padding = 3)(x)

tensorflow 中的 Flops 输出为: 简介:

node name | # float_ops
Conv2D                   8.92b float_ops (100.00%,61.95%)
Conv2DBackpropInput      5.10b float_ops (38.05%,35.44%)
Neg                      180.00m float_ops (2.61%,1.25%)
BiasAdd                  105.75m float_ops (1.36%,0.73%)
Mul                      90.00m float_ops (0.63%,0.63%)

======================End of Report==========================
The FLOPs is:14.3 GFlops

然而,pytorch 中的 FLops 是

Model_1(
  0.013 M,100.000% Params,45.486 GMac,100.000% MACs,(begin): Sequential(
    0.002 M,11.804% Params,0.851 GMac,1.870% MACs,(0): Conv2d(0.001 M,11.367% Params,0.819 GMac,1.801% MACs,1,56,kernel_size=(5,stride=(1,padding=(2,2))
    (1): PReLU(0.0 M,0.437% Params,0.032 GMac,0.069% MACs,num_parameters=56)
  )
  (middle): Sequential(
    0.007 M,52.775% Params,3.803 GMac,8.360% MACs,5.340% Params,0.385 GMac,0.846% MACs,12,kernel_size=(1,1))
    (1): PReLU(0.0 M,0.094% Params,0.007 GMac,0.015% MACs,num_parameters=12)
    (2): Conv2d(0.001 M,10.212% Params,0.736 GMac,1.618% MACs,kernel_size=(3,padding=(1,1))
    (3): PReLU(0.0 M,num_parameters=12)
    (4): Conv2d(0.001 M,1))
    (5): PReLU(0.0 M,num_parameters=12)
    (6): Conv2d(0.001 M,1))
    (7): PReLU(0.0 M,num_parameters=12)
    (8): Conv2d(0.001 M,1))
    (9): PReLU(0.0 M,num_parameters=12)
    (10): Conv2d(0.001 M,5.684% Params,0.409 GMac,0.900% MACs,1))
    (11): PReLU(0.0 M,num_parameters=56)
  )
  (final): ConvTranspose2d(0.005 M,35.420% Params,40.833 GMac,89.770% MACs,kernel_size=(9,stride=(4,padding=(4,output_padding=(3,3))
)
computational complexity:       45.49 GMac

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...