diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 984945f71868..a543f78bd949 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -790,6 +790,16 @@ def _mx_dot(inputs, attrs): def _mx_batch_dot(inputs, attrs): assert len(inputs) == 2 a, b = inputs + a_shape = _infer_type(a).checked_type.shape + batch_shapes = None + if len(a_shape) > 3: + batch_shapes = a_shape[:-2] + a = _op.reverse_reshape(a, newshape=(-1, 0, 0)) + b_shape = _infer_type(b).checked_type.shape + if len(b_shape) > 3: + if batch_shapes is None: + batch_shapes = b_shape[:-2] + b = _op.reverse_reshape(b, newshape=(-1, 0, 0)) transpose_a = attrs.get_bool("transpose_a", False) transpose_b = attrs.get_bool("transpose_b", False) if transpose_a is True: @@ -797,7 +807,10 @@ def _mx_batch_dot(inputs, attrs): raise tvm.error.OpAttributeInvalid(msg.format(transpose_a)) if transpose_b is False: b = _op.transpose(b, axes=[0, 2, 1]) - return _op.nn.batch_matmul(a, b) + out = _op.nn.batch_matmul(a, b) + if batch_shapes is not None: + out = _op.reverse_reshape(out, newshape=tuple(batch_shapes) + (0, 0)) + return out def _mx_arange(inputs, attrs): @@ -2294,18 +2307,16 @@ def _mx_npi_pad(inputs, attrs): raise tvm.error.OpAttributeRequired('Attribute "mode" not found in operator pad.') if pad_mode not in ["constant", "edge", "reflect"]: raise tvm.error.OpAttributeInvalid("Value " + mode + ' in attribute "mode" is not valid') - pad_width = attrs.get_int_tuple("pad_width", None) - if pad_width is None: + if "pad_width" not in attrs.attrs: raise tvm.error.OpAttributeRequired('Attribute "pad_width" not found in operator pad.') - if None in pad_width: - raise tvm.error.OpAttributeInvalid( - 'Value None in attribute "pad_width" of operator Slice is not valid.' - ) + # Begin to parse tuple of tuple, we cannot use get_int_tuple here because it's a tuple of tuple. + pad_width = attrs.attrs["pad_width"] + pad_width = pad_width.replace("(", "[") + pad_width = pad_width.replace(")", "]") + pad_width = json.loads(pad_width) constant_values = attrs.get_float("constant_values", 0.0) - padding = tuple(tuple((b, a)) for b, a in zip(pad_width[::2], pad_width[1::2])) - return _op.nn.pad( - data=inputs[0], pad_width=padding, pad_value=constant_values, pad_mode=pad_mode + data=inputs[0], pad_width=pad_width, pad_value=constant_values, pad_mode=pad_mode ) @@ -2321,24 +2332,74 @@ def _mx_npx_reshape(inputs, attrs): shape = attrs.get_int_tuple("newshape") reverse = attrs.get_bool("reverse", False) shape_list = list(shape) - new_shape_list = [] - for num in shape_list: - if num > 0 or num == -1: - new_shape_list.append(num) - elif num == -2: - new_shape_list.append(0) - elif num == -4: - new_shape_list.append(-2) - elif num == -5: - new_shape_list.append(-3) - elif num == -6: - new_shape_list.append(-4) + old_shape = get_const_tuple(_infer_type(inputs[0]).checked_type.shape) + new_shape = [] + if reverse: + old_shape = old_shape[::-1] + shape_list = shape_list[::-1] + ptr = 0 + unknown_axis = None + src_ptr = 0 + while src_ptr < len(shape_list): + ele = shape_list[src_ptr] + src_ptr += 1 + if ele > 0: + new_shape.append(ele) + ptr += 1 + elif ele == -1: + new_shape.append(-1) + if unknown_axis is not None: + raise tvm.error.OpAttributeInvalid("Can only have one -1 in the input shape.") + unknown_axis = len(new_shape) + ptr += 1 + elif ele == -2: + new_shape.append(old_shape[ptr]) + ptr += 1 + elif ele == -3: + if old_shape[ptr] != 1: + raise tvm.error.OpAttributeInvalid( + "Dimension of the original shape " + "that corresponds to -3 must be 1. Received" + " {}".format(old_shape[ptr]) + ) + ptr += 1 + elif ele == -4: + new_shape += old_shape[ptr:] + break + elif ele == -5: + new_shape.append(old_shape[ptr] * old_shape[ptr + 1]) + ptr += 2 + elif ele == -6: + # Split axis + lhs = shape_list[src_ptr] + rhs = shape_list[src_ptr + 1] + src_ptr += 2 + if lhs == -1 and rhs == -1: + raise tvm.error.OpAttributeInvalid("The lhs and rhs can not both be -1.") + if lhs == -1: + if old_shape[ptr] % rhs != 0: + raise tvm.error.OpAttributeInvalid( + "When splitting the axis, " + "the dimension of the split axis must " + "be divisible by the splitted values." + ) + lhs = old_shape[ptr] // rhs + if rhs == -1: + if old_shape[ptr] % lhs != 0: + raise tvm.error.OpAttributeInvalid( + "When splitting the axis, " + "the dimension of the split axis must " + "be divisible by the splitted values." + ) + rhs = old_shape[ptr] // lhs + new_shape.append(lhs) + new_shape.append(rhs) + ptr += 1 else: - raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % num) - shape = tuple(new_shape_list) + raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % ele) if reverse: - return _op.reverse_reshape(inputs[0], newshape=shape) - return _op.reshape(inputs[0], newshape=shape) + new_shape = new_shape[::-1] + return _op.reshape(inputs[0], newshape=new_shape) def _mx_split_v2(inputs, attrs): @@ -2356,12 +2417,21 @@ def _mx_split_v2(inputs, attrs): def _mx_npi_where_rscalar(inputs, attrs): + cond, dat = inputs scalar = attrs.get_float("scalar") - dtype = _infer_type(inputs[1]).checked_type.dtype + cond_shape = get_const_tuple(_infer_type(cond).checked_type.shape) + dat_shape = get_const_tuple(_infer_type(dat).checked_type.shape) + dtype = _infer_type(dat).checked_type.dtype + # Check for broadcasting + out_shape = np.broadcast(np.empty(cond_shape), np.empty(dat_shape)).shape + if out_shape != cond_shape: + cond = _op.broadcast_to(cond, out_shape) + if out_shape != dat_shape: + dat = _op.broadcast_to(dat, out_shape) scalar = _expr.const(scalar, dtype=dtype) - ones = _op.ones_like(inputs[1]) + ones = _op.ones_like(dat) scalar = _op.multiply(ones, scalar) - return _op.where(inputs[0], inputs[1], scalar) + return _op.where(cond, dat, scalar) # Note: due to attribute conversion constraint @@ -2382,13 +2452,13 @@ def _mx_npi_where_rscalar(inputs, attrs): "reshape_like", "zeros_like", "ones_like", - "where", "cos", "cosh", "sin", "sinh", "tan", "tanh", + "where", ] _convert_map = { @@ -2609,6 +2679,7 @@ def _mx_npi_where_rscalar(inputs, attrs): "_npi_concatenate": _mx_npi_concatenate, "_npx_reshape": _mx_npx_reshape, "_np_copy": _rename(_op.copy), + "_npi_copy": _rename(_op.copy), "_npi_power": _rename(_op.power), "_npi_power_scalar": _binop_scalar(_op.power), "_npi_multiply": _rename(_op.multiply), @@ -2617,6 +2688,7 @@ def _mx_npi_where_rscalar(inputs, attrs): "_npi_add_scalar": _binop_scalar(_op.add), "_npi_where_rscalar": _mx_npi_where_rscalar, "_npi_less": _rename(_op.less), + "_npi_less_equal": _mx_compare(_op.less_equal, _rename), "_npi_tanh": _rename(_op.tanh), "_npi_true_divide_scalar": _binop_scalar(_op.divide), } @@ -2728,7 +2800,6 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None): else: raise RuntimeError("unexpected type %s" % type(res)) node_map[nid] = res - outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) func = _function.Function(analysis.free_vars(outputs), outputs) diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index e3f08160509e..4e5f6efc815a 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -37,6 +37,9 @@ def batch_matmul(cfg, x, y, out_shape=None): 3-D with shape [batch, M, K] y : tvm.te.Tensor 3-D with shape [batch, N, K] + out_shape : tuple or None + Shape of the outputs + Returns ------- output : tvm.te.Tensor @@ -135,7 +138,7 @@ def _default_batch_matmul_config(cfg, M, N, K): @autotvm.register_topi_compute("batch_matmul_cblas.x86") -def batch_matmul_cblas(cfg, x, y): +def batch_matmul_cblas(cfg, x, y, out_shape=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. @@ -147,6 +150,9 @@ def batch_matmul_cblas(cfg, x, y): 3-D with shape [batch, M, K] y : tvm.te.Tensor 3-D with shape [batch, N, K] + out_shape : tuple or None + Shape of the output + Returns ------- output : tvm.te.Tensor @@ -157,6 +163,10 @@ def batch_matmul_cblas(cfg, x, y): YB, N, YK = get_const_tuple(y.shape) assert XB == YB, "batch dimension doesn't match" assert XK == YK, "shapes of x and y is inconsistant" + if out_shape is not None: + assert out_shape[0] == XB, "got invalid output shape" + assert out_shape[1] == M, "got invalid output shape" + assert out_shape[2] == N, "got invalid output shape" cfg.add_flop(XB * M * N * XK * 2) return cblas.batch_matmul(x, y, False, True) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 44307f4e60fe..79c587fc7f9e 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1932,7 +1932,10 @@ def verify(data_shape, axis, use_length, length): @pytest.mark.skipif(not hasattr(mx.sym.np, "pad"), reason="mx.sym.np.pad hasn't been publish yet") @pytest.mark.parametrize( "data_shape, pad_width", - [((1, 1, 3, 5), (0, 0, 0, 0, 1, 2, 3, 4)), ((1, 1, 3, 5, 7), (0, 0, 0, 0, 1, 2, 3, 4, 5, 6))], + [ + ((1, 1, 3, 5), ((0, 0), (0, 0), (1, 2), (3, 4))), + ((1, 1, 3, 5, 7), ((0, 0), (0, 0), (1, 2), (3, 4), (5, 6))), + ], ) @pytest.mark.parametrize("mode", ["constant", "edge", "reflect"]) @pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32"]) @@ -1943,19 +1946,17 @@ def test_forward_npi_pad(data_shape, pad_width, mode, dtype, constant_value, tar data_np = np.random.uniform(size=data_shape).astype(dtype) data = mx.sym.var("data") if mode == "constant": - ref_res = mx.ndarray.pad( - mx.nd.array(data_np), mode=mode, pad_width=pad_width, constant_value=constant_value - ) + ref_res = np.pad(data_np, mode=mode, pad_width=pad_width, constant_values=constant_value) mx_sym = mx.sym.np.pad( data.as_np_ndarray(), mode=mode, pad_width=pad_width, constant_values=constant_value ) else: - ref_res = mx.ndarray.pad(mx.nd.array(data_np), mode=mode, pad_width=pad_width) + ref_res = np.pad(data_np, mode=mode, pad_width=pad_width) mx_sym = mx.sym.np.pad(data.as_np_ndarray(), mode=mode, pad_width=pad_width) mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(data_np) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) @pytest.mark.skipif( @@ -2029,8 +2030,12 @@ def test_forward_np_copy(data_shape, dtype, target, ctx, kind): ((2, 3, 8), (-2, -2, 2, -1), False), ((8, 3, 3, 3, 4, 4), (-6, 2, -1, -4), False), ((8, 3, 3, 3, 4, 4), (-5, -4), False), + ((1, 8, 3, 3, 3, 4, 4), (-3, -5, -4), False), + ((8, 1, 3, 4), (-2, -3, -1), False), ((8, 3, 3, 3, 3, 8), (-4, -5), True), ((8, 3, 2, 4, 8), (-4, -1, 2, -6), True), + ((3, 2, 4, 8, 1, 1), (-4, -1, 2, -6, -5, -3), True), + ((2, 4, 1, 8), (-4, -3, -1, 2, -6), True), ], ) def test_forward_npx_reshape(data_shape, out_shape, dtype, target, reverse, ctx, kind): @@ -2117,16 +2122,21 @@ def test_forward_npi_tanh(data_shape, dtype, target, ctx, kind): @pytest.mark.skipif(not hasattr(mx.np, "where"), reason="mx.np.where hasn't been publish yet") -@pytest.mark.parametrize("data_shape", [(2, 2, 2), (2, 7, 2), (1, 8), (2, 2), (1, 3)]) +@pytest.mark.parametrize( + "data_shape,cond_shape", + [[(2, 2, 2), (2, 2, 2)], [(2, 7, 2), (7, 2)], [(2, 2), (1, 2)], [(1, 3), (3, 3)]], +) @pytest.mark.parametrize("data_dtype", ["float64", "float32", "int64", "int32", "bool"]) @pytest.mark.parametrize("cond_dtype", ["float64", "float32", "int64", "int32", "bool"]) @pytest.mark.parametrize("scalar", [1.0, 2.0]) @tvm.testing.parametrize_targets @pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) -def test_forward_npi_where_rscalar(data_shape, cond_dtype, data_dtype, scalar, target, ctx, kind): +def test_forward_npi_where_rscalar( + data_shape, cond_shape, data_dtype, cond_dtype, scalar, target, ctx, kind +): if data_dtype == "bool": scalar = scalar == 0.0 - cond_np = np.random.uniform(size=data_shape).astype(cond_dtype) + cond_np = np.random.uniform(size=cond_shape).astype(cond_dtype) data_np = np.random.uniform(size=data_shape).astype(data_dtype) cond = mx.sym.var("condition") data = mx.sym.var("x") @@ -2136,7 +2146,7 @@ def test_forward_npi_where_rscalar(data_shape, cond_dtype, data_dtype, scalar, t dtypeDic["condition"] = cond_dtype dtypeDic["x"] = data_dtype mod, _ = relay.frontend.from_mxnet( - mx_sym, shape={"condition": data_shape, "x": data_shape}, dtype=dtypeDic + mx_sym, shape={"condition": cond_shape, "x": data_shape}, dtype=dtypeDic ) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(cond_np, data_np) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 8c724daaa9d0..37a59c30f410 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -134,7 +134,7 @@ def check_binary_op(opfunc, ref, dtype): continue intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data, y_data) - np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) + np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01, atol=1e-3) for opfunc, ref in [ (relay.add, np.add),