Skip to content

Commit

Permalink
fix round type
Browse files Browse the repository at this point in the history
  • Loading branch information
MARD1NO committed Jul 5, 2023
1 parent f77ef8f commit 206b9aa
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion paddle/phi/kernels/gpu/rms_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,12 @@ __forceinline__ __device__ OutType QuantHelperFunc(const InType input,
const float min_bound) {
float quant_value = max_bound * scale * input;

quant_value = static_cast<float>(round(quant_value));
if (round_type == 0) {
quant_value = static_cast<float>(rint(quant_value));
} else {
quant_value = static_cast<float>(round(quant_value));
}

return static_cast<OutType>(
ClipFunc<float>(quant_value, min_bound, max_bound));
}
Expand Down

0 comments on commit 206b9aa

Please sign in to comment.