Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gamma: Align data type in computation with the declaration of the helper #837

Merged
merged 7 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 22 additions & 19 deletions src/ATen/native/xpu/sycl/Math.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ namespace at::native::xpu {
* For licensing information, please refer to the cpu implementation located in
* "ATen/native/Math.h".
*/
template <typename scalar_t>
template <typename scalar_t, typename pi_t = double>
static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
// [C++ Standard Reference: Gamma Function]
// https://en.cppreference.com/w/cpp/numeric/math/tgamma
using accscalar_t = at::acc_type_device<scalar_t, kXPU>;
static const double PI_f64 = 3.14159265358979323846;
static const pi_t PI_f64 = 3.14159265358979323846;
const accscalar_t PSI_10 = 2.25175258906672110764;
const accscalar_t A[] = {
8.33333333333333333333E-2,
Expand All @@ -27,15 +27,15 @@ static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
};

accscalar_t x = static_cast<accscalar_t>(in);
if (x == 0) {
if (x == accscalar_t(0)) {
// As per C++ standard for gamma related functions and SciPy,
// If the argument is ±0, ±∞ is returned
return std::copysign(static_cast<scalar_t>(INFINITY), -x);
}

bool x_is_integer = x == std::trunc(x);
accscalar_t result = 0;
if (x < 0) {
if (x < accscalar_t(0)) {
if (x_is_integer) {
// As per C++ standard for gamma related functions and SciPy,
// If the argument is a negative integer, NaN is returned
Expand All @@ -46,23 +46,23 @@ static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
// mathematically equivalent since both x and r are in radians and tan() has
// a periodicity of pi, in practice the computation of pi * x is a source of
// error (when |x| > 1).
double q, r;
r = std::modf(static_cast<double>(x), &q);
pi_t q, r;
r = std::modf(static_cast<pi_t>(x), &q);
result = static_cast<accscalar_t>(-PI_f64 / std::tan(PI_f64 * r));
x = 1 - x;
}

while (x < 10) {
while (x < accscalar_t(10)) {
result -= 1 / x;
x += 1;
}
if (x == 10) {
if (x == accscalar_t(10)) {
return static_cast<scalar_t>(result + PSI_10);
}

accscalar_t y = 0;
if (x < 1.0e17) {
accscalar_t z = 1 / (x * x);
if (x < accscalar_t(1.0e17)) {
accscalar_t z = accscalar_t(1) / (x * x);

accscalar_t polevl_result = 0;
for (int i = 0; i <= 6; i++) {
Expand All @@ -82,20 +82,23 @@ static inline C10_HOST_DEVICE scalar_t calc_trigamma(scalar_t in) {
accscalar_t x = static_cast<accscalar_t>(in);
accscalar_t sign = +1;
accscalar_t result = 0;
if (x < 0.5f) {
if (x < accscalar_t(0.5)) {
sign = -1;
accscalar_t sin_pi_x = std::sin(PI * x);
result -= (PI * PI) / (sin_pi_x * sin_pi_x);
x = 1 - x;
x = accscalar_t(1) - x;
}
for (int i = 0; i < 6; ++i) {
result += 1 / (x * x);
x += 1;
result += accscalar_t(1) / (x * x);
x += accscalar_t(1);
}
const accscalar_t one = static_cast<scalar_t>(1);
const accscalar_t ixx = 1 / (x * x);
result += (1 + 1 / (2 * x) +
ixx * (one / 6 - ixx * (one / 30 - ixx * (one / 42)))) /
const accscalar_t one = accscalar_t(1);
const accscalar_t ixx = accscalar_t(1) / (x * x);
result +=
(accscalar_t(1) + accscalar_t(1) / (accscalar_t(2) * x) +
ixx *
(one / accscalar_t(6) -
ixx * (one / accscalar_t(30) - ixx * (one / accscalar_t(42))))) /
x;
return static_cast<scalar_t>(sign * result);
}
Expand All @@ -122,7 +125,7 @@ chbevl(scalar_t _x, const scalar_t array[], size_t len) {
b0 = _x * b1 - b2 + array[i];
}

return (0.5 * (b0 - b2));
return (scalar_t(0.5) * (b0 - b2));
}

/*
Expand Down
17 changes: 14 additions & 3 deletions src/ATen/native/xpu/sycl/UnaryGammaKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@

namespace at::native::xpu {

template <typename scalar_t>
template <typename scalar_t, bool USE_FP64_PI>
struct DigammaFunctor {
scalar_t operator()(scalar_t a) const {
return calc_digamma(a);
if constexpr (USE_FP64_PI) {
return calc_digamma<scalar_t, double>(a);
} else {
using pi_t = at::acc_type_device<scalar_t, kXPU>;
return calc_digamma<scalar_t, pi_t>(a);
}
}
};

Expand All @@ -24,7 +29,13 @@ void digamma_kernel(TensorIteratorBase& iter) {
at::ScalarType::BFloat16,
iter.common_dtype(),
"digamma_xpu",
[&]() { gpu_kernel(iter, DigammaFunctor<scalar_t>()); });
[&]() {
if (syclHasFloat64()) {
gpu_kernel(iter, DigammaFunctor<scalar_t, true>());
} else {
gpu_kernel(iter, DigammaFunctor<scalar_t, false>());
}
});
}

template <typename scalar_t>
Expand Down
6 changes: 6 additions & 0 deletions src/comm/DeviceProperties.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,5 +190,11 @@ uint32_t syclNativeVectorWidth(
"Invalid data type to fetch native vector width!");
}

static inline bool syclHasFloat64(
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
return dev_prop->has_fp64;
}

} // namespace sycl
} // namespace xpu