Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 4 No.49】:为 Paddle bce_loss 支持 float16 数据类型 #50930

Merged
merged 36 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
c8ae296
untracked files
thunder95 Feb 20, 2023
6aa02f0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Feb 23, 2023
d599110
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Feb 23, 2023
264894d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Feb 25, 2023
98d1e1c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Feb 25, 2023
b958122
Merge branch 'develop' of https://github.com/thunder95/Paddle into de…
thunder95 Feb 25, 2023
760e099
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Feb 26, 2023
e16076d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Feb 26, 2023
85169ee
bce_loss_fp16
thunder95 Feb 26, 2023
c7560fe
remove unused files
thunder95 Feb 26, 2023
085c7a6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Mar 2, 2023
b1edf68
Merge branch 'develop' of https://github.com/thunder95/Paddle into de…
thunder95 Mar 2, 2023
f2887e5
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Mar 2, 2023
6a62308
Merge branch 'develop' of https://github.com/thunder95/Paddle into de…
thunder95 Mar 2, 2023
6620e88
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Mar 8, 2023
e4134d9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Mar 10, 2023
29929ba
Merge branch 'bce_loss_fp16' of https://github.com/thunder95/Paddle i…
thunder95 Mar 11, 2023
812c917
back max_rel_erro still big
thunder95 Mar 11, 2023
4793749
simplify code
thunder95 Mar 13, 2023
b50ec23
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Mar 13, 2023
96b73f9
upd
thunder95 Mar 13, 2023
e598009
fix max_relative_error
thunder95 Mar 13, 2023
68e544f
fix max_relative_error
thunder95 Mar 13, 2023
82adc55
restart ci
thunder95 Mar 14, 2023
0f96491
Update test_bce_loss.py
Mar 14, 2023
895c3d5
Merge branch 'bce_loss_fp16' of https://github.com/thunder95/Paddle i…
thunder95 Mar 14, 2023
2df50b2
Update test_bce_loss.py
Mar 14, 2023
c9d5fc2
Update test_bce_loss.py
Mar 15, 2023
9afb5e6
Update test_bce_loss.py
Mar 15, 2023
b9fc6df
try to pass test
thunder95 Mar 22, 2023
f8a3c2b
merge
thunder95 Mar 29, 2023
a8616ca
restore file
thunder95 Mar 29, 2023
73071c0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Apr 13, 2023
ec9249d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Apr 13, 2023
796c276
remove error value
thunder95 Apr 14, 2023
72bd515
fix bug
thunder95 Apr 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions paddle/phi/kernels/gpu/bce_loss_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <vector>

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
Expand All @@ -26,17 +28,15 @@ namespace phi {

template <typename T>
struct BCELossGradFunctor {
T one;
T eps;

HOSTDEVICE inline BCELossGradFunctor() {
one = static_cast<T>(1.0f);
eps = static_cast<T>(1e-12);
}
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
MT one = static_cast<MT>(1.0f);
MT eps = static_cast<MT>(1e-12);

HOSTDEVICE inline T operator()(const T x, const T label, const T dout) const {
T term1 = max((one - x) * x, eps);
return (dout * (x - label) / term1);
MT x_mt = static_cast<MT>(x);
MT term1 = max((one - x_mt) * x_mt, eps);
return static_cast<T>(static_cast<MT>(dout) *
(x_mt - static_cast<MT>(label)) / term1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eps的问题,36行,1e-12在fp16表示下会下溢出为0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已做调整,不知道是否可以这样写。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以简化一下代码?one和eps作为成员变量,初始化为MT类型。原来的构造函数可以删掉了

}
};

Expand All @@ -55,5 +55,10 @@ void BCELossGradKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(
bce_loss_grad, GPU, ALL_LAYOUT, phi::BCELossGradKernel, float, double) {}
PD_REGISTER_KERNEL(bce_loss_grad,
GPU,
ALL_LAYOUT,
phi::BCELossGradKernel,
float,
double,
phi::dtype::float16) {}
36 changes: 22 additions & 14 deletions paddle/phi/kernels/gpu/bce_loss_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <vector>

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
Expand All @@ -27,22 +29,23 @@ namespace phi {

template <typename T>
struct BCELossFunctor {
T one;
T neg_100;

HOSTDEVICE inline BCELossFunctor() {
one = static_cast<T>(1.0f);
neg_100 = static_cast<T>(-100.);
}
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
MT zero = static_cast<MT>(0);
MT one = static_cast<MT>(1.0f);
MT neg_100 = static_cast<MT>(-100.);

HOSTDEVICE inline T operator()(const T x, const T label) const {
MT x_mt = static_cast<MT>(x);
MT label_mt = static_cast<MT>(label);

PADDLE_ENFORCE(
(x >= static_cast<T>(0)) && (x <= one),
(x_mt >= zero) && (x_mt <= one),
"Input is expected to be within the interval [0, 1], but received %f.",
x);
T term1 = max(phi::kps::details::Log(x), neg_100);
T term2 = max(phi::kps::details::Log(one - x), neg_100);
return (((label - one) * term2) - (label * term1));
x_mt);

MT term1 = max(phi::kps::details::Log(x_mt), neg_100);
MT term2 = max(phi::kps::details::Log(one - x_mt), neg_100);
return static_cast<T>((label_mt - one) * term2 - label_mt * term1);
}
};

Expand All @@ -60,5 +63,10 @@ void BCELossKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(
bce_loss, GPU, ALL_LAYOUT, phi::BCELossKernel, float, double) {}
PD_REGISTER_KERNEL(bce_loss,
GPU,
ALL_LAYOUT,
phi::BCELossKernel,
float,
double,
phi::dtype::float16) {}
42 changes: 40 additions & 2 deletions python/paddle/fluid/tests/unittests/test_bce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import paddle
from paddle import fluid
from paddle.fluid import core


def test_static_layer(
Expand Down Expand Up @@ -249,11 +250,12 @@ def bce_wrapper(x, label):

class TestBceLossOp(OpTest):
def setUp(self):
self.init_test_dtype()
self.init_test_case()
self.op_type = "bce_loss"
self.python_api = bce_wrapper
input_np = np.random.uniform(0.1, 0.8, self.shape).astype("float64")
label_np = np.random.randint(0, 2, self.shape).astype("float64")
input_np = np.random.uniform(0.1, 0.8, self.shape).astype(self.dtype)
label_np = np.random.randint(0, 2, self.shape).astype(self.dtype)
output_np = bce_loss(input_np, label_np)

self.inputs = {'X': input_np, 'Label': label_np}
Expand All @@ -268,6 +270,9 @@ def test_check_grad(self):
def init_test_case(self):
self.shape = [10, 10]

def init_test_dtype(self):
self.dtype = "float64"


class TestBceLossOpCase1(OpTest):
def init_test_cast(self):
Expand All @@ -279,6 +284,39 @@ def init_test_cast(self):
self.shape = [2, 3, 20]


class TestBceLossOpFP16(TestBceLossOp):
def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')

def init_test_dtype(self):
self.dtype = np.float16


class TestBceLossOpStaticFP16(unittest.TestCase):
def test_fp16(self):
paddle.enable_static()
shape = [2, 3, 20]
x_data = np.random.uniform(0.1, 0.8, shape).astype("float16")
y_data = np.random.randint(0, 2, shape).astype("float16")
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(shape=shape, name='x', dtype='float16')
y = paddle.static.data(shape=shape, name='y', dtype='float16')
out = paddle.nn.functional.binary_cross_entropy(
x, y, reduction="none"
)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
output_pd = exe.run(
feed={'x': x_data, 'y': y_data}, fetch_list=[out]
)[0]
paddle.disable_static()


if __name__ == "__main__":
paddle.enable_static()
unittest.main()
14 changes: 10 additions & 4 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,10 +641,10 @@ def binary_cross_entropy(
Parameters:
input (Tensor): The input predications tensor. 2-D tensor with shape: [N, *],
N is batch_size, `*` means number of additional dimensions. The ``input``
should always be the output of sigmod. Available dtype is float32, float64.
should always be the output of sigmod. Available dtype is float16, float32, float64.
label (Tensor): The target labels tensor. 2-D tensor with the same shape as
``input``. The target labels which values should be numbers between 0 and 1.
Available dtype is float32, float64.
Available dtype is float16, float32, float64.
weight (Tensor, optional): A manual rescaling weight given to the loss of each
batch element. If given, has to be a Tensor of size nbatch and the data type
is float32, float64. Default is ``'None'``.
Expand Down Expand Up @@ -694,10 +694,16 @@ def binary_cross_entropy(
return out
else:
check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'binary_cross_entropy'
input,
'input',
['float16', 'float32', 'float64'],
'binary_cross_entropy',
)
check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'binary_cross_entropy'
label,
'label',
['float16', 'float32', 'float64'],
'binary_cross_entropy',
)

sub_name = name if weight is None and reduction == 'none' else None
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/nn/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,8 +730,8 @@ class BCELoss(Layer):
For more information, please refer to :ref:`api_guide_Name`.

Shape:
- input (Tensor): 2-D tensor with shape: ``[N, *]``, N is batch_size, `*` means number of additional dimensions. The input ``input`` should always be the output of sigmod. Available dtype is float32, float64.
- label (Tensor): 2-D tensor with the same shape as ``input``. The target labels which values should be numbers between 0 and 1. Available dtype is float32, float64.
- input (Tensor): 2-D tensor with shape: ``[N, *]``, N is batch_size, `*` means number of additional dimensions. The input ``input`` should always be the output of sigmod. Available dtype is float16, float32, float64.
- label (Tensor): 2-D tensor with the same shape as ``input``. The target labels which values should be numbers between 0 and 1. Available dtype is float16, float32, float64.
- output (Tensor): If ``reduction`` is ``'none'``, the shape of output is same as ``input`` , else the shape of output is scalar.

Returns:
Expand Down