Skip to content

Commit

Permalink
[FRONTEND][MXNET] Use leaky by default for LeakyReLU (#5192)
Browse files Browse the repository at this point in the history
  • Loading branch information
kazum authored Apr 1, 2020
1 parent 302e8ee commit 2f41a39
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def _mx_pad(inputs, attrs):
pad_mode=pad_mode)

def _mx_leaky_relu(inputs, attrs):
act_type = attrs.get_str("act_type")
act_type = attrs.get_str("act_type", "leaky")
if act_type == "leaky":
return _op.nn.leaky_relu(inputs[0], alpha=attrs.get_float("slope", 0.25))
if act_type == "prelu":
Expand Down
11 changes: 10 additions & 1 deletion tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ def test_forward_resnet():
mx_sym = model_zoo.mx_resnet(18)
verify_mxnet_frontend_impl(mx_sym)

def test_forward_leaky_relu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
mx_sym = mx.sym.LeakyReLU(data, act_type='leaky')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

def test_forward_elu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
Expand Down Expand Up @@ -979,6 +987,7 @@ def verify(x, shape, dtype):
test_forward_mlp()
test_forward_vgg()
test_forward_resnet()
test_forward_leaky_relu()
test_forward_elu()
test_forward_rrelu()
test_forward_prelu()
Expand Down Expand Up @@ -1030,4 +1039,4 @@ def verify(x, shape, dtype):
test_forward_deconvolution()
test_forward_cond()
test_forward_make_loss()
test_forward_unravel_index()
test_forward_unravel_index()

0 comments on commit 2f41a39

Please sign in to comment.