Skip to content

Commit

Permalink
ggml : ggml_rope now takes a vector with positions instead of n_past
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Sep 17, 2023
1 parent 3b4bab6 commit d4cd263
Show file tree
Hide file tree
Showing 9 changed files with 270 additions and 131 deletions.
37 changes: 31 additions & 6 deletions examples/baby-llama/baby-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,14 @@ struct ggml_tensor * forward(
struct ggml_tensor * kc = kv_self.k;
struct ggml_tensor * vc = kv_self.v;

struct ggml_tensor * KQ_rope = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
{
int * data = (int *) KQ_rope->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}

// inpL shape [n_embd,N,1,1]
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
for (int il = 0; il < n_layer; ++il) {
Expand Down Expand Up @@ -583,8 +591,8 @@ struct ggml_tensor * forward(
// wk shape [n_embd, n_embd, 1, 1]
// Qcur shape [n_embd/n_head, n_head, N, 1]
// Kcur shape [n_embd/n_head, n_head, N, 1]
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_rope, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_rope, n_rot, 0, 0);

// store key and value to memory
{
Expand Down Expand Up @@ -810,9 +818,18 @@ struct ggml_tensor * forward_batch(
struct ggml_tensor * kc = kv_self.k;
struct ggml_tensor * vc = kv_self.v;

struct ggml_tensor * KQ_rope = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
{
int * data = (int *) KQ_rope->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}

// inpL shape [n_embd,N*n_batch,1]
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
assert_shape_2d(inpL, n_embd, N*n_batch);

for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;

Expand Down Expand Up @@ -840,8 +857,8 @@ struct ggml_tensor * forward_batch(
// wk shape [n_embd, n_embd, 1, 1]
// Qcur shape [n_embd/n_head, n_head, N, n_batch]
// Kcur shape [n_embd/n_head, n_head, N, n_batch]
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_rope, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_rope, n_rot, 0, 0);
assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);

Expand Down Expand Up @@ -1100,6 +1117,14 @@ struct ggml_tensor * forward_lora(
struct ggml_tensor * kc = kv_self.k;
struct ggml_tensor * vc = kv_self.v;

struct ggml_tensor * KQ_rope = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
{
int * data = (int *) KQ_rope->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}

// inpL shape [n_embd,N,1,1]
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
for (int il = 0; il < n_layer; ++il) {
Expand Down Expand Up @@ -1133,7 +1158,7 @@ struct ggml_tensor * forward_lora(
model->layers[il].wqb,
cur)),
n_embd/n_head, n_head, N),
n_past, n_rot, 0, 0);
KQ_rope, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0,
ggml_reshape_3d(ctx0,
ggml_mul_mat(ctx0,
Expand All @@ -1142,7 +1167,7 @@ struct ggml_tensor * forward_lora(
model->layers[il].wkb,
cur)),
n_embd/n_head, n_head, N),
n_past, n_rot, 0, 0);
KQ_rope, n_rot, 0, 0);

// store key and value to memory
{
Expand Down
14 changes: 11 additions & 3 deletions examples/train-text-from-scratch/train-text-from-scratch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,15 +679,23 @@ struct ggml_tensor * llama_build_train_graphs(
}
};

// KQ_rope - contains the positions
struct ggml_tensor * KQ_rope = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
{
int * data = (int *) KQ_rope->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}

// rope has so much parameters that we make a custom function for it
auto rope = [ctx, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
auto rope = [ctx, KQ_rope, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
(struct ggml_tensor * t) -> struct ggml_tensor * {
// not capturing these, to silcence warnings
const int n_past = 0;
const int rope_mode = 0;

return ggml_rope_custom(ctx,
t, n_past, n_rot, rope_mode, n_ctx,
t, KQ_rope, n_rot, rope_mode, n_ctx,
rope_freq_base, rope_freq_scale);
};

Expand Down
49 changes: 26 additions & 23 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,9 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_ROPE:
{
const int n_past = ((int32_t *) dst->op_params)[0];
GGML_ASSERT(ne10 == ne02);

//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];

Expand All @@ -1221,28 +1223,29 @@ void ggml_metal_graph_compute(

[encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
//[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
[encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break;
Expand Down
55 changes: 29 additions & 26 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -854,29 +854,30 @@ kernel void kernel_alibi_f32(
}

kernel void kernel_rope(
device const void * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant int & n_past,
constant int & n_dims,
constant int & mode,
constant float & freq_base,
constant float & freq_scale,
device const void * src0,
device const int32_t * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant int & n_past,
constant int & n_dims,
constant int & mode,
constant float & freq_base,
constant float & freq_scale,
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
Expand All @@ -886,7 +887,9 @@ kernel void kernel_rope(

const bool is_neox = mode & 2;

const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
device const int32_t * pos = src1;

const int64_t p = pos[i2];

const float theta_0 = freq_scale * (float)p;
const float inv_ndims = -1.f/n_dims;
Expand Down Expand Up @@ -1320,8 +1323,8 @@ kernel void kernel_mul_mat_q3_K_f32(

float yl[32];

const uint16_t kmask1 = 0x3030;
const uint16_t kmask2 = 0x0f0f;
//const uint16_t kmask1 = 0x3030;
//const uint16_t kmask2 = 0x0f0f;

const int tid = tiisg/4;
const int ix = tiisg%4;
Expand Down
Loading

0 comments on commit d4cd263

Please sign in to comment.