From ab1307cc05ea3c49027709ca271af486a7f1a7ff Mon Sep 17 00:00:00 2001 From: michal-miotk Date: Tue, 30 Jul 2024 22:00:50 +0000 Subject: [PATCH] test passed --- .../include/intel_gpu/primitives/lstm.hpp | 4 +- .../src/graph/impls/ocl/lstm_seq.cpp | 10 +-- .../cl_kernels/lstm_seq_gpu_bfyx_ref.cl | 72 ++----------------- src/plugins/intel_gpu/src/plugin/ops/rnn.cpp | 7 +- 4 files changed, 10 insertions(+), 83 deletions(-) diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/lstm.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/lstm.hpp index ef41e0fe9e4246..cf82f116254784 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/lstm.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/lstm.hpp @@ -183,7 +183,7 @@ struct lstm_seq : public primitive_base { const lstm_weights_order offset_order = lstm_weights_order::iofz, const uint32_t direction = 0, const padding& output_padding = padding()) - : primitive_base(id, {x, initial_hidden_state, initial_cell_state, seq_lenghts, W, R, B}, {output_padding}, {}, 2), + : primitive_base(id, {x, initial_hidden_state, initial_cell_state, seq_lenghts, W, R, B, out2_prim_id}, {output_padding}, {}, 1), out2_prim_id(out2_prim_id), cell(cell), clip(clip), @@ -282,7 +282,7 @@ struct lstm_seq : public primitive_base { ret.push_back(cell); */ //ret.push_back(out1_prim_id); - ret.push_back(out2_prim_id); + //ret.push_back(out2_prim_id); return ret; } }; diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/lstm_seq.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/lstm_seq.cpp index 6f4c1fb9e4ff45..c6fc8d2ff2a1ee 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/lstm_seq.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/lstm_seq.cpp @@ -26,15 +26,7 @@ struct lstm_seq_impl : typed_primitive_impl_ocl { protected: kernel_arguments_data get_arguments(const typed_primitive_inst& instance) const override { kernel_arguments_data args = parent::get_arguments(instance); - args.inputs = { instance.dep_memory_ptr(0), instance.dep_memory_ptr(1), instance.dep_memory_ptr(2), instance.dep_memory_ptr(3), - instance.dep_memory_ptr(4), instance.dep_memory_ptr(5), instance.dep_memory_ptr(6)}; - args.cell = instance.cell_term() ? instance.cell_memory() : nullptr; - // New API for mutiple outputs support - for (size_t i = 0; i < 1; i++) { - args.outputs.push_back(instance.output_memory_ptr(i)); - } - args.outputs.push_back(instance.second_output_mem()); - args.outputs.push_back(instance.third_output_mem()); + args.outputs.push_back(instance.dep_memory_ptr(instance.desc()->input_size() - 1)); return args; } diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/lstm_seq_gpu_bfyx_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/lstm_seq_gpu_bfyx_ref.cl index d30f1fb384bbe8..8e630841c9802b 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/lstm_seq_gpu_bfyx_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/lstm_seq_gpu_bfyx_ref.cl @@ -13,50 +13,13 @@ KERNEL(lstm_seq)( const __global INPUT5_TYPE* R, const __global INPUT6_TYPE* B, __global OUTPUT_TYPE* hidden_history, - __global OUTPUT2_TYPE* cell_state + __global OUTPUT1_TYPE* cell_state ) { const uint hidden_idx = get_global_id(0); float local_hidden_state = 0; const uint b = get_global_id(1); - for(int i=0;i