From d81e53cff2fb19516d405491eef57f23e8f36d48 Mon Sep 17 00:00:00 2001 From: Ke Han Date: Mon, 24 Feb 2020 13:49:35 +0800 Subject: [PATCH] * Fix rfmod function --- python/mxnet/ndarray/numpy/_op.py | 2 +- python/mxnet/symbol/numpy/_symbol.py | 2 +- src/operator/mshadow_op.h | 11 +++++++++++ .../numpy/np_elemwise_broadcast_op_extended.cc | 9 +++++++++ .../numpy/np_elemwise_broadcast_op_extended.cu | 6 ++++++ src/operator/operator_tune.cc | 1 + 6 files changed, 29 insertions(+), 2 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 91dfba7c7281..119b22959522 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -1176,7 +1176,7 @@ def fmod(x1, x2, out=None, **kwargs): out : ndarray or scalar This is a scalar if both x1 and x2 are scalars. """ - return _ufunc_helper(x1, x2, _npi.fmod, _np.fmod, _npi.fmod_scalar, out) + return _ufunc_helper(x1, x2, _npi.fmod, _np.fmod, _npi.fmod_scalar, _npi.rfmod_scalar, out) @set_module('mxnet.ndarray.numpy') diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 71ced3844f90..c1e1104625ff 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -1580,7 +1580,7 @@ def mod(x1, x2, out=None, **kwargs): @set_module('mxnet.symbol.numpy') @wrap_np_binary_func def fmod(x1, x2, out=None, **kwargs): - return _ufunc_helper(x1, x2, _npi.fmod, _np.fmod, _npi.fmod_scalar, out) + return _ufunc_helper(x1, x2, _npi.fmod, _np.fmod, _npi.fmod_scalar, _npi.rfmod_scalar, out) @set_module('mxnet.symbol.numpy') diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index b50256eecf9d..4d9de29ce709 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -806,6 +806,17 @@ struct fmod : public mxnet_op::tunable { } }; +struct rfmod : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + if (a == DType(0)) { + return DType(0); + } else { + return DType(::fmod(static_cast(b), static_cast(a))); + } + } +}; + template<> MSHADOW_XINLINE mshadow::half::half2_t mod::Map (mshadow::half::half2_t a, diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc index 308e254c2b4d..3d79e3a4105a 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc @@ -455,5 +455,14 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_fmod_scalar) .set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) .set_attr("FCompute", BinaryScalarOp::Backward); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rfmod_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rfmod_scalar"}); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rfmod_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", BinaryScalarOp::Backward); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cu b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu index 3bb6b2d0aee6..8e3ec3db8784 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu @@ -155,5 +155,11 @@ NNVM_REGISTER_OP(_npi_fmod_scalar) NNVM_REGISTER_OP(_backward_npi_fmod_scalar) .set_attr("FCompute", BinaryScalarOp::Backward); +NNVM_REGISTER_OP(_npi_rfmod_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_backward_npi_rfmod_scalar) +.set_attr("FCompute", BinaryScalarOp::Backward); + } // namespace op } // namespace mxnet diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 85a14765680f..b76e341b9fc6 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -346,6 +346,7 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_rgrad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rdiv_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mod); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::fmod); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rfmod); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_rgrad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rmod); // NOLINT()