diff --git a/3rdparty/mshadow/mshadow/base.h b/3rdparty/mshadow/mshadow/base.h index d08efd387c7e..1076f122c3ae 100755 --- a/3rdparty/mshadow/mshadow/base.h +++ b/3rdparty/mshadow/mshadow/base.h @@ -645,6 +645,25 @@ MSHADOW_XINLINE int64_t MinValue(void) { return LLONG_MIN; } +/*! + * \brief negative infinity of certain types + * \tparam DType data type + */ +template +MSHADOW_XINLINE DType NegInfValue(void) { + return MinValue(); +} +/*! \brief negative infinity value of float */ +template<> +MSHADOW_XINLINE float NegInfValue(void) { + return -HUGE_VALF; +} +/*! \brief negative infinity value of double */ +template<> +MSHADOW_XINLINE double NegInfValue(void) { + return -HUGE_VAL; +} + /*! * \brief maximum value of certain types * \tparam DType data type @@ -686,6 +705,26 @@ template<> MSHADOW_XINLINE int64_t MaxValue(void) { return LLONG_MAX; } + +/*! + * \brief positive infinity of certain types + * \tparam DType data type + */ +template +MSHADOW_XINLINE DType PosInfValue(void) { + return MaxValue(); +} +/*! \brief positive infinity value of float */ +template<> +MSHADOW_XINLINE float PosInfValue(void) { + return HUGE_VALF; +} +/*! \brief positive infinity value of double */ +template<> +MSHADOW_XINLINE double PosInfValue(void) { + return HUGE_VAL; +} + } // namespace limits /*! \brief sum reducer */ @@ -793,7 +832,7 @@ struct maximum { */ template MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) - initv = limits::MinValue(); + initv = limits::NegInfValue(); } /*! *\brief set the initial value during reduction @@ -849,7 +888,7 @@ struct minimum { */ template MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) - initv = limits::MaxValue(); + initv = limits::PosInfValue(); } /*! *\brief set the initial value during reduction diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 2e7cc3ce7504..068aab939104 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -9252,6 +9252,23 @@ def test_sample_normal_default_shape(): assert s.shape == (1, 1) +def test_min_max_inf(): + dtypes = [np.float32, np.double] + elem_list = [-1, 1, 0, np.inf, -np.inf] + + for dtype in dtypes: + for a in elem_list: + for b in elem_list: + data_np = np.array([a, b], dtype=dtype) + data_mx = mx.nd.array(data_np, dtype=dtype) + + min_data_np, max_data_np = data_np.min(), data_np.max() + min_data_mx, max_data_mx = data_mx.min(), data_mx.max() + + assert_array_equal(min_data_np, min_data_mx.asnumpy()) + assert_array_equal(max_data_np, max_data_mx.asnumpy()) + + if __name__ == '__main__': import nose nose.runmodule()