diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c0c9edd56dbc2..17154ec1a8ebd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4596,14 +4596,12 @@ static __global__ void rope_neox( const int i = row*ncols + col/2; const int i2 = row/p_delta_rows; - // simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero - const float cur_rot = -float(col)/ncols; - const int p = has_pos ? pos[i2] : 0; - const float theta_base = p*powf(freq_base, cur_rot); + const float theta_base = p*powf(freq_base, -float(col)/ncols); + // rotation amount is `ib * ncols + col`, but ib is assumed to be zero float cos_theta, sin_theta; - rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta); const float x0 = x[i + 0]; const float x1 = x[i + ncols/2]; diff --git a/ggml-metal.metal b/ggml-metal.metal index 5d1357cd72d45..e8032dd5b4861 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1277,10 +1277,9 @@ kernel void kernel_rope( for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) { - // simplified from `(ib * n_dims + ic) * inv_ndims` - const float cur_rot = inv_ndims*ic - ib; + const int64_t cur_rot = ib * n_dims + ic; - const float theta = theta_0 * pow(freq_base, cur_rot); + const float theta = theta_0 * pow(freq_base, inv_ndims*cur_rot); float cos_theta, sin_theta; rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); diff --git a/ggml.c b/ggml.c index 3202a517b7868..012b326baf0e1 100644 --- a/ggml.c +++ b/ggml.c @@ -11084,8 +11084,7 @@ static void ggml_compute_forward_rope_f32( theta_base *= freq_scale; for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 0; ic < n_dims; ic += 2) { - // simplified from `(ib * n_dims + ic) * inv_ndims` - float cur_rot = inv_ndims * ic - ib; + int64_t cur_rot = ib * n_dims + ic; float cos_theta, sin_theta; rope_yarn( @@ -11237,8 +11236,7 @@ static void ggml_compute_forward_rope_f16( theta_base *= freq_scale; for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 0; ic < n_dims; ic += 2) { - // simplified from `(ib * n_dims + ic) * inv_ndims` - float cur_rot = inv_ndims * ic - ib; + int64_t cur_rot = ib * n_dims + ic; float cos_theta, sin_theta; rope_yarn(