Skip to content

Commit

Permalink
Merge pull request #1981 from lucbv/half_cleanup
Browse files Browse the repository at this point in the history
Common: remove half and bhalf implementations
  • Loading branch information
lucbv authored Oct 17, 2023
2 parents c9093bb + caef00b commit 71be381
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 25 deletions.
2 changes: 2 additions & 0 deletions common/src/KokkosKernels_Half.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
//
//@HEADER

#if KOKKOS_VERSION < 40199
#ifndef KOKKOSKERNELS_HALF_HPP
#define KOKKOSKERNELS_HALF_HPP

Expand Down Expand Up @@ -61,3 +62,4 @@ namespace Experimental {
} // namespace Experimental
} // namespace KokkosKernels
#endif // KOKKOSKERNELS_HALF_HPP
#endif // KOKKOS_VERSION < 40199
133 changes: 115 additions & 18 deletions common/src/Kokkos_ArithTraits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
#include <Kokkos_MathematicalFunctions.hpp>
#include <Kokkos_Complex.hpp>
#include <Kokkos_Macros.hpp>
#if KOKKOS_VERSION < 40199
#include <KokkosKernels_Half.hpp>
#endif

#include <impl/Kokkos_QuadPrecisionMath.hpp>

Expand Down Expand Up @@ -197,8 +199,6 @@ KOKKOS_FORCEINLINE_FUNCTION IntType intPowUnsigned(const IntType x,
namespace Kokkos {

// Macro to automate the wrapping of Kokkos Mathematical Functions
// in the ArithTraits struct for real floating point types, hopefully
// this can be expanded to Kokkos::half_t and Kokkos::bhalf_t
#define KOKKOSKERNELS_ARITHTRAITS_REAL_FP(FUNC_QUAL) \
static FUNC_QUAL val_type zero() { return static_cast<val_type>(0); } \
static FUNC_QUAL val_type one() { return static_cast<val_type>(1); } \
Expand Down Expand Up @@ -279,6 +279,83 @@ namespace Kokkos {
static FUNC_QUAL val_type squareroot(const val_type x) { return sqrt(x); } \
static FUNC_QUAL mag_type eps() { return epsilon(); }

// Macro to automate the wrapping of Kokkos Mathematical Functions
#define KOKKOSKERNELS_ARITHTRAITS_HALF_FP(FUNC_QUAL) \
static FUNC_QUAL val_type zero() { return static_cast<val_type>(0); } \
static FUNC_QUAL val_type one() { return static_cast<val_type>(1); } \
static FUNC_QUAL val_type min() { \
return Kokkos::Experimental::finite_min<val_type>::value; \
} \
static FUNC_QUAL val_type max() { \
return Kokkos::Experimental::finite_max<val_type>::value; \
} \
static FUNC_QUAL val_type infinity() { \
return Kokkos::Experimental::infinity<val_type>::value; \
} \
static FUNC_QUAL val_type nan() { \
return Kokkos::Experimental::quiet_NaN<val_type>::value; \
} \
static FUNC_QUAL mag_type epsilon() { \
return Kokkos::Experimental::epsilon<val_type>::value; \
} \
static FUNC_QUAL mag_type sfmin() { \
return Kokkos::Experimental::norm_min<val_type>::value; \
} \
static FUNC_QUAL int base() { \
return Kokkos::Experimental::radix<val_type>::value; \
} \
static FUNC_QUAL mag_type prec() { \
return epsilon() * static_cast<mag_type>(base()); \
} \
static FUNC_QUAL int t() { \
return Kokkos::Experimental::digits<val_type>::value; \
} \
static FUNC_QUAL mag_type rnd() { return one(); } \
static FUNC_QUAL int emin() { \
return Kokkos::Experimental::min_exponent<val_type>::value; \
} \
static FUNC_QUAL mag_type rmin() { \
return Kokkos::Experimental::norm_min<val_type>::value; \
} \
static FUNC_QUAL int emax() { \
return Kokkos::Experimental::max_exponent<val_type>::value; \
} \
static FUNC_QUAL mag_type rmax() { \
return Kokkos::Experimental::finite_max<val_type>::value; \
} \
\
static FUNC_QUAL bool isInf(const val_type x) { return Kokkos::isinf(x); } \
static FUNC_QUAL mag_type abs(const val_type x) { return Kokkos::abs(x); } \
static FUNC_QUAL mag_type real(const val_type x) { return Kokkos::real(x); } \
static FUNC_QUAL mag_type imag(const val_type x) { return Kokkos::imag(x); } \
static FUNC_QUAL val_type conj(const val_type x) { return x; } \
static FUNC_QUAL val_type pow(const val_type x, const val_type y) { \
return Kokkos::pow(x, y); \
} \
static FUNC_QUAL val_type sqrt(const val_type x) { return Kokkos::sqrt(x); } \
static FUNC_QUAL val_type cbrt(const val_type x) { return Kokkos::cbrt(x); } \
static FUNC_QUAL val_type exp(const val_type x) { return Kokkos::exp(x); } \
static FUNC_QUAL val_type log(const val_type x) { return Kokkos::log(x); } \
static FUNC_QUAL val_type log10(const val_type x) { \
return Kokkos::log10(x); \
} \
static FUNC_QUAL val_type sin(const val_type x) { return Kokkos::sin(x); } \
static FUNC_QUAL val_type cos(const val_type x) { return Kokkos::cos(x); } \
static FUNC_QUAL val_type tan(const val_type x) { return Kokkos::tan(x); } \
static FUNC_QUAL val_type sinh(const val_type x) { return Kokkos::sinh(x); } \
static FUNC_QUAL val_type cosh(const val_type x) { return Kokkos::cosh(x); } \
static FUNC_QUAL val_type tanh(const val_type x) { return Kokkos::tanh(x); } \
static FUNC_QUAL val_type asin(const val_type x) { return Kokkos::asin(x); } \
static FUNC_QUAL val_type acos(const val_type x) { return Kokkos::acos(x); } \
static FUNC_QUAL val_type atan(const val_type x) { return Kokkos::atan(x); } \
\
static FUNC_QUAL magnitudeType magnitude(const val_type x) { \
return abs(x); \
} \
static FUNC_QUAL val_type conjugate(const val_type x) { return conj(x); } \
static FUNC_QUAL val_type squareroot(const val_type x) { return sqrt(x); } \
static FUNC_QUAL mag_type eps() { return epsilon(); }

#define KOKKOSKERNELS_ARITHTRAITS_CMPLX_FP(FUNC_QUAL) \
\
static constexpr bool is_specialized = true; \
Expand Down Expand Up @@ -912,8 +989,6 @@ class ArithTraits {
//@}
};

// Since Kokkos::Experimental::half_t falls back to float, only define
// ArithTraits if half_t is a backend specialization
#if defined(KOKKOS_HALF_T_IS_FLOAT) && !KOKKOS_HALF_T_IS_FLOAT
template <>
class ArithTraits<Kokkos::Experimental::half_t> {
Expand All @@ -926,8 +1001,9 @@ class ArithTraits<Kokkos::Experimental::half_t> {
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool is_complex = false;
static constexpr bool has_infinity = true;

static constexpr bool has_infinity = true;
#if KOKKOS_VERSION < 40199
static KOKKOS_FUNCTION val_type infinity() {
return Kokkos::Experimental::cast_to_half(
Kokkos::Experimental::infinity<float>::value);
Expand Down Expand Up @@ -1028,16 +1104,21 @@ class ArithTraits<Kokkos::Experimental::half_t> {
static KOKKOS_FUNCTION mag_type epsilon() {
return Kokkos::Experimental::cast_to_half(KOKKOSKERNELS_IMPL_FP16_EPSILON);
}
#endif

// Backwards compatibility with Teuchos::ScalarTraits.
using magnitudeType = mag_type;
// C++ doesn't have a standard "half-float" type.
using halfPrecision = val_type;
using doublePrecision = double;
using magnitudeType = mag_type;
using halfPrecision = Kokkos::Experimental::half_t;
using doublePrecision = float;

static std::string name() { return "half_t"; }

static constexpr bool isComplex = false;
static constexpr bool isOrdinal = false;
static constexpr bool isComparable = true;
static constexpr bool hasMachineParameters = true;

#if KOKKOS_VERSION < 40199
static KOKKOS_FUNCTION bool isnaninf(const val_type x) {
return isNan(x) || isInf(x);
}
Expand All @@ -1047,7 +1128,6 @@ class ArithTraits<Kokkos::Experimental::half_t> {
static KOKKOS_FUNCTION val_type conjugate(const val_type x) {
return conj(x);
}
static std::string name() { return "half"; }
static KOKKOS_FUNCTION val_type squareroot(const val_type x) {
return sqrt(x);
}
Expand Down Expand Up @@ -1077,8 +1157,15 @@ class ArithTraits<Kokkos::Experimental::half_t> {
static KOKKOS_FUNCTION mag_type rmax() {
return Kokkos::Experimental::cast_to_half(KOKKOSKERNELS_IMPL_FP16_MAX);
}
#else
#if defined(KOKKOS_ENABLE_SYCL) || defined(KOKKOS_ENABLE_HIP)
KOKKOSKERNELS_ARITHTRAITS_HALF_FP(KOKKOS_FUNCTION)
#else
KOKKOSKERNELS_ARITHTRAITS_REAL_FP(KOKKOS_FUNCTION)
#endif
#endif
};
#endif // KOKKOS_HALF_T_IS_FLOAT && KOKKOS_ENABLE_CUDA_HALF
#endif // #if defined(KOKKOS_HALF_T_IS_FLOAT) && !KOKKOS_HALF_T_IS_FLOAT

// Since Kokkos::Experimental::bhalf_t falls back to float, only define
// ArithTraits if bhalf_t is a backend specialization
Expand All @@ -1094,8 +1181,9 @@ class ArithTraits<Kokkos::Experimental::bhalf_t> {
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool is_complex = false;
static constexpr bool has_infinity = true;

static constexpr bool has_infinity = true;
#if KOKKOS_VERSION < 40199
static KOKKOS_FUNCTION val_type infinity() {
return Kokkos::Experimental::cast_to_bhalf(
Kokkos::Experimental::infinity<float>::value);
Expand Down Expand Up @@ -1193,16 +1281,23 @@ class ArithTraits<Kokkos::Experimental::bhalf_t> {
// return ::pow(2, -KOKKOSKERNELS_IMPL_BF16_SIGNIFICAND_BITS);
return Kokkos::Experimental::cast_to_bhalf(KOKKOSKERNELS_IMPL_BF16_EPSILON);
}
#endif

// Backwards compatibility with Teuchos::ScalarTraits.
using magnitudeType = mag_type;
// C++ doesn't have a standard "bhalf-float" type.
using bhalfPrecision = val_type;
using doublePrecision = double;
using magnitudeType = mag_type;
using bhalfPrecision = Kokkos::Experimental::bhalf_t;
// There is no type that has twice the precision as bhalf_t.
// The closest type would be float.
using doublePrecision = void;

static constexpr bool isComplex = false;
static constexpr bool isOrdinal = false;
static constexpr bool isComparable = true;
static constexpr bool hasMachineParameters = true;

static std::string name() { return "bhalf_t"; }

#if KOKKOS_VERSION < 40199
static KOKKOS_FUNCTION bool isnaninf(const val_type x) {
return isNan(x) || isInf(x);
}
Expand All @@ -1212,7 +1307,6 @@ class ArithTraits<Kokkos::Experimental::bhalf_t> {
static KOKKOS_FUNCTION val_type conjugate(const val_type x) {
return conj(x);
}
static std::string name() { return "bhalf"; }
static KOKKOS_FUNCTION val_type squareroot(const val_type x) {
return sqrt(x);
}
Expand Down Expand Up @@ -1242,8 +1336,11 @@ class ArithTraits<Kokkos::Experimental::bhalf_t> {
static KOKKOS_FUNCTION mag_type rmax() {
return Kokkos::Experimental::cast_to_bhalf(KOKKOSKERNELS_IMPL_BF16_MAX);
}
#else
KOKKOSKERNELS_ARITHTRAITS_REAL_FP(KOKKOS_FUNCTION)
#endif
};
#endif // KOKKOS_BHALF_T_IS_FLOAT
#endif // #if defined(KOKKOS_BHALF_T_IS_FLOAT) && !KOKKOS_BHALF_T_IS_FLOAT

template <>
class ArithTraits<float> {
Expand Down
80 changes: 73 additions & 7 deletions common/unit_test/Test_Common_ArithTraits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,20 @@ class ArithTraitsTesterBase {
}

if (AT::has_infinity) {
if (!AT::isInf(AT::infinity())) {
out << "AT::isInf (inf) != true" << endl;
FAILURE();
// Compiler intrinsic casts from inf of type half_t / bhalf_t to inf
// of type float in CUDA, SYCL and HIP do not work yet.
#if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_SYCL) || \
defined(KOKKOS_ENABLE_HIP)
namespace KE = Kokkos::Experimental;
if constexpr (!std::is_same<ScalarType, KE::half_t>::value &&
!std::is_same<ScalarType, KE::bhalf_t>::value) {
#else
{
#endif // KOKKOS_ENABLE_CUDA || KOKKOS_ENABLE_SYCL || KOKKOS_ENABLE_HIP
if (!AT::isInf(AT::infinity())) {
out << "AT::isInf (inf) != true" << endl;
FAILURE();
}
}
}
if (!std::is_same<ScalarType, decltype(AT::infinity())>::value) {
Expand Down Expand Up @@ -1495,13 +1506,24 @@ class ArithTraitsTesterFloatingPointBase<ScalarType, DeviceType, 0>
FAILURE();
}

if (!AT::isNan(AT::nan())) {
// Compiler intrinsic casts from nan of type half_t / bhalf_t to nan
// of type float in CUDA, SYCL and HIP do not work yet.
#if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_SYCL) || \
defined(KOKKOS_ENABLE_HIP)
namespace KE = Kokkos::Experimental;
if constexpr (!std::is_same<ScalarType, KE::half_t>::value &&
!std::is_same<ScalarType, KE::bhalf_t>::value) {
#else
{
#endif // KOKKOS_ENABLE_CUDA || KOKKOS_ENABLE_SYCL || KOKKOS_ENABLE_HIP
if (!AT::isNan(AT::nan())) {
#if KOKKOS_VERSION < 40199
KOKKOS_IMPL_DO_NOT_USE_PRINTF("NaN is not NaN\n");
KOKKOS_IMPL_DO_NOT_USE_PRINTF("NaN is not NaN\n");
#else
Kokkos::printf("NaN is not NaN\n");
Kokkos::printf("NaN is not NaN\n");
#endif
FAILURE();
FAILURE();
}
}

const ScalarType zero = AT::zero();
Expand All @@ -1523,6 +1545,27 @@ class ArithTraitsTesterFloatingPointBase<ScalarType, DeviceType, 0>
#endif
FAILURE();
}
#if defined(KOKKOS_ENABLE_SYCL) || \
defined(KOKKOS_ENABLE_HIP) // FIXME_SYCL, FIXME_HIP
if constexpr (!std::is_same_v<ScalarType, Kokkos::Experimental::half_t>) {
if (AT::isNan(zero)) {
#if KOKKOS_VERSION < 40199
KOKKOS_IMPL_DO_NOT_USE_PRINTF("0 is NaN\n");
#else
Kokkos::printf("0 is NaN\n");
#endif
FAILURE();
}
if (AT::isNan(one)) {
#if KOKKOS_VERSION < 40199
KOKKOS_IMPL_DO_NOT_USE_PRINTF("1 is NaN\n");
#else
Kokkos::printf("1 is NaN\n");
#endif
FAILURE();
}
}
#else
if (AT::isNan(zero)) {
#if KOKKOS_VERSION < 40199
KOKKOS_IMPL_DO_NOT_USE_PRINTF("0 is NaN\n");
Expand All @@ -1539,6 +1582,7 @@ class ArithTraitsTesterFloatingPointBase<ScalarType, DeviceType, 0>
#endif
FAILURE();
}
#endif

// Call the base class' implementation. Every subclass'
// implementation of operator() must do this, in order to include
Expand All @@ -1563,10 +1607,19 @@ class ArithTraitsTesterFloatingPointBase<ScalarType, DeviceType, 0>

// if (std::numeric_limits<ScalarType>::is_iec559) {
// success = success && AT::isInf (AT::inf ());
#if defined(KOKKOS_ENABLE_SYCL) || defined(KOKKOS_ENABLE_HIP)
if constexpr (!std::is_same_v<ScalarType, Kokkos::Experimental::half_t>) {
if (!AT::isNan(AT::nan())) {
out << "isNan or nan failed" << endl;
FAILURE();
}
}
#else
if (!AT::isNan(AT::nan())) {
out << "isNan or nan failed" << endl;
FAILURE();
}
#endif
//}

const ScalarType zero = AT::zero();
Expand All @@ -1580,6 +1633,18 @@ class ArithTraitsTesterFloatingPointBase<ScalarType, DeviceType, 0>
out << "isInf(one) is 1" << endl;
FAILURE();
}
#if defined(KOKKOS_ENABLE_SYCL) || defined(KOKKOS_ENABLE_HIP)
if constexpr (!std::is_same_v<ScalarType, Kokkos::Experimental::half_t>) {
if (AT::isNan(zero)) {
out << "isNan(zero) is 1" << endl;
FAILURE();
}
if (AT::isNan(one)) {
out << "isNan(one) is 1" << endl;
FAILURE();
}
}
#else
if (AT::isNan(zero)) {
out << "isNan(zero) is 1" << endl;
FAILURE();
Expand All @@ -1588,6 +1653,7 @@ class ArithTraitsTesterFloatingPointBase<ScalarType, DeviceType, 0>
out << "isNan(one) is 1" << endl;
FAILURE();
}
#endif

// Call the base class' implementation. Every subclass'
// implementation of testHostImpl() should (must) do this, in
Expand Down
2 changes: 2 additions & 0 deletions test_common/KokkosKernels_TestUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ class epsilon {
constexpr static double value = std::numeric_limits<T>::epsilon();
};

#if KOKKOS_VERSION < 40199
// explicit epsilon specializations
#if defined(KOKKOS_HALF_T_IS_FLOAT) && !KOKKOS_HALF_T_IS_FLOAT
template <>
Expand All @@ -428,6 +429,7 @@ class epsilon<Kokkos::Experimental::bhalf_t> {
constexpr static double value = 0.0078125F;
};
#endif // KOKKOS_HALF_T_IS_FLOAT
#endif // KOKKOS_VERSION < 40199

using KokkosKernels::Impl::getRandomBounds;

Expand Down

0 comments on commit 71be381

Please sign in to comment.