Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMP OP&Test]Add fp16/bf16 support isnan/isfinite/isinf op #52259

Merged
merged 17 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions paddle/fluid/framework/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,13 @@ struct DataTypeTrait<void> {
_ForEachDataTypeHelper_( \
callback, ::paddle::platform::complex<double>, COMPLEX128);

#define _ForEachDataTypeNormal_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16);
#define _ForEachDataTypeNormal_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::bfloat16, BF16);

// For the use of thrust, as index-type elements can be only integers.
#define _ForEachDataTypeTiny_(callback) \
Expand Down
12 changes: 8 additions & 4 deletions paddle/fluid/operators/isfinite_op.cu
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/isfinite_op.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"

namespace ops = paddle::operators;
Expand All @@ -22,18 +23,21 @@ REGISTER_OP_CUDA_KERNEL(
ops::OverflowKernel<phi::GPUContext, int, ops::InfinityFunctor>,
ops::OverflowKernel<phi::GPUContext, float, ops::InfinityFunctor>,
ops::OverflowKernel<phi::GPUContext, double, ops::InfinityFunctor>,
ops::OverflowKernel<phi::GPUContext, plat::float16, ops::InfinityFunctor>);
ops::OverflowKernel<phi::GPUContext, plat::float16, ops::InfinityFunctor>,
ops::OverflowKernel<phi::GPUContext, plat::bfloat16, ops::InfinityFunctor>);

REGISTER_OP_CUDA_KERNEL(
isnan,
ops::OverflowKernel<phi::GPUContext, int, ops::NANFunctor>,
ops::OverflowKernel<phi::GPUContext, float, ops::NANFunctor>,
ops::OverflowKernel<phi::GPUContext, double, ops::NANFunctor>,
ops::OverflowKernel<phi::GPUContext, plat::float16, ops::NANFunctor>);
ops::OverflowKernel<phi::GPUContext, plat::float16, ops::NANFunctor>,
ops::OverflowKernel<phi::GPUContext, plat::bfloat16, ops::NANFunctor>);

