Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
rectify
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiyan66 committed Jul 23, 2020
1 parent c6ec8d8 commit 73d8838
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 13 deletions.
42 changes: 39 additions & 3 deletions src/operator/numpy/np_elemwise_broadcast_logic_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ TBlob PrependAxes(const TBlob& src, const int dst_ndim) {
return src.reshape(dst_shape);
}

struct TVMBinaryBroadcastCompute {

template<typename xpu, typename OP>
struct GetBinaryBroadcastCompute {
const char* func;
void operator()(const nnvm::NodeAttrs& attrs,
const mxnet::OpContext& ctx,
Expand All @@ -96,6 +98,38 @@ struct TVMBinaryBroadcastCompute {
std::vector<int> type_codes;
std::vector<TVMValue> values;

const TBlob& a = inputs[0];
const TBlob& b = inputs[1];
if (a.type_flag_ != b.type_flag_) {
if (outputs[0].shape_.Size() == 0U) return;
mxnet::TShape new_lshape, new_rshape, new_oshape;
const TBlob& lhs = inputs[0];
const TBlob& rhs = inputs[1];
const TBlob& out = outputs[0];
int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_,
&new_lshape, &new_rshape, &new_oshape);
if (!ndim) {
ElemwiseBinaryOp::ComputeLogic<xpu, OP>(attrs, ctx, inputs, req, outputs);
} else {
if (req[0] == kNullOp) return;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH_WITH_BOOL(lhs.type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(rhs.type_flag_, EType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape,
lhs.dptr<DType>(), rhs.dptr<EType>(),
out.dptr<bool>());
});
});
});
}
return;
}

const int ondim = outputs[0].shape_.ndim();
const size_t num_args = inputs.size() + outputs.size();
type_codes.resize(num_args);
Expand Down Expand Up @@ -146,13 +180,15 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(logical_xor);

#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(name) \
NNVM_REGISTER_OP(_npi_##name) \
.set_attr<FCompute>("FCompute<cpu>", TVMBinaryBroadcastCompute{func_##name##_cpu})
.set_attr<FCompute>("FCompute<cpu>", GetBinaryBroadcastCompute<cpu, \
mshadow_op::np_##name>{func_##name##_cpu})

#if MXNET_USE_CUDA

#define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(name) \
NNVM_REGISTER_OP(_npi_##name) \
.set_attr<FCompute>("FCompute<gpu>", TVMBinaryBroadcastCompute{func_##name##_gpu})
.set_attr<FCompute>("FCompute<gpu>", GetBinaryBroadcastCompute<gpu, \
mshadow_op::np_##name>{func_##name##_gpu})

MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(equal);
MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(not_equal);
Expand Down
19 changes: 9 additions & 10 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3143,17 +3143,16 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
'mod': (1.0, 5.0, None, None),
'power': (1.0, 3.0, lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2,
lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)),
'equal': (0.0, 2.0, None, None),
'not_equal': (0.0, 2.0, None, None),
'greater': (0.0, 2.0, None, None),
'less': (0.0, 2.0, None, None),
'greater_equal': (0.0, 2.0, None, None),
'less_equal': (0.0, 2.0, None, None),
'logical_and': (0.0, 2.0, None, None),
'logical_or': (0.0, 2.0, None, None),
'logical_xor': (0.0, 2.0, None, None),
}
if not has_tvm_ops():
funcs['equal'] = (0.0, 2.0, None, None)
funcs['not_equal'] = (0.0, 2.0, None, None)
funcs['greater'] = (0.0, 2.0, None, None)
funcs['less'] = (0.0, 2.0, None, None)
funcs['greater_equal'] = (0.0, 2.0, None, None)
funcs['less_equal'] = (0.0, 2.0, None, None)
funcs['logical_and'] = (0.0, 2.0, None, None)
funcs['logical_or'] = (0.0, 2.0, None, None)
funcs['logical_xor'] = (0.0, 2.0, None, None)

shape_pairs = [((3, 2), (3, 2)),
((3, 2), (3, 1)),
Expand Down

0 comments on commit 73d8838

Please sign in to comment.