diff --git a/src/flit/flitHelpers.h b/src/flit/flitHelpers.h index 1fc01091..5e1f9e23 100644 --- a/src/flit/flitHelpers.h +++ b/src/flit/flitHelpers.h @@ -93,6 +93,7 @@ #include #include #include +#include #include #include #include @@ -102,6 +103,7 @@ #include #include +#include #ifndef FLIT_UNUSED #define FLIT_UNUSED(x) (void)x @@ -254,6 +256,24 @@ as_int(long double val) { return temp & (~zero >> 48); } +template +bool equal_with_nan_inf(T a, T b) { + if (std::fpclassify(a) == std::fpclassify(b)) { + switch (std::fpclassify(a)) { + case FP_INFINITE: + case FP_NAN: + return std::signbit(a) == std::signbit(b); + + case FP_NORMAL: + case FP_SUBNORMAL: + case FP_ZERO: + default: + return a == b; + } + } + return false; +} + /** * Default comparison used by FLiT. Similar to * @@ -266,7 +286,12 @@ as_int(long double val) { */ template T abs_compare(T expected, T actual) { - // TODO: implement all other cases + if (equal_with_nan_inf(expected, actual)) { + return T(0.0); + } + if (std::isnan(expected) && std::isinf(actual)) { + return std::numeric_limits::infinity(); + } return std::abs(actual - expected); } diff --git a/tests/flit_src/tst_flitHelpers_h.cpp b/tests/flit_src/tst_flitHelpers_h.cpp index d90b0ee6..e9c0b964 100644 --- a/tests/flit_src/tst_flitHelpers_h.cpp +++ b/tests/flit_src/tst_flitHelpers_h.cpp @@ -280,24 +280,6 @@ TH_TEST(tst_as_int_128bit) { namespace tst_abs_compare { -template -bool equal_with_nan_inf(T a, T b) { - if (std::fpclassify(a) == std::fpclassify(b)) { - switch (std::fpclassify(a)) { - case FP_INFINITE: - case FP_NAN: - return std::signbit(a) == std::signbit(b); - - case FP_NORMAL: - case FP_SUBNORMAL: - case FP_ZERO: - default: - return a == b; - } - } - return false; -} - template void tst_equal_with_nan_inf_impl() { using lim = std::numeric_limits; @@ -305,7 +287,7 @@ void tst_equal_with_nan_inf_impl() { static_assert(lim::has_quiet_NaN); static_assert(lim::has_infinity); - auto eq = equal_with_nan_inf; + auto &eq = flit::equal_with_nan_inf; T my_nan = lim::quiet_NaN(); T my_inf = lim::infinity(); T normal = -3.2; @@ -370,8 +352,8 @@ void tst_abs_compare_impl() { T normal = -3.2; T zero = 0.0; - auto eq = equal_with_nan_inf; - auto comp = flit::abs_compare; + auto &eq = flit::equal_with_nan_inf; + auto &comp = flit::abs_compare; // we have 25 cases TH_VERIFY(eq(comp( my_nan, my_nan), zero ));