Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement customizable RoPE #2054

Merged
merged 10 commits into from
Jul 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_ctx = std::stoi(argv[i]);
} else if (arg == "--rope-freq-base") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_freq_base = std::stof(argv[i]);
} else if (arg == "--rope-freq-scale") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_freq_scale = std::stof(argv[i]);
} else if (arg == "--memory-f32") {
params.memory_f16 = false;
} else if (arg == "--top-p") {
Expand Down Expand Up @@ -493,6 +505,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
fprintf(stderr, " --cfg-smooth-factor N smooth factor between old and new logits (default: %f, 1.0 = no smoothing)\n", params.cfg_smooth_factor);
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
Expand Down Expand Up @@ -573,6 +587,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
lparams.use_mlock = params.use_mlock;
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;
lparams.rope_freq_base = params.rope_freq_base;
lparams.rope_freq_scale = params.rope_freq_scale;

return lparams;
}
Expand Down
2 changes: 2 additions & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ struct gpt_params {
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
float rope_freq_base = 10000.0f; // RoPE base frequency
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor

// sampling parameters
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
Expand Down
12 changes: 10 additions & 2 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,17 @@ int main(int argc, char ** argv) {
return 0;
}

if (params.rope_freq_base != 10000.0) {
fprintf(stderr, "%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
}

if (params.rope_freq_scale != 1.0) {
fprintf(stderr, "%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
}

if (params.n_ctx > 2048) {
fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
"expect poor results\n", __func__, params.n_ctx);
fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified);"
" you are on your own\n", __func__, params.n_ctx);
} else if (params.n_ctx < 8) {
fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
Expand Down
1 change: 1 addition & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Using [curl](https://curl.se/). On Windows `curl.exe` should be available in the
```sh
curl --request POST \
--url http://localhost:8080/completion \
--header "Content-Type: application/json" \
--data '{"prompt": "Building a website can be done in 10 simple steps:","n_predict": 128}'
```

Expand Down
2 changes: 2 additions & 0 deletions examples/server/chat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ tokenize() {
--silent \
--request POST \
--url "${API_URL}/tokenize" \
--header "Content-Type: application/json" \
--data-raw "$(jq -ns --arg content "$1" '{content:$content}')" \
| jq '.tokens[]'
}
Expand Down Expand Up @@ -64,6 +65,7 @@ chat_completion() {
--no-buffer \
--request POST \
--url "${API_URL}/completion" \
--header "Content-Type: application/json" \
--data-raw "${DATA}")

printf "\n"
Expand Down
18 changes: 18 additions & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
fprintf(stderr, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n");
Expand Down Expand Up @@ -722,6 +724,22 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.n_ctx = std::stoi(argv[i]);
}
else if (arg == "--rope-freq-base")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_freq_base = std::stof(argv[i]);
}
else if (arg == "--rope-freq-scale")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_freq_scale = std::stof(argv[i]);
}
else if (arg == "--memory-f32" || arg == "--memory_f32")
{
params.memory_f16 = false;
Expand Down
45 changes: 26 additions & 19 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -885,28 +885,35 @@ void ggml_metal_graph_compute(

const int n_past = ((int32_t *)(src1->data))[0];

float freq_base;
float freq_scale;
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));

[encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
Expand Down
6 changes: 4 additions & 2 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -624,17 +624,19 @@ kernel void kernel_rope(
constant int & n_past,
constant int & n_dims,
constant int & mode,
constant float & freq_base,
constant float & freq_scale,
uint3 tpig[[thread_position_in_grid]]) {
const int64_t i3 = tpig[2];
const int64_t i2 = tpig[1];
const int64_t i1 = tpig[0];

const bool is_neox = mode & 2;
const float theta_scale = pow(10000.0, -2.0f/n_dims);
const float theta_scale = pow(freq_base, -2.0f/n_dims);

const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);

float theta = (float)p;
float theta = freq_scale * (float)p;

if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
Expand Down
50 changes: 38 additions & 12 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -6954,6 +6954,8 @@ struct ggml_tensor * ggml_rope_impl(
int n_past,
int n_dims,
int mode,
float freq_base,
float freq_scale,
int n_ctx,
bool inplace) {
GGML_ASSERT(n_past >= 0);
Expand All @@ -6967,12 +6969,14 @@ struct ggml_tensor * ggml_rope_impl(

ggml_scratch_save(ctx);

struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6);

((int32_t *) b->data)[0] = n_past;
((int32_t *) b->data)[1] = n_dims;
((int32_t *) b->data)[2] = mode;
((int32_t *) b->data)[3] = n_ctx;
memcpy((int32_t *) b->data + 4, &freq_base, sizeof(float));
memcpy((int32_t *) b->data + 5, &freq_scale, sizeof(float));

ggml_scratch_load(ctx);

Expand All @@ -6991,7 +6995,7 @@ struct ggml_tensor * ggml_rope(
int n_dims,
int mode,
int n_ctx) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false);
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false);
}

