diff --git a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc index b191553f16da..9aacbc02b061 100644 --- a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc @@ -79,7 +79,9 @@ TBlob PrependAxes(const TBlob& src, const int dst_ndim) { return src.reshape(dst_shape); } -struct TVMBinaryBroadcastCompute { + +template +struct GetBinaryBroadcastCompute { const char* func; void operator()(const nnvm::NodeAttrs& attrs, const mxnet::OpContext& ctx, @@ -96,6 +98,38 @@ struct TVMBinaryBroadcastCompute { std::vector type_codes; std::vector values; + const TBlob& a = inputs[0]; + const TBlob& b = inputs[1]; + if (a.type_flag_ != b.type_flag_) { + if (outputs[0].shape_.Size() == 0U) return; + mxnet::TShape new_lshape, new_rshape, new_oshape; + const TBlob& lhs = inputs[0]; + const TBlob& rhs = inputs[1]; + const TBlob& out = outputs[0]; + int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_, + &new_lshape, &new_rshape, &new_oshape); + if (!ndim) { + ElemwiseBinaryOp::ComputeLogic(attrs, ctx, inputs, req, outputs); + } else { + if (req[0] == kNullOp) return; + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH_WITH_BOOL(lhs.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(rhs.type_flag_, EType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + lhs.dptr(), rhs.dptr(), + out.dptr()); + }); + }); + }); + } + return; + } + const int ondim = outputs[0].shape_.ndim(); const size_t num_args = inputs.size() + outputs.size(); type_codes.resize(num_args); @@ -146,13 +180,15 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(logical_xor); #define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(name) \ NNVM_REGISTER_OP(_npi_##name) \ - .set_attr("FCompute", TVMBinaryBroadcastCompute{func_##name##_cpu}) + .set_attr("FCompute", GetBinaryBroadcastCompute{func_##name##_cpu}) #if MXNET_USE_CUDA #define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(name) \ NNVM_REGISTER_OP(_npi_##name) \ - .set_attr("FCompute", TVMBinaryBroadcastCompute{func_##name##_gpu}) + .set_attr("FCompute", GetBinaryBroadcastCompute{func_##name##_gpu}) MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(equal); MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(not_equal); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 7ee46a0d315c..81449f358a70 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3143,17 +3143,16 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): 'mod': (1.0, 5.0, None, None), 'power': (1.0, 3.0, lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2, lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)), + 'equal': (0.0, 2.0, None, None), + 'not_equal': (0.0, 2.0, None, None), + 'greater': (0.0, 2.0, None, None), + 'less': (0.0, 2.0, None, None), + 'greater_equal': (0.0, 2.0, None, None), + 'less_equal': (0.0, 2.0, None, None), + 'logical_and': (0.0, 2.0, None, None), + 'logical_or': (0.0, 2.0, None, None), + 'logical_xor': (0.0, 2.0, None, None), } - if not has_tvm_ops(): - funcs['equal'] = (0.0, 2.0, None, None) - funcs['not_equal'] = (0.0, 2.0, None, None) - funcs['greater'] = (0.0, 2.0, None, None) - funcs['less'] = (0.0, 2.0, None, None) - funcs['greater_equal'] = (0.0, 2.0, None, None) - funcs['less_equal'] = (0.0, 2.0, None, None) - funcs['logical_and'] = (0.0, 2.0, None, None) - funcs['logical_or'] = (0.0, 2.0, None, None) - funcs['logical_xor'] = (0.0, 2.0, None, None) shape_pairs = [((3, 2), (3, 2)), ((3, 2), (3, 1)),