Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.54】为 Paddle allclose、isclose 算子实现 float16 数据类型支持 #50988

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions paddle/phi/kernels/gpu/allclose_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "glog/logging.h"

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"

Expand All @@ -31,14 +32,16 @@ __global__ void AllcloseCUDAKernel(const T* in_data,
bool* out_data) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
bool val;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
const T a = in_data[i], b = other_data[i];
const MPType a = static_cast<MPType>(in_data[i]),
b = static_cast<MPType>(other_data[i]);
if (isnan(a) || isnan(b)) {
val = equal_nan && isnan(a) == isnan(b);
} else {
T left = (a > b ? a - b : b - a);
T right = atol + (b > 0 ? rtol * b : (-rtol) * b);
T diff = (left > right ? left - right : right - left);
MPType left = (a > b ? a - b : b - a);
MPType right = atol + (b > 0 ? rtol * b : (-rtol) * b);
MPType diff = (left > right ? left - right : right - left);
val = a == b || left <= right || diff <= 1e-15;
}
if (!val) *out_data = false;
Expand Down Expand Up @@ -92,7 +95,12 @@ void AllCloseKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(
allclose, GPU, ALL_LAYOUT, phi::AllCloseKernel, float, double) {
PD_REGISTER_KERNEL(allclose,
GPU,
ALL_LAYOUT,
phi::AllCloseKernel,
phi::dtype::float16,
float,
double) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
9 changes: 7 additions & 2 deletions paddle/phi/kernels/gpu/isclose_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,10 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/isclose_kernel_impl.h"

PD_REGISTER_KERNEL(
isclose, GPU, ALL_LAYOUT, phi::IscloseKernel, float, double) {}
PD_REGISTER_KERNEL(isclose,
GPU,
ALL_LAYOUT,
phi::IscloseKernel,
phi::dtype::float16,
float,
double) {}
11 changes: 7 additions & 4 deletions paddle/phi/kernels/impl/isclose_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
Expand Down Expand Up @@ -109,14 +110,16 @@ __global__ void IscloseCUDAKernel(const T* in_data,
bool* out_data) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
bool val;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
const T a = in_data[i], b = other_data[i];
const MPType a = static_cast<MPType>(in_data[i]),
b = static_cast<MPType>(other_data[i]);
if (isnan(a) || isnan(b)) {
val = equal_nan && isnan(a) == isnan(b);
} else {
T left = (a > b ? a - b : b - a);
T right = atol + (b > 0 ? rtol * b : (-rtol) * b);
T diff = (left > right ? left - right : right - left);
MPType left = (a > b ? a - b : b - a);
MPType right = atol + (b > 0 ? rtol * b : (-rtol) * b);
MPType diff = (left > right ? left - right : right - left);
val = a == b || left <= right || diff <= 1e-15;
}
out_data[i] = val;
Expand Down
15 changes: 14 additions & 1 deletion python/paddle/fluid/tests/unittests/test_allclose_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,26 @@ def test_api_case(self):
paddle.enable_static()


class TestAllcloseDygraphFp16(unittest.TestCase):
def test_api_case(self):
paddle.disable_static()
x_data = np.random.rand(10, 10).astype("float16")
y_data = np.random.rand(10, 10).astype("float16")
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
out = paddle.allclose(x, y, rtol=1e-05, atol=1e-08)
expected_out = np.allclose(x_data, y_data, rtol=1e-05, atol=1e-08)
self.assertTrue((out.numpy() == expected_out).all(), True)
paddle.enable_static()


class TestAllcloseError(unittest.TestCase):
def test_input_dtype(self):
def test_x_dtype():
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float16')
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='int32')
y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64')
result = paddle.allclose(x, y)

Expand Down
19 changes: 18 additions & 1 deletion python/paddle/fluid/tests/unittests/test_isclose_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,23 @@ def test_api_case(self):
paddle.enable_static()


class TestIscloseDygraphFp16(unittest.TestCase):
def test_api_case(self):
places = []
if paddle.fluid.core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
paddle.disable_static()
x_data = np.random.rand(10, 10).astype(np.float16)
y_data = np.random.rand(10, 10).astype(np.float16)
x = paddle.to_tensor(x_data, place=place)
y = paddle.to_tensor(y_data, place=place)
out = paddle.isclose(x, y, rtol=1e-05, atol=1e-08)
expected_out = np.isclose(x_data, y_data, rtol=1e-05, atol=1e-08)
self.assertTrue((out.numpy() == expected_out).all(), True)
paddle.enable_static()


class TestIscloseError(unittest.TestCase):
def test_input_dtype(self):
paddle.enable_static()
Expand All @@ -166,7 +183,7 @@ def test_x_dtype():
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float16')
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='int32')
y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64')
result = paddle.isclose(x, y)

Expand Down