struct ggml_tensor * ggml_rope_inplace(
Expand All @@ -7001,7 +7005,19 @@ struct ggml_tensor * ggml_rope_inplace(
int n_dims,
int mode,
int n_ctx) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true);
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true);
}

struct ggml_tensor * ggml_rope_custom_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
int mode,
float freq_base,
float freq_scale,
int n_ctx) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true);
}

// ggml_rope_back
Expand Down Expand Up @@ -12062,16 +12078,21 @@ static void ggml_compute_forward_rope_f32(
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_nelements(src1) == 4);
GGML_ASSERT(ggml_nelements(src1) == 6);

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}

float freq_base;
float freq_scale;

const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
const int n_ctx = ((int32_t *) src1->data)[3];
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));

assert(n_past >= 0);

Expand Down Expand Up @@ -12100,7 +12121,7 @@ static void ggml_compute_forward_rope_f32(
// row index used to determine which thread to use
int ir = 0;

const float theta_scale = powf(10000.0, -2.0f/n_dims);
const float theta_scale = powf(freq_base, -2.0f/n_dims);

const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
Expand All @@ -12112,7 +12133,7 @@ static void ggml_compute_forward_rope_f32(
if (ir++ < ir0) continue;
if (ir > ir1) break;

float theta = (float)p;
float theta = freq_scale * (float)p;

if (is_glm) {
theta = MIN(p, n_ctx - 2);
Expand Down Expand Up @@ -12189,16 +12210,21 @@ static void ggml_compute_forward_rope_f16(
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_nelements(src1) == 4);
GGML_ASSERT(ggml_nelements(src1) == 6);

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}

float freq_base;
float freq_scale;

const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
const int n_ctx = ((int32_t *) src1->data)[3];
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));

assert(n_past >= 0);

Expand Down Expand Up @@ -12227,7 +12253,7 @@ static void ggml_compute_forward_rope_f16(
// row index used to determine which thread to use
int ir = 0;

const float theta_scale = powf(10000.0, -2.0f/n_dims);
const float theta_scale = powf(freq_base, -2.0f/n_dims);

const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
Expand All @@ -12239,7 +12265,7 @@ static void ggml_compute_forward_rope_f16(
if (ir++ < ir0) continue;
if (ir > ir1) break;

float theta = (float)p;
float theta = freq_scale * (float)p;

if (is_glm) {
theta = MIN(p, n_ctx - 2);
Expand Down Expand Up @@ -12300,7 +12326,7 @@ static void ggml_compute_forward_rope_f16(
const float x0 = GGML_FP16_TO_FP32(src[0]);
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);

dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
}
Expand Down Expand Up @@ -15712,7 +15738,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama
if (src0->grad) {
assert(src1->type == GGML_TYPE_I32);
assert(ggml_nelements(src1) == 4);
assert(ggml_nelements(src1) == 6);
const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
Expand All @@ -15733,7 +15759,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
if (src0->grad) {
assert(src1->type == GGML_TYPE_I32);
assert(ggml_nelements(src1) == 4);
assert(ggml_nelements(src1) == 3);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be 6? Based on the code immediately after it should be at least 4, I think, not 3.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went through the code an I also can't see why it's 3 when the lines just below it show it clearly taking 4 elements and looks like it designed to fail the assertion

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed in 513f861

const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
Expand Down
11 changes: 11 additions & 0 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,17 @@ extern "C" {
int mode,
int n_ctx);

// custom RoPE, in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
int mode,
float freq_base,
float freq_scale,
int n_ctx);

// rotary position embedding backward, i.e compute dx from dy
// a - dy
GGML_API struct ggml_tensor * ggml_rope_back(
Expand Down
Loading