-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PaddlePaddle Hackathon 4 No.49】:为 Paddle bce_loss 支持 float16 数据类型 #50930
Changes from 18 commits
c8ae296
6aa02f0
d599110
264894d
98d1e1c
b958122
760e099
e16076d
85169ee
c7560fe
085c7a6
b1edf68
f2887e5
6a62308
6620e88
e4134d9
29929ba
812c917
4793749
b50ec23
96b73f9
e598009
68e544f
82adc55
0f96491
895c3d5
2df50b2
c9d5fc2
9afb5e6
b9fc6df
f8a3c2b
a8616ca
73071c0
ec9249d
796c276
72bd515
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,8 @@ | |
#include <vector> | ||
|
||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/common/amp_type_traits.h" | ||
#include "paddle/phi/common/float16.h" | ||
#include "paddle/phi/core/hostdevice.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/funcs/elementwise_base.h" | ||
|
@@ -40,9 +42,15 @@ struct BCELossFunctor { | |
(x >= static_cast<T>(0)) && (x <= one), | ||
"Input is expected to be within the interval [0, 1], but received %f.", | ||
x); | ||
T term1 = max(phi::kps::details::Log(x), neg_100); | ||
T term2 = max(phi::kps::details::Log(one - x), neg_100); | ||
return (((label - one) * term2) - (label * term1)); | ||
using MT = typename phi::dtype::MPTypeTrait<T>::Type; | ||
MT term1 = max(phi::kps::details::Log(static_cast<MT>(x)), | ||
static_cast<MT>(neg_100)); | ||
MT term2 = | ||
max(phi::kps::details::Log(static_cast<MT>(one) - static_cast<MT>(x)), | ||
static_cast<MT>(neg_100)); | ||
return static_cast<T>( | ||
((static_cast<MT>(label) - static_cast<MT>(one)) * term2) - | ||
(static_cast<MT>(label) * term1)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里和上面也是类似的问题,我觉得可以修改下原始的实现。one和neg_100本来是成员变量,可以初始化就为MT 类型。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
} | ||
}; | ||
|
||
|
@@ -60,5 +68,10 @@ void BCELossKernel(const Context& dev_ctx, | |
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL( | ||
bce_loss, GPU, ALL_LAYOUT, phi::BCELossKernel, float, double) {} | ||
PD_REGISTER_KERNEL(bce_loss, | ||
GPU, | ||
ALL_LAYOUT, | ||
phi::BCELossKernel, | ||
float, | ||
double, | ||
phi::dtype::float16) {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
|
||
import paddle | ||
import paddle.fluid as fluid | ||
import paddle.fluid.core as core | ||
|
||
|
||
def test_static_layer( | ||
|
@@ -279,6 +280,68 @@ def init_test_cast(self): | |
self.shape = [2, 3, 20] | ||
|
||
|
||
class TestBceLossOpFP16(TestBceLossOp): | ||
def setUp(self): | ||
self.init_test_case() | ||
self.op_type = "bce_loss" | ||
self.python_api = bce_wrapper | ||
input_np = np.random.uniform(0.1, 0.8, self.shape).astype("float16") | ||
label_np = np.random.randint(0, 2, self.shape).astype("float16") | ||
output_np = bce_loss(input_np, label_np) | ||
|
||
self.inputs = {'X': input_np, 'Label': label_np} | ||
self.outputs = {'Out': output_np} | ||
|
||
def test_check_output(self): | ||
if core.is_compiled_with_cuda(): | ||
place = core.CUDAPlace(0) | ||
if core.is_float16_supported(place): | ||
self.check_output_with_place(place, atol=1e-3) | ||
|
||
def test_check_grad(self): | ||
place = core.CUDAPlace(0) | ||
if core.is_float16_supported(place): | ||
self.check_grad_with_place( | ||
place, ['X'], 'Out', max_relative_error=1 | ||
) | ||
|
||
|
||
class TestBceLossOpFP16Case1(TestBceLossOpFP16): | ||
def init_test_case(self): | ||
self.shape = [2, 3, 4, 5] | ||
|
||
|
||
class TestBceLossOpFP16Case2(TestBceLossOpFP16): | ||
def init_test_case(self): | ||
self.shape = [2, 3, 20] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 上述单测可以再简化一下,TestBceLossOpFP16继承了TestBceLossOp,可以对TestBceLossOp做一些调整,比如初始化case的时候能够设置dtype,shape。这样可以去掉很多冗余的代码。 max_relative_error为什么会这么大? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 暂时为了测试ci, 反向的相对误差很大,一直没找到原因 |
||
|
||
|
||
class TestBceLossOpStaticFP16(unittest.TestCase): | ||
def test_fp16(self): | ||
paddle.enable_static() | ||
shape = [2, 3, 20] | ||
x_data = np.random.uniform(0.1, 0.8, shape).astype("float16") | ||
y_data = np.random.randint(0, 2, shape).astype("float16") | ||
output_np = bce_loss(x_data, y_data) | ||
with paddle.static.program_guard(paddle.static.Program()): | ||
x = paddle.static.data(shape=shape, name='x', dtype='float16') | ||
y = paddle.static.data(shape=shape, name='y', dtype='float16') | ||
out = paddle.nn.functional.binary_cross_entropy( | ||
x, y, reduction="none" | ||
) | ||
if core.is_compiled_with_cuda(): | ||
place = paddle.CUDAPlace(0) | ||
exe = paddle.static.Executor(place) | ||
exe.run(paddle.static.default_startup_program()) | ||
output_pd = exe.run( | ||
feed={'x': x_data, 'y': y_data}, fetch_list=[out] | ||
)[0] | ||
np.testing.assert_allclose( | ||
output_pd, output_np, rtol=1e-3, atol=1e-3 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. atol设置为0能通过吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zhangting2020 这里是没问题的, atol=1e-3也能通过。 |
||
) | ||
paddle.disable_static() | ||
|
||
|
||
if __name__ == "__main__": | ||
paddle.enable_static() | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eps的问题,36行,1e-12在fp16表示下会下溢出为0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已做调整,不知道是否可以这样写。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以简化一下代码?one和eps作为成员变量,初始化为MT类型。原来的构造函数可以删掉了