diff --git a/paddle/phi/kernels/funcs/elementwise_functor.h b/paddle/phi/kernels/funcs/elementwise_functor.h index dc78bd7098411..af0ae3b6447be 100644 --- a/paddle/phi/kernels/funcs/elementwise_functor.h +++ b/paddle/phi/kernels/funcs/elementwise_functor.h @@ -229,6 +229,17 @@ struct FMaxFunctor { } }; +template <> +struct FMaxFunctor { + inline HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16 a, + const dtype::bfloat16 b) const { + float float_a = static_cast(a); + float float_b = static_cast(b); + auto result = std::fmax(float_a, float_b); + return static_cast(result); + } +}; + template <> struct FMaxFunctor { inline HOSTDEVICE int operator()(const int a, const int b) const { diff --git a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu index b3ad0dacae37c..30e222663da10 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu @@ -97,6 +97,7 @@ PD_REGISTER_KERNEL(fmax_grad, double, int, phi::dtype::float16, + phi::dtype::bfloat16, int64_t) {} PD_REGISTER_KERNEL(fmin_grad, diff --git a/paddle/phi/kernels/kps/elementwise_kernel.cu b/paddle/phi/kernels/kps/elementwise_kernel.cu index 80a969c4fabb4..cbf811a9830b0 100644 --- a/paddle/phi/kernels/kps/elementwise_kernel.cu +++ b/paddle/phi/kernels/kps/elementwise_kernel.cu @@ -155,6 +155,7 @@ PD_REGISTER_KERNEL(fmax, double, int, float16, + bfloat16, int64_t) {} PD_REGISTER_KERNEL(fmin, diff --git a/python/paddle/fluid/tests/unittests/test_fmax_op.py b/python/paddle/fluid/tests/unittests/test_fmax_op.py index 19417919df248..0271854aebb72 100644 --- a/python/paddle/fluid/tests/unittests/test_fmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_fmax_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle.fluid import core @@ -241,5 +241,34 @@ def test_check_grad_normal(self): self.check_grad(['X', 'Y'], 'Out') +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the bfloat16", +) +class TestFmaxBF16OP(OpTest): + def setUp(self): + self.op_type = "elementwise_fmax" + self.python_api = paddle.fmax + self.dtype = np.uint16 + x = np.random.uniform(0.1, 1, [13, 17]).astype("float32") + sgn = np.random.choice([-1, 1], [13, 17]).astype("float32") + y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float32") + out = np.fmax(x, y) + self.inputs = { + 'X': convert_float_to_uint16(x), + 'Y': convert_float_to_uint16(y), + } + self.outputs = {'Out': convert_float_to_uint16(out)} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['X', 'Y'], 'Out') + + if __name__ == "__main__": unittest.main()