From 1b82983909becd9d6151b2aae8a8a5dc110d74d3 Mon Sep 17 00:00:00 2001 From: fanlinghuo <942441893@qq.com> Date: Tue, 28 Feb 2023 00:19:47 +0800 Subject: [PATCH 1/2] first commit of isclose and allclose --- paddle/phi/kernels/gpu/allclose_kernel.cu | 20 +++++++++++++------ paddle/phi/kernels/gpu/isclose_kernel.cu | 9 +++++++-- paddle/phi/kernels/impl/isclose_kernel_impl.h | 11 ++++++---- .../fluid/tests/unittests/test_allclose_op.py | 15 +++++++++++++- .../fluid/tests/unittests/test_isclose_op.py | 19 +++++++++++++++++- 5 files changed, 60 insertions(+), 14 deletions(-) diff --git a/paddle/phi/kernels/gpu/allclose_kernel.cu b/paddle/phi/kernels/gpu/allclose_kernel.cu index fa6a8fce0bf861..a0492ca46bf34c 100644 --- a/paddle/phi/kernels/gpu/allclose_kernel.cu +++ b/paddle/phi/kernels/gpu/allclose_kernel.cu @@ -16,6 +16,7 @@ #include "glog/logging.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" @@ -31,14 +32,16 @@ __global__ void AllcloseCUDAKernel(const T* in_data, bool* out_data) { unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x; bool val; + using MPType = typename phi::dtype::MPTypeTrait::Type; for (int i = idx; i < num; i += blockDim.x * gridDim.x) { - const T a = in_data[i], b = other_data[i]; + const MPType a = static_cast(in_data[i]), + b = static_cast(other_data[i]); if (isnan(a) || isnan(b)) { val = equal_nan && isnan(a) == isnan(b); } else { - T left = (a > b ? a - b : b - a); - T right = atol + (b > 0 ? rtol * b : (-rtol) * b); - T diff = (left > right ? left - right : right - left); + MPType left = (a > b ? a - b : b - a); + MPType right = atol + (b > 0 ? rtol * b : (-rtol) * b); + MPType diff = (left > right ? left - right : right - left); val = a == b || left <= right || diff <= 1e-15; } if (!val) *out_data = false; @@ -92,7 +95,12 @@ void AllCloseKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - allclose, GPU, ALL_LAYOUT, phi::AllCloseKernel, float, double) { +PD_REGISTER_KERNEL(allclose, + GPU, + ALL_LAYOUT, + phi::AllCloseKernel, + phi::dtype::float16, + float, + double) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } diff --git a/paddle/phi/kernels/gpu/isclose_kernel.cu b/paddle/phi/kernels/gpu/isclose_kernel.cu index 34774ec715c48d..4cf827856e22a2 100644 --- a/paddle/phi/kernels/gpu/isclose_kernel.cu +++ b/paddle/phi/kernels/gpu/isclose_kernel.cu @@ -18,5 +18,10 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/isclose_kernel_impl.h" -PD_REGISTER_KERNEL( - isclose, GPU, ALL_LAYOUT, phi::IscloseKernel, float, double) {} +PD_REGISTER_KERNEL(isclose, + GPU, + ALL_LAYOUT, + phi::IscloseKernel, + phi::dtype::float16, + float, + double) {} diff --git a/paddle/phi/kernels/impl/isclose_kernel_impl.h b/paddle/phi/kernels/impl/isclose_kernel_impl.h index cf7171656486c1..0105e6b4e8a0bc 100644 --- a/paddle/phi/kernels/impl/isclose_kernel_impl.h +++ b/paddle/phi/kernels/impl/isclose_kernel_impl.h @@ -18,6 +18,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" @@ -109,14 +110,16 @@ __global__ void IscloseCUDAKernel(const T* in_data, bool* out_data) { unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x; bool val; + using MPType = typename phi::dtype::MPTypeTrait::Type; for (int i = idx; i < num; i += blockDim.x * gridDim.x) { - const T a = in_data[i], b = other_data[i]; + const MPType a = static_cast(in_data[i]), + b = static_cast(other_data[i]); if (isnan(a) || isnan(b)) { val = equal_nan && isnan(a) == isnan(b); } else { - T left = (a > b ? a - b : b - a); - T right = atol + (b > 0 ? rtol * b : (-rtol) * b); - T diff = (left > right ? left - right : right - left); + MPType left = (a > b ? a - b : b - a); + MPType right = atol + (b > 0 ? rtol * b : (-rtol) * b); + MPType diff = (left > right ? left - right : right - left); val = a == b || left <= right || diff <= 1e-15; } out_data[i] = val; diff --git a/python/paddle/fluid/tests/unittests/test_allclose_op.py b/python/paddle/fluid/tests/unittests/test_allclose_op.py index c4cde0ec49ee99..f0d1065c41ef1f 100644 --- a/python/paddle/fluid/tests/unittests/test_allclose_op.py +++ b/python/paddle/fluid/tests/unittests/test_allclose_op.py @@ -128,13 +128,26 @@ def test_api_case(self): paddle.enable_static() +class TestAllcloseDygraphFp16(unittest.TestCase): + def test_api_case(self): + paddle.disable_static() + x_data = np.random.rand(10, 10).astype("float16") + y_data = np.random.rand(10, 10).astype("float16") + x = paddle.to_tensor(x_data) + y = paddle.to_tensor(y_data) + out = paddle.allclose(x, y, rtol=1e-05, atol=1e-08) + expected_out = np.allclose(x_data, y_data, rtol=1e-05, atol=1e-08) + self.assertTrue((out.numpy() == expected_out).all(), True) + paddle.enable_static() + + class TestAllcloseError(unittest.TestCase): def test_input_dtype(self): def test_x_dtype(): with paddle.static.program_guard( paddle.static.Program(), paddle.static.Program() ): - x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float16') + x = paddle.fluid.data(name='x', shape=[10, 10], dtype='int32') y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64') result = paddle.allclose(x, y) diff --git a/python/paddle/fluid/tests/unittests/test_isclose_op.py b/python/paddle/fluid/tests/unittests/test_isclose_op.py index fc2a5cd5ebef2d..931ddd065ecae3 100644 --- a/python/paddle/fluid/tests/unittests/test_isclose_op.py +++ b/python/paddle/fluid/tests/unittests/test_isclose_op.py @@ -158,6 +158,23 @@ def test_api_case(self): paddle.enable_static() +class TestIscloseDygraphFp16(unittest.TestCase): + def test_api_case(self): + places = [] + if paddle.fluid.core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for place in places: + paddle.disable_static() + x_data = np.random.rand(10, 10).astype(np.float16) + y_data = np.random.rand(10, 10).astype(np.float16) + x = paddle.to_tensor(x_data, place=place) + y = paddle.to_tensor(y_data, place=place) + out = paddle.isclose(x, y, rtol=1e-05, atol=1e-08) + expected_out = np.isclose(x_data, y_data, rtol=1e-05, atol=1e-08) + self.assertTrue((out.numpy() == expected_out).all(), True) + paddle.enable_static() + + class TestIscloseError(unittest.TestCase): def test_input_dtype(self): paddle.enable_static() @@ -166,7 +183,7 @@ def test_x_dtype(): with paddle.static.program_guard( paddle.static.Program(), paddle.static.Program() ): - x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float16') + x = paddle.fluid.data(name='x', shape=[10, 10], dtype='int32') y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64') result = paddle.isclose(x, y) From 368adcab2a3b8abe5a47f3b437c3ca6b43213545 Mon Sep 17 00:00:00 2001 From: fanlinghuo <942441893@qq.com> Date: Tue, 28 Feb 2023 23:47:50 +0800 Subject: [PATCH 2/2] Empty-Commit