diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 819f31c8a300c..2f9c05abac6c4 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -98,6 +98,33 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t l return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type); + + const char * pool_str = "undefined"; + switch (op_pool) { + case GGML_OP_POOL_AVG: pool_str = "avg"; break; + case GGML_OP_POOL_MAX: pool_str = "max"; break; + default: GGML_ASSERT(false && "not implemented"); + }; + + char base[256]; + char name[256]; + + snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type)); + snprintf(name, sizeof(name), "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { GGML_ASSERT(ggml_is_contiguous(op->src[0])); GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index f6ebf90a00ee8..e9b80358dda44 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -103,6 +103,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_1d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 523f9d71ba14e..7cfc107526755 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -673,7 +673,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_IM2COL: return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32); case GGML_OP_POOL_1D: - return false; + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; case GGML_OP_POOL_2D: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 88c98423ebec0..38250bd0832df 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -714,6 +714,16 @@ typedef struct { int64_t np; } ggml_metal_kargs_pool_2d; + +typedef struct { + int32_t k0; + int32_t s0; + int32_t p0; + int64_t IW; + int64_t OW; + int64_t np; +} ggml_metal_kargs_pool_1d; + typedef struct { int64_t ne00; uint64_t nb01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e85a223c01dc3..dba689e6b7531 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -390,6 +390,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_cpy(ctx, idx); } break; + case GGML_OP_POOL_1D: + { + n_fuse = ggml_metal_op_pool_1d(ctx, idx); + } break; case GGML_OP_POOL_2D: { n_fuse = ggml_metal_op_pool_2d(ctx, idx); @@ -1331,6 +1335,58 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const int32_t * opts = op->op_params; + ggml_op_pool op_pool = (ggml_op_pool) opts[0]; + + const int32_t k0 = opts[1]; + const int32_t s0 = opts[2]; + const int32_t p0 = opts[3]; + + const int64_t IW = op->src[0]->ne[0]; + + // planes are the remaining dims: N * OC * (and OH,OW in 2D; for 1D only OC & N + OW) + const int64_t N = op->ne[3]; + const int64_t OC = op->ne[2]; + const int64_t OW = op->ne[0]; + + const int64_t np = N * OC * OW; + + ggml_metal_kargs_pool_1d args_pool_1d = { + /* .k0 = */ k0, + /* .s0 = */ s0, + /* .p0 = */ p0, + /* .IW = */ IW, + /* .OW = */ OW, + /* .np = */ np + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np); + const int ntg = (np + nth - 1) / nth; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1); + + return 1; +} + + int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 8df4c72e7c8cb..9baf8dd5ea18d 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -56,6 +56,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_mul_mat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 96df6f0ce62de..3968eaf2d00b0 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -8720,3 +8720,68 @@ kernel void kernel_pool_2d_avg_f32( o_ptr[cur_oh * args.OW + cur_ow] = res; } + + +kernel void kernel_pool_1d_max_f32( + constant ggml_metal_kargs_pool_1d & args, + device const float * src0, + device float * dst, + uint gid[[thread_position_in_grid]]) { + + if (gid >= args.np) { + return; + } + + const int idx = gid; + const int O_L = args.OW; + const int nc = idx / O_L; + const int cur_o0 = idx % O_L; + + device const float * i_ptr = src0 + nc * args.IW; + device float * o_ptr = dst + nc * O_L; + + const int start = cur_o0 * args.s0 - args.p0; + const int b = MAX(0, start); + const int e = MIN(args.IW, start + args.k0); + + float res = -INFINITY; + + for (int j = b; j < e; ++j) { + res = MAX(res, i_ptr[j]); + } + + o_ptr[cur_o0] = res; +} + +kernel void kernel_pool_1d_avg_f32( + constant ggml_metal_kargs_pool_1d & args, + device const float * src0, + device float * dst, + uint gid[[thread_position_in_grid]]) { + + if (gid >= args.np) { + return; + } + + const int idx = gid; + const int O_L = args.OW; + const int nc = idx / O_L; + const int cur_o0 = idx % O_L; + + device const float * i_ptr = src0 + nc * args.IW; + device float * o_ptr = dst + nc * O_L; + + const int start = cur_o0 * args.s0 - args.p0; + const int b = MAX(0, start); + const int e = MIN(args.IW, start + args.k0); + + const float scale = 1.0f / args.k0; + + float res = 0.0f; + + for (int j = b; j < e; ++j) { + res += i_ptr[j] * scale; + } + + o_ptr[cur_o0] = res; +} diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c1e45972e54ca..c6e8076e92a8e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4017,6 +4017,37 @@ struct test_pool2d : public test_case { } }; +// GGML_OP_POOL1D +struct test_pool1d : public test_case { + enum ggml_op_pool pool_type; + const ggml_type type_input; + const std::array ne_input; + const int k0; + const int s0; + const int p0; + + std::string vars() override { + return VARS_TO_STR6(pool_type, type_input, ne_input, k0, s0, p0); + } + + test_pool1d(ggml_op_pool pool_type = GGML_OP_POOL_AVG, + ggml_type type_input = GGML_TYPE_F32, + std::array ne_input = {10, 1, 1, 1}, + int k0 = 3, int s0 = 3, int p0 = 0) + : pool_type(pool_type), type_input(type_input), ne_input(ne_input), k0(k0), s0(s0), p0(p0) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data()); + ggml_set_param(input); + ggml_set_name(input, "input"); + + ggml_tensor * out = ggml_pool_1d(ctx, input, pool_type, k0, s0, p0); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_CONV_TRANSPOSE_1D struct test_conv_transpose_1d : public test_case { const std::array ne_input; @@ -5832,6 +5863,20 @@ static std::vector> make_test_cases_eval() { } } + for (ggml_type type_input : {GGML_TYPE_F32}) { + for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) { + for (int k0 : {1, 3}) { + for (int s0 : {1, 2, 3}) { + for (int p0 : {0, 1}) { + test_cases.emplace_back(new test_pool1d(pool_type, type_input, { 10, 1, 1, 1 }, k0, s0, p0)); + test_cases.emplace_back(new test_pool1d(pool_type, type_input, { 11, 1, 1, 2 }, k0, s0, p0)); + test_cases.emplace_back(new test_pool1d(pool_type, type_input, { 128, 1, 1, 3 }, k0, s0, p0)); + } + } + } + } + } + #if 0 // >4GB im2col destination. Too slow to run by default. // Test cases taken from Wan2.1 T2V 1.3B.