From f9140213981b982b635f617cbd529e3b9090405a Mon Sep 17 00:00:00 2001 From: wkcn Date: Fri, 20 Sep 2019 20:22:01 +0800 Subject: [PATCH 1/3] fix min max infinity value --- 3rdparty/mshadow/mshadow/base.h | 43 ++++++++++++++++++++++++-- tests/python/unittest/test_operator.py | 16 ++++++++++ 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/3rdparty/mshadow/mshadow/base.h b/3rdparty/mshadow/mshadow/base.h index d08efd387c7e..3ce102c61ff5 100755 --- a/3rdparty/mshadow/mshadow/base.h +++ b/3rdparty/mshadow/mshadow/base.h @@ -645,6 +645,25 @@ MSHADOW_XINLINE int64_t MinValue(void) { return LLONG_MIN; } +/*! + * \brief negative infinity of certain types + * \tparam DType data type + */ +template +MSHADOW_XINLINE DType NegInfValue(void) { + return MinValue(); +} +/*! \brief negative infinity value of float */ +template<> +MSHADOW_XINLINE float NegInfValue(void) { + return -std::numeric_limits::infinity(); +} +/*! \brief negative infinity value of double */ +template<> +MSHADOW_XINLINE double NegInfValue(void) { + return -std::numeric_limits::infinity(); +} + /*! * \brief maximum value of certain types * \tparam DType data type @@ -686,6 +705,26 @@ template<> MSHADOW_XINLINE int64_t MaxValue(void) { return LLONG_MAX; } + +/*! + * \brief positive infinity of certain types + * \tparam DType data type + */ +template +MSHADOW_XINLINE DType PosInfValue(void) { + return MaxValue(); +} +/*! \brief positive infinity value of float */ +template<> +MSHADOW_XINLINE float PosInfValue(void) { + return std::numeric_limits::infinity(); +} +/*! \brief positive infinity value of double */ +template<> +MSHADOW_XINLINE double PosInfValue(void) { + return std::numeric_limits::infinity(); +} + } // namespace limits /*! \brief sum reducer */ @@ -793,7 +832,7 @@ struct maximum { */ template MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) - initv = limits::MinValue(); + initv = limits::NegInfValue(); } /*! *\brief set the initial value during reduction @@ -849,7 +888,7 @@ struct minimum { */ template MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) - initv = limits::MaxValue(); + initv = limits::PosInfValue(); } /*! *\brief set the initial value during reduction diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 2e7cc3ce7504..51725dbde8c2 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -9252,6 +9252,22 @@ def test_sample_normal_default_shape(): assert s.shape == (1, 1) +def test_min_max_inf(): + dtypes = [np.float32, np.double] + elem_list = [-1, 1, 0, np.inf, -np.inf] + + for dtype in dtypes: + for a in elem_list: + for b in elem_list: + data_np = np.array([a, b], dtype=dtype) + data_mx = mx.nd.array(data_np, dtype=dtype) + + min_data_np, max_data_np = data_np.min(), data_np.max() + min_data_mx, max_data_mx = data_mx.min(), data_mx.max() + + assert_array_equal(min_data_np, min_data_mx.asnumpy()) + + if __name__ == '__main__': import nose nose.runmodule() From 752d400139471df49418976cd6a06d173aa2f782 Mon Sep 17 00:00:00 2001 From: wkcn Date: Fri, 20 Sep 2019 20:30:35 +0800 Subject: [PATCH 2/3] add test maximum --- tests/python/unittest/test_operator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 51725dbde8c2..068aab939104 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -9266,6 +9266,7 @@ def test_min_max_inf(): min_data_mx, max_data_mx = data_mx.min(), data_mx.max() assert_array_equal(min_data_np, min_data_mx.asnumpy()) + assert_array_equal(max_data_np, max_data_mx.asnumpy()) if __name__ == '__main__': From 521adf2270d7d876cc4e7a2f45f7b78c345d17e3 Mon Sep 17 00:00:00 2001 From: wkcn Date: Fri, 20 Sep 2019 20:52:16 +0800 Subject: [PATCH 3/3] use HUGE_VAL macro for nvcc --- 3rdparty/mshadow/mshadow/base.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/3rdparty/mshadow/mshadow/base.h b/3rdparty/mshadow/mshadow/base.h index 3ce102c61ff5..1076f122c3ae 100755 --- a/3rdparty/mshadow/mshadow/base.h +++ b/3rdparty/mshadow/mshadow/base.h @@ -656,12 +656,12 @@ MSHADOW_XINLINE DType NegInfValue(void) { /*! \brief negative infinity value of float */ template<> MSHADOW_XINLINE float NegInfValue(void) { - return -std::numeric_limits::infinity(); + return -HUGE_VALF; } /*! \brief negative infinity value of double */ template<> MSHADOW_XINLINE double NegInfValue(void) { - return -std::numeric_limits::infinity(); + return -HUGE_VAL; } /*! @@ -717,12 +717,12 @@ MSHADOW_XINLINE DType PosInfValue(void) { /*! \brief positive infinity value of float */ template<> MSHADOW_XINLINE float PosInfValue(void) { - return std::numeric_limits::infinity(); + return HUGE_VALF; } /*! \brief positive infinity value of double */ template<> MSHADOW_XINLINE double PosInfValue(void) { - return std::numeric_limits::infinity(); + return HUGE_VAL; } } // namespace limits