From 0f512ab1eb7850f0a2304ff435a88fa96cdf28ee Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 28 Nov 2022 11:50:44 +0800 Subject: [PATCH] add annotations and tests --- .../fluid/tests/unittests/test_profiler.py | 54 +++++++++++++++++-- python/paddle/utils/flops.py | 8 +-- 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_profiler.py b/python/paddle/fluid/tests/unittests/test_profiler.py index 5fbedfaaa7ff0c..f5432b0585dc36 100644 --- a/python/paddle/fluid/tests/unittests/test_profiler.py +++ b/python/paddle/fluid/tests/unittests/test_profiler.py @@ -291,9 +291,6 @@ def test_flops(self): ) == 3 * 12 * 12 * 12 * 2 * 8 ) - self.assertTrue( - flops('relu', {'X': [[12, 12, 12]]}, {}) == 12 * 12 * 12 - ) self.assertTrue( flops('softmax', {'X': [[12, 12, 12]]}, {}) == 3 * 12 * 12 * 12 ) @@ -301,7 +298,56 @@ def test_flops(self): flops('c_embedding', {'Ids': [[12, 12]], 'W': [[12, 12, 3]]}, {}) == 0 ) - + self.assertTrue( + flops( + 'elu', + { + 'X': [[12, 12]], + }, + {}, + ) + == 144 + ) + self.assertTrue( + flops( + 'leaky_relu', + { + 'X': [[12, 12]], + }, + {}, + ) + == 144 + ) + self.assertTrue( + flops( + 'prelu', + { + 'X': [[12, 12]], + }, + {}, + ) + == 144 + ) + self.assertTrue( + flops( + 'relu6', + { + 'X': [[12, 12]], + }, + {}, + ) + == 144 + ) + self.assertTrue( + flops( + 'silu', + { + 'X': [[12, 12]], + }, + {}, + ) + == 144 + ) if __name__ == '__main__': paddle.enable_static() diff --git a/python/paddle/utils/flops.py b/python/paddle/utils/flops.py index a29de61c5d4794..ff4e576d08060d 100644 --- a/python/paddle/utils/flops.py +++ b/python/paddle/utils/flops.py @@ -216,6 +216,10 @@ def _matmul_v2_flops(input_shapes, attrs): def _relu_class_flops(input_shapes, attrs): + """FLOPs computation for relu_like ops. + For elu/leaky_relu/prelu/relu/relu6/silu (input): + equation: flops = (numel)total number of elements in the input tensor. + """ input = input_shapes.get('X')[0] return prod(input) @@ -237,10 +241,6 @@ def _prelu_flops(input_shapes, attrs): @register_flops("relu") def _relu_flops(input_shapes, attrs): - """FLOPs computation for relu op. - For relu(input): - equation: flops = (numel)total number of elements in the input tensor. - """ return _relu_class_flops(input_shapes, attrs)