diff --git a/src/operator/pad-inl.h b/src/operator/pad-inl.h index 520cd124c49a..0b43e2d0cfd2 100644 --- a/src/operator/pad-inl.h +++ b/src/operator/pad-inl.h @@ -189,6 +189,17 @@ class PadProp : public OperatorProperty { return param_.__DICT__(); } + bool InferType(std::vector *in_type, + std::vector *out_type, + std::vector *aux_type) const override { + int dtype = (*in_type)[0]; + type_assign(&dtype, (*out_type)[0]); + + TYPE_ASSIGN_CHECK(*in_type, 0, dtype); + TYPE_ASSIGN_CHECK(*out_type, 0, dtype); + return dtype != -1; + } + bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 11180ebbc5d4..ee6997f61dc4 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2905,16 +2905,16 @@ def test_roipooling(): numeric_eps=1e-4, rtol=1e-1, atol=1E-4) -def check_pad_with_shape(shape, xpu, pad_width, mode): +def check_pad_with_shape(shape, xpu, pad_width, mode, dtype="float64"): # bind with label - X = mx.symbol.Variable('X') + X = mx.symbol.Variable('X', dtype=dtype) Y = mx.symbol.Pad(data=X, mode=mode, pad_width=pad_width) - x = mx.random.uniform(-1, 1, shape, ctx=mx.cpu()).copyto(xpu) + x = mx.random.uniform(-1, 1, shape, ctx=mx.cpu(), dtype=dtype).copyto(xpu) # numpy result pad_grouped = list(zip(*[iter(list(pad_width))] * 2)) np_out = np.pad(x.asnumpy(), pad_grouped, mode) # mxnet result - grad = mx.nd.empty(shape, ctx = xpu) + grad = mx.nd.empty(shape, ctx = xpu, dtype=dtype) exec1 = Y.bind(xpu, args = [x], args_grad = {'X': grad}) exec1.forward(is_train=True) out = exec1.outputs[0].asnumpy() @@ -2926,16 +2926,20 @@ def check_pad_with_shape(shape, xpu, pad_width, mode): @with_seed() def test_pad(): + ctx = default_context() shape1 = (2, 3, 3, 5) pad1 = (0, 0, 0, 0, 1, 2, 3, 4) shape2 = (2, 3, 3, 5, 4) pad2 = (0, 0, 0, 0, 1, 2, 3, 4, 3, 1) - check_pad_with_shape(shape1, default_context(), pad1, 'constant') - check_pad_with_shape(shape1, default_context(), pad1, 'edge') - check_pad_with_shape(shape2, default_context(), pad2, 'constant') - check_pad_with_shape(shape2, default_context(), pad2, 'edge') - check_pad_with_shape(shape1, default_context(), pad1, 'reflect') - check_pad_with_shape(shape2, default_context(), pad2, 'reflect') + # note: this op doesn't support ints yet. Add tests when supported + dtypes = ["float16", "float32", "float64"] + for dtype in dtypes: + check_pad_with_shape(shape1, ctx, pad1, 'constant', dtype) + check_pad_with_shape(shape1, ctx, pad1, 'edge', dtype) + check_pad_with_shape(shape2, ctx, pad2, 'constant', dtype) + check_pad_with_shape(shape2, ctx, pad2, 'edge', dtype) + check_pad_with_shape(shape1, ctx, pad1, 'reflect', dtype) + check_pad_with_shape(shape2, ctx, pad2, 'reflect', dtype) def np_instance_norm(data, weight, bias, eps):