Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YaRN : correction to GPT-NeoX implementation #4093

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4596,14 +4596,12 @@ static __global__ void rope_neox(
const int i = row*ncols + col/2;
const int i2 = row/p_delta_rows;

// simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero
const float cur_rot = -float(col)/ncols;

const int p = has_pos ? pos[i2] : 0;
const float theta_base = p*powf(freq_base, cur_rot);
const float theta_base = p*powf(freq_base, -float(col)/ncols);

// rotation amount is `ib * ncols + col`, but ib is assumed to be zero
float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);

const float x0 = x[i + 0];
const float x1 = x[i + ncols/2];
Expand Down
5 changes: 2 additions & 3 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1277,10 +1277,9 @@ kernel void kernel_rope(
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {

// simplified from `(ib * n_dims + ic) * inv_ndims`
const float cur_rot = inv_ndims*ic - ib;
const int64_t cur_rot = ib * n_dims + ic;

const float theta = theta_0 * pow(freq_base, cur_rot);
const float theta = theta_0 * pow(freq_base, inv_ndims*cur_rot);
float cos_theta, sin_theta;
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);

Expand Down
6 changes: 2 additions & 4 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -11084,8 +11084,7 @@ static void ggml_compute_forward_rope_f32(
theta_base *= freq_scale;
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;
int64_t cur_rot = ib * n_dims + ic;

float cos_theta, sin_theta;
rope_yarn(
Expand Down Expand Up @@ -11237,8 +11236,7 @@ static void ggml_compute_forward_rope_f16(
theta_base *= freq_scale;
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;
int64_t cur_rot = ib * n_dims + ic;

float cos_theta, sin_theta;
rope_yarn(
Expand Down
Loading