From 75a9baeba263d05a99cef41890f1434f3647fea5 Mon Sep 17 00:00:00 2001 From: Ke Han Date: Mon, 24 Feb 2020 13:49:35 +0800 Subject: [PATCH] * Fix rfmod function --- python/mxnet/ndarray/numpy/_op.py | 2 +- python/mxnet/symbol/numpy/_symbol.py | 2 +- src/operator/mshadow_op.h | 11 +++++++++++ .../numpy/np_elemwise_broadcast_op_extended.cc | 9 +++++++++ .../numpy/np_elemwise_broadcast_op_extended.cu | 6 ++++++ src/operator/operator_tune.cc | 1 + 6 files changed, 29 insertions(+), 2 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 638c966dbb4f..7250a033bf80 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -1189,7 +1189,7 @@ def fmod(x1, x2, out=None, **kwargs): out : ndarray or scalar This is a scalar if both x1 and x2 are scalars. """ - return _ufunc_helper(x1, x2, _npi.fmod, _np.fmod, _npi.fmod_scalar, out) + return _ufunc_helper(x1, x2, _npi.fmod, _np.fmod, _npi.fmod_scalar, _npi.rfmod_scalar, out) @set_module('mxnet.ndarray.numpy') diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index ee45847e6257..3e5ec36675f9 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -1589,7 +1589,7 @@ def mod(x1, x2, out=None, **kwargs): @set_module('mxnet.symbol.numpy') @wrap_np_binary_func def fmod(x1, x2, out=None, **kwargs): - return _ufunc_helper(x1, x2, _npi.fmod, _np.fmod, _npi.fmod_scalar, out) + return _ufunc_helper(x1, x2, _npi.fmod, _np.fmod, _npi.fmod_scalar, _npi.rfmod_scalar, out) @set_module('mxnet.symbol.numpy') diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 947770cb8f45..9f732343d882 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -818,6 +818,17 @@ struct fmod : public mxnet_op::tunable { } }; +struct rfmod : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + if (a == DType(0)) { + return DType(0); + } else { + return DType(::fmod(static_cast(b), static_cast(a))); + } + } +}; + template<> MSHADOW_XINLINE mshadow::half::half2_t mod::Map (mshadow::half::half2_t a, diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc index 308e254c2b4d..3d79e3a4105a 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc @@ -455,5 +455,14 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_fmod_scalar) .set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) .set_attr("FCompute", BinaryScalarOp::Backward); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rfmod_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rfmod_scalar"}); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rfmod_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", BinaryScalarOp::Backward); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cu b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu index 3bb6b2d0aee6..8e3ec3db8784 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu @@ -155,5 +155,11 @@ NNVM_REGISTER_OP(_npi_fmod_scalar) NNVM_REGISTER_OP(_backward_npi_fmod_scalar) .set_attr("FCompute", BinaryScalarOp::Backward); +NNVM_REGISTER_OP(_npi_rfmod_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_backward_npi_rfmod_scalar) +.set_attr("FCompute", BinaryScalarOp::Backward); + } // namespace op } // namespace mxnet diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 85a14765680f..b76e341b9fc6 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -346,6 +346,7 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_rgrad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rdiv_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mod); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::fmod); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rfmod); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_rgrad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rmod); // NOLINT()