From 4b7ec59103153439ac09fbc73a22b85242a45bcd Mon Sep 17 00:00:00 2001 From: wicky Date: Sun, 5 Apr 2020 22:29:41 +0800 Subject: [PATCH 1/2] Relaxing type requirements for broadcast_like --- src/operator/tensor/broadcast_reduce_op_value.cc | 11 ++++++++++- tests/python/unittest/test_operator.py | 9 +++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/operator/tensor/broadcast_reduce_op_value.cc b/src/operator/tensor/broadcast_reduce_op_value.cc index 0a14a2008557..71be8f814f3b 100644 --- a/src/operator/tensor/broadcast_reduce_op_value.cc +++ b/src/operator/tensor/broadcast_reduce_op_value.cc @@ -138,7 +138,16 @@ NNVM_REGISTER_OP(broadcast_like) [](const NodeAttrs& attrs) { return std::vector{"lhs", "rhs"}; }) -.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FInferType", [](const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2) << " in operator " << attrs.name; + std::vector checked_in_attrs = { (*in_attrs)[0] }; + bool ret = !type_is_none((*in_attrs)[1]) && + ElemwiseType<1, 1>(attrs, &checked_in_attrs, out_attrs); + (*in_attrs)[0] = checked_in_attrs[0]; + return ret; +}) .set_attr("FGradient", [](const nnvm::ObjectPtr& n, const std::vector& ograds) { diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 6cbbc5dd0509..6b4c381c3344 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3119,6 +3119,15 @@ def test_reshape_like_different_types(): z = mx.nd.reshape_like(x, y) assert_allclose(z.asnumpy(), [[0,0],[0,0],[0,0]]) +@with_seed() +def test_broadcast_like_different_types(): + x = mx.nd.zeros((2, 1)) + y = mx.nd.ones((2, 2)) + + y = mx.nd.array(y).astype('int32') + z = mx.nd.broadcast_like(x, y) + assert_allclose(z.asnumpy(), [[0,0],[0,0]]) + @with_seed() def test_flip(): for ndim in range(1, 6): From feef46cb8c35685267c02c8be545be298a18e168 Mon Sep 17 00:00:00 2001 From: wicky Date: Mon, 6 Apr 2020 21:08:06 +0800 Subject: [PATCH 2/2] enhance unit test --- tests/python/unittest/test_operator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 6b4c381c3344..cb835339bee5 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3127,6 +3127,7 @@ def test_broadcast_like_different_types(): y = mx.nd.array(y).astype('int32') z = mx.nd.broadcast_like(x, y) assert_allclose(z.asnumpy(), [[0,0],[0,0]]) + assert x.dtype == z.dtype @with_seed() def test_flip():