Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
* Fix rfmod function
Browse files Browse the repository at this point in the history
  • Loading branch information
hanke580 committed Mar 3, 2020
1 parent 87b4f0b commit 7ff5708
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,7 @@ def fmod(x1, x2, out=None, **kwargs):
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
"""
return _ufunc_helper(x1, x2, _npi.fmod, _np.fmod, _npi.fmod_scalar, out)
return _ufunc_helper(x1, x2, _npi.fmod, _np.fmod, _npi.fmod_scalar, _npi.rfmod_scalar, out)


@set_module('mxnet.ndarray.numpy')
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,7 +1585,7 @@ def mod(x1, x2, out=None, **kwargs):
@set_module('mxnet.symbol.numpy')
@wrap_np_binary_func
def fmod(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.fmod, _np.fmod, _npi.fmod_scalar, out)
return _ufunc_helper(x1, x2, _npi.fmod, _np.fmod, _npi.fmod_scalar, _npi.rfmod_scalar, out)


@set_module('mxnet.symbol.numpy')
Expand Down
11 changes: 11 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,17 @@ struct fmod : public mxnet_op::tunable {
}
};

struct rfmod : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (a == DType(0)) {
return DType(0);
} else {
return DType(::fmod(static_cast<double>(b), static_cast<double>(a)));
}
}
};

template<>
MSHADOW_XINLINE mshadow::half::half2_t mod::Map<mshadow::half::half2_t>
(mshadow::half::half2_t a,
Expand Down
9 changes: 9 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op_extended.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,5 +455,14 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_fmod_scalar)
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Backward<cpu, mshadow_op::mod_grad>);

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rfmod_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rfmod>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_rfmod_scalar"});

MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rfmod_scalar)
.add_argument("scalar", "float", "scalar value")
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Backward<cpu, mshadow_op::rmod_grad>);

} // namespace op
} // namespace mxnet
6 changes: 6 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op_extended.cu
Original file line number Diff line number Diff line change
Expand Up @@ -155,5 +155,11 @@ NNVM_REGISTER_OP(_npi_fmod_scalar)
NNVM_REGISTER_OP(_backward_npi_fmod_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<gpu, mshadow_op::mod_grad>);

NNVM_REGISTER_OP(_npi_rfmod_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rfmod>);

NNVM_REGISTER_OP(_backward_npi_rfmod_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<gpu, mshadow_op::rmod_grad>);

} // namespace op
} // namespace mxnet
1 change: 1 addition & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_rgrad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rdiv_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mod); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::fmod); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rfmod); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_rgrad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rmod); // NOLINT()
Expand Down

0 comments on commit 7ff5708

Please sign in to comment.