REGISTER_OP_CUDA_KERNEL(
isfinite,
ops::OverflowKernel<phi::GPUContext, int, ops::IsfiniteFunctor>,
ops::OverflowKernel<phi::GPUContext, float, ops::IsfiniteFunctor>,
ops::OverflowKernel<phi::GPUContext, double, ops::IsfiniteFunctor>,
ops::OverflowKernel<phi::GPUContext, plat::float16, ops::IsfiniteFunctor>);
ops::OverflowKernel<phi::GPUContext, plat::float16, ops::IsfiniteFunctor>,
ops::OverflowKernel<phi::GPUContext, plat::bfloat16, ops::IsfiniteFunctor>);
3 changes: 3 additions & 0 deletions paddle/phi/kernels/cpu/isfinite_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ PD_REGISTER_KERNEL(isinf,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
Expand All @@ -37,6 +38,7 @@ PD_REGISTER_KERNEL(isnan,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
Expand All @@ -49,6 +51,7 @@ PD_REGISTER_KERNEL(isfinite,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
Expand Down
21 changes: 21 additions & 0 deletions paddle/phi/kernels/funcs/isfinite_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ struct IsNanFunctor<phi::dtype::float16, void> {
}
};

template <>
struct IsNanFunctor<phi::dtype::bfloat16, void> {
HOSTDEVICE bool operator()(const phi::dtype::bfloat16& a) const {
return phi::dtype::isnan(a);
}
};

template <typename T, class Enable = void>
struct IsInfFunctor {
HOSTDEVICE bool operator()(const T& a) const {
Expand All @@ -69,6 +76,13 @@ struct IsInfFunctor<phi::dtype::float16, void> {
}
};

template <>
struct IsInfFunctor<phi::dtype::bfloat16, void> {
HOSTDEVICE bool operator()(const phi::dtype::bfloat16& a) const {
return phi::dtype::isinf(a);
}
};

template <typename T, class Enable = void>
struct IsFiniteFunctor {
HOSTDEVICE bool operator()(const T& a) const {
Expand All @@ -94,5 +108,12 @@ struct IsFiniteFunctor<phi::dtype::float16, void> {
}
};

template <>
struct IsFiniteFunctor<phi::dtype::bfloat16, void> {
HOSTDEVICE bool operator()(const phi::dtype::bfloat16& a) const {
return phi::dtype::isfinite(a);
}
};

} // namespace funcs
} // namespace phi
3 changes: 3 additions & 0 deletions paddle/phi/kernels/gpu/isfinite_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ PD_REGISTER_KERNEL(isinf,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
Expand All @@ -37,6 +38,7 @@ PD_REGISTER_KERNEL(isnan,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
Expand All @@ -49,6 +51,7 @@ PD_REGISTER_KERNEL(isfinite,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
Expand Down
68 changes: 67 additions & 1 deletion python/paddle/fluid/tests/unittests/test_isfinite_op.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16

from paddle.fluid import core

Expand Down Expand Up @@ -48,6 +48,28 @@ def init_dtype(self):
self.dtype = np.float16


# BFP16 isinf Test
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestInfBF16(OpTest):
def setUp(self):
self.op_type = "isinf"
self.dtype = np.uint16
x = np.random.uniform(0.1, 1, [11, 17]).astype(np.float32)
x[0] = np.inf
x[-1] = np.inf

out = np.array(True)
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': out}

def test_output(self):
self.check_output_with_place(core.CUDAPlace(0))


class TestNAN(OpTest):
def setUp(self):
self.op_type = "isnan"
Expand Down Expand Up @@ -76,6 +98,28 @@ def init_dtype(self):
self.dtype = np.float16


# BFP16 isnan Test
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestNANBF16(OpTest):
def setUp(self):
self.op_type = "isnan"
self.dtype = np.uint16
x = np.random.uniform(0.1, 1, [11, 17]).astype(np.float32)
x[0] = np.nan
x[-1] = np.nan

out = np.array(True)
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': out}

def test_output(self):
self.check_output_with_place(core.CUDAPlace(0))


class TestIsfinite(OpTest):
def setUp(self):
self.op_type = "isfinite"
Expand Down Expand Up @@ -105,5 +149,27 @@ def init_dtype(self):
self.dtype = np.float16


# BFP16 isfinite Test
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestIsfiniteBF16(OpTest):
def setUp(self):
self.op_type = "isfinite"
self.dtype = np.uint16
x = np.random.uniform(0.1, 1, [11, 17]).astype(np.float32)
x[0] = np.inf
x[-1] = np.nan

out = np.array(False)
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': out}

def test_output(self):
self.check_output_with_place(core.CUDAPlace(0))


if __name__ == '__main__':
unittest.main()
33 changes: 30 additions & 3 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3466,7 +3466,14 @@ def isfinite(x, name=None):
check_variable_and_dtype(
x,
'x',
['float16', 'float32', 'float64', 'int32', 'int64'],
[
'float16',
'float32',
'float64',
'int32',
'int64',
'uint16',
],
'isfinite',
)
out = helper.create_variable_for_type_inference('bool')
Expand Down Expand Up @@ -3502,7 +3509,17 @@ def isinf(x, name=None):
else:
helper = LayerHelper("isinf_v2", **locals())
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'isinf'
x,
'x',
[
'float16',
'float32',
'float64',
'int32',
'int64',
'uint16',
],
'isinf',
)
out = helper.create_variable_for_type_inference(dtype='bool')
helper.append_op(type="isinf_v2", inputs={"X": x}, outputs={"Out": out})
Expand Down Expand Up @@ -3535,7 +3552,17 @@ def isnan(x, name=None):
else:
helper = LayerHelper("isnan_v2", **locals())
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'isnan'
x,
'x',
[
'float16',
'float32',
'float64',
'int32',
'int64',
'uint16',
],
'isnan',
)
out = helper.create_variable_for_type_inference(dtype='bool')
helper.append_op(type="isnan_v2", inputs={"X": x}, outputs={"Out": out})
Expand Down