Skip to content

Commit

Permalink
add annotations and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CjhHa1 committed Dec 1, 2022
1 parent 4e8e653 commit 0f512ab
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 8 deletions.
54 changes: 50 additions & 4 deletions python/paddle/fluid/tests/unittests/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,17 +291,63 @@ def test_flops(self):
)
== 3 * 12 * 12 * 12 * 2 * 8
)
self.assertTrue(
flops('relu', {'X': [[12, 12, 12]]}, {}) == 12 * 12 * 12
)
self.assertTrue(
flops('softmax', {'X': [[12, 12, 12]]}, {}) == 3 * 12 * 12 * 12
)
self.assertTrue(
flops('c_embedding', {'Ids': [[12, 12]], 'W': [[12, 12, 3]]}, {})
== 0
)

self.assertTrue(
flops(
'elu',
{
'X': [[12, 12]],
},
{},
)
== 144
)
self.assertTrue(
flops(
'leaky_relu',
{
'X': [[12, 12]],
},
{},
)
== 144
)
self.assertTrue(
flops(
'prelu',
{
'X': [[12, 12]],
},
{},
)
== 144
)
self.assertTrue(
flops(
'relu6',
{
'X': [[12, 12]],
},
{},
)
== 144
)
self.assertTrue(
flops(
'silu',
{
'X': [[12, 12]],
},
{},
)
== 144
)

if __name__ == '__main__':
paddle.enable_static()
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/utils/flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ def _matmul_v2_flops(input_shapes, attrs):


def _relu_class_flops(input_shapes, attrs):
"""FLOPs computation for relu_like ops.
For elu/leaky_relu/prelu/relu/relu6/silu (input):
equation: flops = (numel)total number of elements in the input tensor.
"""
input = input_shapes.get('X')[0]
return prod(input)

Expand All @@ -237,10 +241,6 @@ def _prelu_flops(input_shapes, attrs):

@register_flops("relu")
def _relu_flops(input_shapes, attrs):
"""FLOPs computation for relu op.
For relu(input):
equation: flops = (numel)total number of elements in the input tensor.
"""
return _relu_class_flops(input_shapes, attrs)


Expand Down

0 comments on commit 0f512ab

Please sign in to comment.