Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 4 No.53】:为 Paddle label_smooth 支持 float16 数据类型 #50921

Closed
wants to merge 10 commits into from
8 changes: 6 additions & 2 deletions paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "paddle/phi/kernels/label_smooth_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"

Expand All @@ -28,7 +30,8 @@ struct LabelSmoothGradFunctor {
}

__device__ __forceinline__ T operator()(const T x) const {
return static_cast<T>(1 - epsilon) * x;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

34行可以删掉,因为上面已有

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

return static_cast<T>((1 - static_cast<MT>(epsilon)) * static_cast<MT>(x));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

结合上面的代码。这里的epsilon在上面29行已经被cast到了FP16,这里又cast回FP32。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已调整

}
};

Expand All @@ -52,4 +55,5 @@ PD_REGISTER_KERNEL(label_smooth_grad,
ALL_LAYOUT,
phi::LabelSmoothGradKernel,
float,
double) {}
double,
phi::dtype::float16) {}
23 changes: 17 additions & 6 deletions paddle/phi/kernels/gpu/label_smooth_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <vector>

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"

Expand All @@ -33,8 +35,10 @@ struct LabelSmoothFunctor {
}

__device__ __forceinline__ T operator()(const T x) const {
return (static_cast<T>(1 - epsilon) * x +
static_cast<T>(epsilon / label_dim));
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
return static_cast<T>((1 - static_cast<MT>(epsilon)) * static_cast<MT>(x) +
static_cast<MT>(epsilon) /
static_cast<MT>(label_dim));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

epsilon问题同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

}
};

Expand All @@ -46,9 +50,11 @@ __global__ void LabelSmoothRunDistKernel(const int N,
const T* dist_data,
T* dst) {
CUDA_KERNEL_LOOP(idx, N) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
int dist_idx = idx % dist_numel;
dst[idx] = static_cast<T>(1 - epsilon) * src[idx] +
static_cast<T>(epsilon) * dist_data[dist_idx];
dst[idx] = static_cast<T>(
static_cast<MT>((1 - epsilon) * static_cast<MT>(src[idx])) +
static_cast<MT>(epsilon) * static_cast<MT>(dist_data[dist_idx]));
}
}

Expand Down Expand Up @@ -83,5 +89,10 @@ void LabelSmoothKernel(const Context& ctx,

} // namespace phi

PD_REGISTER_KERNEL(
label_smooth, GPU, ALL_LAYOUT, phi::LabelSmoothKernel, float, double) {}
PD_REGISTER_KERNEL(label_smooth,
GPU,
ALL_LAYOUT,
phi::LabelSmoothKernel,
float,
double,
phi::dtype::float16) {}
3 changes: 3 additions & 0 deletions paddle/phi/kernels/gpu/prelu_funcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ void PreluChannelWiseDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
stream>>>(
input, alpha, output, channel, numel);
} else {
printf("debug: spatial: %d, ch_num: %d\n",
static_cast<int>(numel / batch_size / channel),
static_cast<int>(channel));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

与PR的功能无关?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已移除

PReluChannelFirstWiseKernel<<<PADDLE_GET_BLOCKS(numel),
CUDA_NUM_THREADS,
0,
Expand Down
25 changes: 25 additions & 0 deletions python/paddle/fluid/tests/unittests/test_label_smooth_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,31 @@ def setUp(self):
)


class TestLabelSmoothFP16(unittest.TestCase):
def check_main(self, x_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
x.stop_gradient = False
y = paddle.nn.functional.label_smooth(x, epsilon=0.1)
x_g = paddle.grad(y, [x])
y_np = y.numpy().astype('float32')
x_g_np = x_g[0].numpy().astype('float32')
paddle.enable_static()
return y_np, x_g_np

def test_main(self):
if not paddle.is_compiled_with_cuda():
return

np.random.seed(20)
x_np = np.random.random([10, 12])
y_np_1, x_g_np_1 = self.check_main(x_np, 'float16')
y_np_2, x_g_np_2 = self.check_main(x_np, 'float32')

np.testing.assert_allclose(y_np_1, y_np_2, rtol=1e-03)
np.testing.assert_allclose(x_g_np_1, x_g_np_2, rtol=1e-03)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要按照低精度单测规范,添加Op的单测。另外,添加完OpTest后,这里的case可以去掉了。https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/dev_guides/amp_precision/amp_test_dev_guide_cn.html#step2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

if __name__ == '__main__':
paddle.enable_static()
unittest.main()
5 changes: 3 additions & 2 deletions python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1922,7 +1922,8 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
label(Tensor): The input variable containing the label data. The
label data should use one-hot representation. It's
a multidimensional tensor with a shape of
:math:`[N_1, ..., Depth]`, where Depth is class number. The dtype can be "float32" and "float64".
:math:`[N_1, ..., Depth]`, where Depth is class number.
The dtype can be "float16", "float32" and "float64".
prior_dist(Tensor, optional): The prior distribution to be used to smooth
labels. If not provided, an uniform distribution
is used. It's a multidimensional tensor with a shape of
Expand Down Expand Up @@ -1964,7 +1965,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
)

check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'label_smooth'
label, 'label', ['float16', 'float32', 'float64'], 'label_smooth'
)

helper = LayerHelper("label_smooth", **locals())
Expand Down