Skip to content

Commit

Permalink
[SYCL] Fix the mul_hi built-in on host device
Browse files Browse the repository at this point in the history
This patch fixes incorrect handling of negative arguments in host
implementation of the mul_hi built-in.

Signed-off-by: Sergey Semenov <sergey.semenov@intel.com>
  • Loading branch information
sergey-semenov authored and bader committed Sep 26, 2019
1 parent ab3e71e commit 8a3b7d9
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 21 deletions.
80 changes: 59 additions & 21 deletions sycl/source/detail/builtins_integer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,20 @@ template <typename T> T __mul_hi(T a, T b) {
return (mul >> (sizeof(T) * 8));
}

// T is minimum of 64 bits- long or longlong
template <typename T> inline T __long_mul_hi(T a, T b) {
int halfsize = (sizeof(T) * 8) / 2;
// A helper function for mul_hi built-in for long
template <typename T> inline T __get_high_half(T a0b0, T a0b1, T a1b0, T a1b1) {
constexpr int halfsize = (sizeof(T) * 8) / 2;
// To get the upper 64 bits:
// 64 bits from a1b1, upper 32 bits from [a1b0 + (a0b1 + a0b0>>32 (carry bit
// in 33rd bit))] with carry bit on 64th bit - use of hadd. Add the a1b1 to
// the above 32 bit result.
return a1b1 + (__hadd(a1b0, (a0b1 + (a0b0 >> halfsize))) >> (halfsize - 1));
}

// A helper function for mul_hi built-in for long
template <typename T>
inline void __get_half_products(T a, T b, T &a0b0, T &a0b1, T &a1b0, T &a1b1) {
constexpr int halfsize = (sizeof(T) * 8) / 2;
T a1 = a >> halfsize;
T a0 = (a << halfsize) >> halfsize;
T b1 = b >> halfsize;
Expand All @@ -90,26 +101,53 @@ template <typename T> inline T __long_mul_hi(T a, T b) {
// a1b1 - for bits - [64-128)
// a1b0 a0b1 for bits - [32-96)
// a0b0 for bits - [0-64)
T a1b1 = a1 * b1;
T a0b1 = a0 * b1;
T a1b0 = a1 * b0;
T a0b0 = a0 * b0;
a1b1 = a1 * b1;
a0b1 = a0 * b1;
a1b0 = a1 * b0;
a0b0 = a0 * b0;
}

// T is minimum of 64 bits- long or longlong
template <typename T> inline T __u_long_mul_hi(T a, T b) {
T a0b0, a0b1, a1b0, a1b1;
__get_half_products(a, b, a0b0, a0b1, a1b0, a1b1);
T result = __get_high_half(a0b0, a0b1, a1b0, a1b1);
return result;
}

template <typename T> inline T __s_long_mul_hi(T a, T b) {
using UT = typename std::make_unsigned<T>::type;
UT absA = std::abs(a);
UT absB = std::abs(b);

UT a0b0, a0b1, a1b0, a1b1;
__get_half_products(absA, absB, a0b0, a0b1, a1b0, a1b1);
T result = __get_high_half(a0b0, a0b1, a1b0, a1b1);

bool isResultNegative = (a < 0) != (b < 0);
if (isResultNegative) {
result = ~result;

// Find the low half to see if we need to carry
constexpr int halfsize = (sizeof(T) * 8) / 2;
UT low = a0b0 + ((a0b1 + a1b0) << halfsize);
if (low == 0)
++result;
}

// To get the upper 64 bits:
// 64 bits from a1b1, upper 32 bits from [a1b0 + (a0b1 + a0b0>>32 (carry bit
// in 33rd bit))] with carry bit on 64th bit - use of hadd. Add the a1b1 to
// the above 32 bit result.
T result =
a1b1 + (__hadd(a1b0, (a0b1 + (a0b0 >> halfsize))) >> (halfsize - 1));
return result;
}

template <typename T> inline T __mad_hi(T a, T b, T c) {
return __mul_hi(a, b) + c;
}

template <typename T> inline T __long_mad_hi(T a, T b, T c) {
return __long_mul_hi(a, b) + c;
template <typename T> inline T __u_long_mad_hi(T a, T b, T c) {
return __u_long_mul_hi(a, b) + c;
}

template <typename T> inline T __s_long_mad_hi(T a, T b, T c) {
return __s_long_mul_hi(a, b) + c;
}

template <typename T> inline T __s_mad_sat(T a, T b, T c) {
Expand All @@ -123,7 +161,7 @@ template <typename T> inline T __s_mad_sat(T a, T b, T c) {

template <typename T> inline T __s_long_mad_sat(T a, T b, T c) {
bool neg_prod = (a < 0) ^ (b < 0);
T mulhi = __long_mul_hi(a, b);
T mulhi = __s_long_mul_hi(a, b);

// check mul_hi. If it is any value != 0.
// if prod is +ve, any value in mulhi means we need to saturate.
Expand All @@ -145,7 +183,7 @@ template <typename T> inline T __u_mad_sat(T a, T b, T c) {
}

template <typename T> inline T __u_long_mad_sat(T a, T b, T c) {
T mulhi = __long_mul_hi(a, b);
T mulhi = __u_long_mul_hi(a, b);
// check mul_hi. If it is any value != 0.
if (mulhi != 0)
return d::max_v<T>();
Expand Down Expand Up @@ -421,7 +459,7 @@ cl_char s_mul_hi(cl_char a, cl_char b) { return __mul_hi(a, b); }
cl_short s_mul_hi(cl_short a, cl_short b) { return __mul_hi(a, b); }
cl_int s_mul_hi(cl_int a, cl_int b) { return __mul_hi(a, b); }
cl_long s_mul_hi(s::cl_long x, s::cl_long y) __NOEXC {
return __long_mul_hi(x, y);
return __s_long_mul_hi(x, y);
}
MAKE_1V_2V(s_mul_hi, s::cl_char, s::cl_char, s::cl_char)
MAKE_1V_2V(s_mul_hi, s::cl_short, s::cl_short, s::cl_short)
Expand All @@ -433,7 +471,7 @@ cl_uchar u_mul_hi(cl_uchar a, cl_uchar b) { return __mul_hi(a, b); }
cl_ushort u_mul_hi(cl_ushort a, cl_ushort b) { return __mul_hi(a, b); }
cl_uint u_mul_hi(cl_uint a, cl_uint b) { return __mul_hi(a, b); }
cl_ulong u_mul_hi(s::cl_ulong x, s::cl_ulong y) __NOEXC {
return __long_mul_hi(x, y);
return __u_long_mul_hi(x, y);
}
MAKE_1V_2V(u_mul_hi, s::cl_uchar, s::cl_uchar, s::cl_uchar)
MAKE_1V_2V(u_mul_hi, s::cl_ushort, s::cl_ushort, s::cl_ushort)
Expand All @@ -452,7 +490,7 @@ cl_int s_mad_hi(s::cl_int x, s::cl_int minval, s::cl_int maxval) __NOEXC {
return __mad_hi(x, minval, maxval);
}
cl_long s_mad_hi(s::cl_long x, s::cl_long minval, s::cl_long maxval) __NOEXC {
return __long_mad_hi(x, minval, maxval);
return __s_long_mad_hi(x, minval, maxval);
}
MAKE_1V_2V_3V(s_mad_hi, s::cl_char, s::cl_char, s::cl_char, s::cl_char)
MAKE_1V_2V_3V(s_mad_hi, s::cl_short, s::cl_short, s::cl_short, s::cl_short)
Expand All @@ -473,7 +511,7 @@ cl_uint u_mad_hi(s::cl_uint x, s::cl_uint minval, s::cl_uint maxval) __NOEXC {
}
cl_ulong u_mad_hi(s::cl_ulong x, s::cl_ulong minval,
s::cl_ulong maxval) __NOEXC {
return __long_mad_hi(x, minval, maxval);
return __u_long_mad_hi(x, minval, maxval);
}
MAKE_1V_2V_3V(u_mad_hi, s::cl_uchar, s::cl_uchar, s::cl_uchar, s::cl_uchar)
MAKE_1V_2V_3V(u_mad_hi, s::cl_ushort, s::cl_ushort, s::cl_ushort, s::cl_ushort)
Expand Down
33 changes: 33 additions & 0 deletions sycl/test/built-ins/scalar_integer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,39 @@ int main() {
assert(r == 0x10);
}

// mul_hi with negative result w/ carry
{
s::cl_int r{0};
{
s::buffer<s::cl_int, 1> BufR(&r, s::range<1>(1));
s::queue myQueue;
myQueue.submit([&](s::handler &cgh) {
auto AccR = BufR.get_access<s::access::mode::write>(cgh);
cgh.single_task<class mul_hiSI1SI2>([=]() {
AccR[0] = s::mul_hi(s::cl_int{-0x10000000}, s::cl_int{0x00000100});
}); // -2^28 * 2^8 = -2^36 -> -0x10 (FFFFFFF0) 00000000.
});
}
assert(r == -0x10);
}

// mul_hi with negative result w/o carry
{
s::cl_int r{0};
{
s::buffer<s::cl_int, 1> BufR(&r, s::range<1>(1));
s::queue myQueue;
myQueue.submit([&](s::handler &cgh) {
auto AccR = BufR.get_access<s::access::mode::write>(cgh);
cgh.single_task<class mul_hiSI1SI3>([=]() {
AccR[0] = s::mul_hi(s::cl_int{-0x10000000}, s::cl_int{0x00000101});
}); // -2^28 * (2^8 + 1) = -2^36 - 2^28 -> -0x11 (FFFFFFEF) -0x10000000
// (F0000000).
});
}
assert(r == -0x11);
}

// rotate
{
s::cl_int r{ 0 };
Expand Down

0 comments on commit 8a3b7d9

Please sign in to comment.