From 5ea8c8a2ca0103657064d4e84f8820830af2af1b Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Sun, 3 Mar 2019 10:24:20 -0800 Subject: [PATCH] [Relay][Frontend] Add a few mxnet ops in relay frontend (#2704) --- python/tvm/relay/frontend/mxnet.py | 79 +++++++++++++------- tests/python/frontend/mxnet/test_forward.py | 83 +++++++++++++++++++++ 2 files changed, 136 insertions(+), 26 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 3d3bb8e4fd84e..1f1d18e240cd1 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -64,6 +64,13 @@ def _stable_softrelu(x): raise RuntimeError("Do not support act_type: {}".format(act_type)) +def _mx_compare(new_op, wrapper): + def impl(inputs, attrs): + dtype = ir_pass.infer_type(inputs[0]).checked_type.dtype + return wrapper(new_op)(inputs, attrs).astype(dtype) + return impl + + def _mx_conv2d(inputs, attrs): kernel_size = attrs.get_int_tuple("kernel") if len(kernel_size) != 2: @@ -333,32 +340,52 @@ def _mx_roi_align(inputs, attrs): ] _convert_map = { - "_copy" : _rename(_op.copy), - "relu" : _rename(_op.nn.relu), - "broadcast_add" : _rename(_op.add), - "broadcast_sub" : _rename(_op.subtract), - "broadcast_mul" : _rename(_op.multiply), - "broadcast_div" : _rename(_op.divide), - "elemwise_add" : _rename(_op.add), - "elemwise_sub" : _rename(_op.subtract), - "elemwise_mul" : _rename(_op.multiply), - "elemwise_div" : _rename(_op.divide), - "flatten" : _rename(_op.nn.batch_flatten), - "Flatten" : _rename(_op.nn.batch_flatten), - "_plus_scalar" : _binop_scalar(_op.add), - "__add_scalar__": _binop_scalar(_op.add), - "__sub_scalar__": _binop_scalar(_op.subtract), - "_minus_scalar" : _binop_scalar(_op.subtract), - "__mul_scalar__": _binop_scalar(_op.multiply), - "_mul_scalar" : _binop_scalar(_op.multiply), - "__div_scalar__": _binop_scalar(_op.divide), - "_div_scalar" : _binop_scalar(_op.divide), - "__pow_scalar__": _binop_scalar(_op.power), - "_rminus_scalar": _rbinop_scalar(_op.subtract), - "__rsub_scalar__": _rbinop_scalar(_op.subtract), - "_rdiv_scalar" : _rbinop_scalar(_op.divide), - "__rdiv_scalar__" : _rbinop_scalar(_op.divide), - "__rpow_scalar__": _rbinop_scalar(_op.power), + "_copy" : _rename(_op.copy), + "relu" : _rename(_op.nn.relu), + "broadcast_add" : _rename(_op.add), + "broadcast_sub" : _rename(_op.subtract), + "broadcast_mul" : _rename(_op.multiply), + "broadcast_div" : _rename(_op.divide), + "broadcast_mod" : _rename(_op.mod), + "broadcast_maximum" : _rename(_op.maximum), + "broadcast_minimum" : _rename(_op.minimum), + "broadcast_equal" : _mx_compare(_op.equal, _rename), + "broadcast_not_equal" : _mx_compare(_op.not_equal, _rename), + "broadcast_greater" : _mx_compare(_op.greater, _rename), + "broadcast_greater_equal": _mx_compare(_op.greater_equal, _rename), + "broadcast_lesser" : _mx_compare(_op.less, _rename), + "broadcast_lesser_equal" : _mx_compare(_op.less_equal, _rename), + "elemwise_add" : _rename(_op.add), + "elemwise_sub" : _rename(_op.subtract), + "elemwise_mul" : _rename(_op.multiply), + "elemwise_div" : _rename(_op.divide), + "_maximum" : _rename(_op.maximum), + "_minimum" : _rename(_op.minimum), + "flatten" : _rename(_op.nn.batch_flatten), + "Flatten" : _rename(_op.nn.batch_flatten), + "__add_scalar__" : _binop_scalar(_op.add), + "_plus_scalar" : _binop_scalar(_op.add), + "__sub_scalar__" : _binop_scalar(_op.subtract), + "_minus_scalar" : _binop_scalar(_op.subtract), + "__mul_scalar__" : _binop_scalar(_op.multiply), + "_mul_scalar" : _binop_scalar(_op.multiply), + "__div_scalar__" : _binop_scalar(_op.divide), + "_div_scalar" : _binop_scalar(_op.divide), + "__pow_scalar__" : _binop_scalar(_op.power), + "_power_scalar" : _binop_scalar(_op.power), + "__rsub_scalar__" : _rbinop_scalar(_op.subtract), + "_rminus_scalar" : _rbinop_scalar(_op.subtract), + "__rdiv_scalar__" : _rbinop_scalar(_op.divide), + "_rdiv_scalar" : _rbinop_scalar(_op.divide), + "__rpow_scalar__" : _rbinop_scalar(_op.power), + "_equal_scalar" : _mx_compare(_op.equal, _binop_scalar), + "_not_equal_scalar" : _mx_compare(_op.not_equal, _binop_scalar), + "_greater_scalar" : _mx_compare(_op.greater, _binop_scalar), + "_greater_equal_scalar" : _mx_compare(_op.greater_equal, _binop_scalar), + "_lesser_scalar" : _mx_compare(_op.less, _binop_scalar), + "_lesser_equal_scalar" : _mx_compare(_op.less_equal, _binop_scalar), + "_maximum_scalar" : _binop_scalar(_op.maximum), + "_minimum_scalar" : _binop_scalar(_op.minimum), # reduction ops "max" : _reduce(_op.max), "min" : _reduce(_op.min), diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 671316079308d..ee47d72046ed4 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1,4 +1,5 @@ import numpy as np +import operator import tvm from tvm.contrib import graph_runtime @@ -256,6 +257,85 @@ def verify(start, stop, step): verify(20, 1, -1) verify(20, 1, -1.5) +def _mx_symbol(F, op_name, inputs): + op = getattr(F, op_name) + return op(*inputs) + +def test_forward_broadcast_ops(): + for op in ["broadcast_add", "broadcast_sub", "broadcast_mul", + "broadcast_div", "broadcast_mod", "broadcast_maximum", + "broadcast_minimum", "broadcast_equal", "broadcast_not_equal", + "broadcast_greater", "broadcast_greater_equal", + "broadcast_lesser", "broadcast_lesser_equal"]: + a_shape = (3, 4, 5) + b_shape = (4, 5) + if op == "broadcast_mod": + dtype = 'int32' + a_np = np.random.randint(1, 100, size=a_shape).astype(dtype) + b_np = np.random.randint(1, 100, size=b_shape).astype(dtype) + else: + dtype = 'float32' + a_np = np.random.uniform(size=a_shape).astype(dtype) + b_np = np.random.uniform(size=b_shape).astype(dtype) + mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')]) + ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)]) + shapes = {'a': a_shape, 'b': b_shape} + new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)(a_np, b_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + +def test_forward_elemwise_ops(): + for op in ["elemwise_add", "elemwise_sub", "elemwise_mul", + "elemwise_div", "maximum", "minimum"]: + shape = (3, 4, 5) + dtype = 'float32' + a_np = np.random.uniform(size=shape).astype(dtype) + b_np = np.random.uniform(size=shape).astype(dtype) + mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')]) + ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)]) + shapes = {'a': shape, 'b': shape} + new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)(a_np, b_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + +def test_forward_scalar_ops(): + for op in [operator.add, operator.sub, operator.mul, operator.truediv, + operator.pow, operator.lt, operator.le, operator.eq, + operator.ne, operator.gt, operator.ge]: + dtype='float32' + a_shape = (3, 4, 5) + a_np = np.random.uniform(size=a_shape).astype(dtype) + b_scalar = 2.3 + mx_sym = op(mx.sym.var('a'), b_scalar) + ref_res = op(mx.nd.array(a_np), b_scalar) + shapes = {'a': a_shape} + new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)(a_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + for op in ["maximum", "minimum"]: + dtype='float32' + a_shape = (3, 4, 5) + a_np = np.random.uniform(size=a_shape).astype(dtype) + b_scalar = 2.3 + mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), b_scalar]) + ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), b_scalar]) + shapes = {'a': a_shape} + new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)(a_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + if __name__ == '__main__': test_forward_mlp() @@ -280,3 +360,6 @@ def verify(start, stop, step): test_forward_argmin() test_forward_where() test_forward_arange() + test_forward_broadcast_ops() + test_forward_elemwise_ops() + test_forward_scalar_ops()