From ba048d39c89095d3c648829b48adfe16ebc54332 Mon Sep 17 00:00:00 2001 From: tianhaodongbd Date: Thu, 21 Dec 2023 15:14:16 +0800 Subject: [PATCH 1/2] fix fused_rope --- paddle/phi/kernels/fusion/gpu/fused_rope_utils.h | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h index 972f5ee633bbb..53eddc594bcc0 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h @@ -125,10 +125,11 @@ __global__ void VectorizedFusedRopeWithRotateEveryTwoKernel( MPType p0 = static_cast(input[pr_index]); MPType p1 = static_cast(input[ls_index]); - result[pr_index] = - cos_value[pr_index] * p0 - sign * sin_value[ls_index] * p1; - result[ls_index] = - cos_value[ls_index] * p1 + sign * sin_value[pr_index] * p0; + result[pr_index] = cos_value[pr_index] * p0; + result[pr_index] -= sign * sin_value[pr_index] * p1; + + result[ls_index] = sign * sin_value[ls_index] * p0; + result[ls_index] += cos_value[ls_index] * p1; store[pr_index] = static_cast(result[pr_index]); store[ls_index] = static_cast(result[ls_index]); From 17e3805728e0ce098c881f3689547eb4ee3b298b Mon Sep 17 00:00:00 2001 From: tianhaodongbd Date: Thu, 21 Dec 2023 21:17:02 +0800 Subject: [PATCH 2/2] change cal result --- .../phi/kernels/fusion/gpu/fused_rope_utils.h | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h index 53eddc594bcc0..0db16ffb7e20b 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h @@ -125,11 +125,18 @@ __global__ void VectorizedFusedRopeWithRotateEveryTwoKernel( MPType p0 = static_cast(input[pr_index]); MPType p1 = static_cast(input[ls_index]); - result[pr_index] = cos_value[pr_index] * p0; - result[pr_index] -= sign * sin_value[pr_index] * p1; - - result[ls_index] = sign * sin_value[ls_index] * p0; - result[ls_index] += cos_value[ls_index] * p1; + if (sign == 1) { + result[pr_index] = cos_value[pr_index] * p0; + result[pr_index] -= sin_value[pr_index] * p1; + + result[ls_index] = sin_value[ls_index] * p0; + result[ls_index] += cos_value[ls_index] * p1; + } else if (sign == -1) { + result[pr_index] = + cos_value[pr_index] * p0 + sin_value[ls_index] * p1; + result[ls_index] = + cos_value[ls_index] * p1 - sin_value[pr_index] * p0; + } store[pr_index] = static_cast(result[pr_index]); store[ls_index] = static_cast(result[ls_index]);