Skip to content

Commit

Permalink
metal : add rope_f16 kernel + optimize cpy kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Sep 17, 2023
1 parent 1fb033f commit 57cea73
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 16 deletions.
28 changes: 18 additions & 10 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_DECL_KERNEL(rope);
GGML_METAL_DECL_KERNEL(rope_f32);
GGML_METAL_DECL_KERNEL(rope_f16);
GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
Expand Down Expand Up @@ -261,7 +262,8 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_ADD_KERNEL(rope);
GGML_METAL_ADD_KERNEL(rope_f32);
GGML_METAL_ADD_KERNEL(rope_f16);
GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
Expand Down Expand Up @@ -335,7 +337,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_DEL_KERNEL(rope);
GGML_METAL_DEL_KERNEL(rope_f32);
GGML_METAL_DEL_KERNEL(rope_f16);
GGML_METAL_DEL_KERNEL(alibi_f32);
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
Expand Down Expand Up @@ -870,7 +873,7 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_SOFT_MAX:
{
const int nth = 32;
const int nth = MIN(32, ne00);

if (ne00%4 == 0) {
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
Expand Down Expand Up @@ -1134,7 +1137,7 @@ void ggml_metal_graph_compute(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

const int nth = 512;
const int nth = MIN(512, ne00);

[encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
Expand All @@ -1153,7 +1156,7 @@ void ggml_metal_graph_compute(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

const int nth = 256;
const int nth = MIN(256, ne00);

[encoder setComputePipelineState:ctx->pipeline_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
Expand Down Expand Up @@ -1212,7 +1215,7 @@ void ggml_metal_graph_compute(
{
GGML_ASSERT(ne10 == ne02);

//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];

Expand All @@ -1221,7 +1224,12 @@ void ggml_metal_graph_compute(
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));

[encoder setComputePipelineState:ctx->pipeline_rope];
switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
default: GGML_ASSERT(false);
};

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
Expand All @@ -1241,7 +1249,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
//[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
[encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
Expand All @@ -1253,7 +1261,7 @@ void ggml_metal_graph_compute(
case GGML_OP_CPY:
case GGML_OP_CONT:
{
const int nth = 32;
const int nth = MIN(1024, ne00);

switch (src0t) {
case GGML_TYPE_F32:
Expand Down
45 changes: 39 additions & 6 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,36 @@ kernel void kernel_alibi_f32(
}
}

typedef void (rope_t)(
device const void * src0,
device const int32_t * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant int & n_past,
constant int & n_dims,
constant int & mode,
constant float & freq_base,
constant float & freq_scale,
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]);

template<typename T>
kernel void kernel_rope(
device const void * src0,
device const int32_t * src1,
Expand Down Expand Up @@ -901,11 +931,11 @@ kernel void kernel_rope(
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);

device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

const float x0 = src[0];
const float x1 = src[1];
const T x0 = src[0];
const T x1 = src[1];

dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
Expand All @@ -920,8 +950,8 @@ kernel void kernel_rope(

const int64_t i0 = ib*n_dims + ic/2;

device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

const float x0 = src[0];
const float x1 = src[n_dims/2];
Expand All @@ -933,6 +963,9 @@ kernel void kernel_rope(
}
}

template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;

kernel void kernel_cpy_f16_f16(
device const half * src0,
device half * dst,
Expand Down

0 comments on commit 57cea73

Please sign in to comment.