From 206b9aaebed16edae685b23ddee659dff316b6e8 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Wed, 5 Jul 2023 11:11:31 +0800 Subject: [PATCH] fix round type --- paddle/phi/kernels/gpu/rms_norm_kernel.cu | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/gpu/rms_norm_kernel.cu b/paddle/phi/kernels/gpu/rms_norm_kernel.cu index 5b6944ec6bc55..3b81cddc95f6d 100644 --- a/paddle/phi/kernels/gpu/rms_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/rms_norm_kernel.cu @@ -870,7 +870,12 @@ __forceinline__ __device__ OutType QuantHelperFunc(const InType input, const float min_bound) { float quant_value = max_bound * scale * input; - quant_value = static_cast(round(quant_value)); + if (round_type == 0) { + quant_value = static_cast(rint(quant_value)); + } else { + quant_value = static_cast(round(quant_value)); + } + return static_cast( ClipFunc(quant_value, min_bound, max_bound)); }