Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix isnan usage in RTC (#18984)
Browse files Browse the repository at this point in the history
* Fix isnan usage

* Add test
  • Loading branch information
ptrendx authored Aug 25, 2020
1 parent 3c4ac19 commit 8be953f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 29 deletions.
50 changes: 25 additions & 25 deletions src/common/cuda/rtc/forward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,31 @@ __device__ inline void store_add_index(const vector::VectorizedStorage<DType, nv
const char function_definitions_binary[] = R"code(
namespace op {
template <typename DType>
__device__ inline bool isnan(const DType val) {
return util::isnan(val);
}
template <typename DType>
__device__ inline bool_t isinf(const DType val) {
return util::isinf(val);
}
template <typename DType>
__device__ inline bool_t isposinf(const DType val) {
return util::isinf(val) && (val > 0);
}
template <typename DType>
__device__ inline bool_t isneginf(const DType val) {
return util::isinf(val) && (val < 0);
}
template <typename DType>
__device__ inline bool_t isfinite(const DType val) {
return !op::isnan(val) && !op::isinf(val);
}
template <typename DType, typename DType2>
__device__ inline typename type_util::mixed_type<DType, DType2>::type
add(const DType a, const DType2 b) {
Expand Down Expand Up @@ -867,31 +892,6 @@ __device__ inline bool_t np_logical_not(const DType val) {
return !static_cast<bool>(val);
}
template <typename DType>
__device__ inline bool isnan(const DType val) {
return util::isnan(val);
}
template <typename DType>
__device__ inline bool_t isinf(const DType val) {
return util::isinf(val);
}
template <typename DType>
__device__ inline bool_t isposinf(const DType val) {
return util::isinf(val) && (val > 0);
}
template <typename DType>
__device__ inline bool_t isneginf(const DType val) {
return util::isinf(val) && (val < 0);
}
template <typename DType>
__device__ inline bool_t isfinite(const DType val) {
return !op::isnan(val) && !op::isinf(val);
}
#undef DEFINE_UNARY_MATH_FUNC
template <typename DType>
Expand Down
10 changes: 6 additions & 4 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3030,12 +3030,14 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
'bitwise_and': (-100, 100, [None], None, [[_np.int32]]),
'bitwise_xor': (-100, 100, [None], None, [[_np.int32]]),
'bitwise_or': (-100, 100, [None], None, [[_np.int32]]),
'maximum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 >= x2)],
[lambda y, x1, x2: _np.ones(y.shape) * (x1 < x2)]),
'maximum': (-10, 10, [lambda y, x1, x2: _np.ones(y.shape) * (x1 >= x2)],
[lambda y, x1, x2: _np.ones(y.shape) * (x1 < x2)],
[[_np.int32, _np.float16, _np.float32, _np.float64]]),
'fmax': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 >= x2)],
[lambda y, x1, x2: _np.ones(y.shape) * (x1 < x2)]),
'minimum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 <= x2)],
[lambda y, x1, x2: _np.ones(y.shape) * (x1 > x2)]),
'minimum': (-10, 10, [lambda y, x1, x2: _np.ones(y.shape) * (x1 <= x2)],
[lambda y, x1, x2: _np.ones(y.shape) * (x1 > x2)],
[[_np.int32, _np.float16, _np.float32, _np.float64]]),
'fmin': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 <= x2)],
[lambda y, x1, x2: _np.ones(y.shape) * (x1 > x2)]),
'copysign': (-1, 1,
Expand Down

0 comments on commit 8be953f

Please sign in to comment.