Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make bfloat16 implicitly convert to float/double #48238

Merged
merged 3 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/platform/bfloat16_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ TEST(bfloat16, convert_float32_to_bfloat16_on_gpu) {
TEST(bfloat16, assignment_operator_on_gpu) {
// Assignment operator
bfloat16 v_assign;
v_assign = nv_bfloat16(bfloat16(1.0f));
v_assign = bfloat16(1.0f).to_nv_bfloat16();
EXPECT_EQ(v_assign.x, 0x3f80);
v_assign = 0.33333;
EXPECT_EQ(v_assign.x, 0x3eab);
Expand Down
8 changes: 3 additions & 5 deletions paddle/phi/backends/gpu/cuda/cuda_device_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,8 @@ template <>
__forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleDownSync(
unsigned mask, phi::dtype::bfloat16 val, int delta, int width) {
#if defined(PADDLE_CUDA_BF16)
return phi::dtype::bfloat16(__shfl_down_sync(mask,
static_cast<nv_bfloat16>(val),
static_cast<unsigned>(delta),
width));
return phi::dtype::bfloat16(__shfl_down_sync(
mask, val.to_nv_bfloat16(), static_cast<unsigned>(delta), width));
#else
PADDLE_ENFORCE(
false, "__shfl_down_sync with bfloat16 is not supported on cuda <= 11.");
Expand Down Expand Up @@ -114,7 +112,7 @@ __forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleXorSync(
unsigned mask, phi::dtype::bfloat16 val, int width) {
#if defined(PADDLE_CUDA_BF16)
return phi::dtype::bfloat16(
__shfl_xor_sync(mask, static_cast<nv_bfloat16>(val), width));
__shfl_xor_sync(mask, val.to_nv_bfloat16(), width));
#else
PADDLE_ENFORCE(
false, "__shfl_xor_sync with bfloat16 is not supported on cuda <= 11.");
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/common/bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ struct PADDLE_ALIGN(2) bfloat16 {
}

// Conversion opertors
HOSTDEVICE inline explicit operator float() const {
HOSTDEVICE inline operator float() const {
#ifdef PADDLE_WITH_HIP
uint32_t res = 0;
// We should be using memcpy in order to respect the strict aliasing rule
Expand All @@ -168,7 +168,7 @@ struct PADDLE_ALIGN(2) bfloat16 {
}

#ifdef PADDLE_CUDA_BF16
HOSTDEVICE inline explicit operator __nv_bfloat16() const {
HOSTDEVICE inline __nv_bfloat16 to_nv_bfloat16() const {
return *reinterpret_cast<const __nv_bfloat16*>(&x);
}
#endif
Expand Down Expand Up @@ -207,7 +207,7 @@ struct PADDLE_ALIGN(2) bfloat16 {
return static_cast<uint64_t>(static_cast<float>(*this));
}

HOSTDEVICE inline explicit operator double() const {
HOSTDEVICE inline operator double() const {
return static_cast<double>(static_cast<float>(*this));
}
};
Expand Down