From 2f41a39688bf5fe2f18d8481f9ae012fb6a05614 Mon Sep 17 00:00:00 2001 From: MORITA Kazutaka Date: Thu, 2 Apr 2020 07:49:37 +0900 Subject: [PATCH] [FRONTEND][MXNET] Use leaky by default for LeakyReLU (#5192) --- python/tvm/relay/frontend/mxnet.py | 2 +- tests/python/frontend/mxnet/test_forward.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index b918f9b1adc5..5c8e726f7be3 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -510,7 +510,7 @@ def _mx_pad(inputs, attrs): pad_mode=pad_mode) def _mx_leaky_relu(inputs, attrs): - act_type = attrs.get_str("act_type") + act_type = attrs.get_str("act_type", "leaky") if act_type == "leaky": return _op.nn.leaky_relu(inputs[0], alpha=attrs.get_float("slope", 0.25)) if act_type == "prelu": diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 102905a408db..f01544715e0d 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -107,6 +107,14 @@ def test_forward_resnet(): mx_sym = model_zoo.mx_resnet(18) verify_mxnet_frontend_impl(mx_sym) +def test_forward_leaky_relu(): + data = mx.sym.var('data') + data = mx.sym.concat(data, -data, dim=1) # negative part explicitly + mx_sym = mx.sym.LeakyReLU(data) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) + mx_sym = mx.sym.LeakyReLU(data, act_type='leaky') + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) + def test_forward_elu(): data = mx.sym.var('data') data = mx.sym.concat(data, -data, dim=1) # negative part explicitly @@ -979,6 +987,7 @@ def verify(x, shape, dtype): test_forward_mlp() test_forward_vgg() test_forward_resnet() + test_forward_leaky_relu() test_forward_elu() test_forward_rrelu() test_forward_prelu() @@ -1030,4 +1039,4 @@ def verify(x, shape, dtype): test_forward_deconvolution() test_forward_cond() test_forward_make_loss() - test_forward_unravel_index() \ No newline at end of file + test_forward_unravel_index()