diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/lstm.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/lstm.hpp index c43896c486078e..8e063d68a3e0a0 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/lstm.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/lstm.hpp @@ -282,8 +282,8 @@ struct lstm_seq : public primitive_base { if (!cell.empty()) ret.push_back(cell); */ - //ret.push_back(second_output); - //ret.push_back(third_output); + ret.push_back(out1_prim_id); + ret.push_back(out2_prim_id); return ret; } }; diff --git a/src/plugins/intel_gpu/src/graph/include/lstm_seq_inst.h b/src/plugins/intel_gpu/src/graph/include/lstm_seq_inst.h index 110bafed0f92a6..1436923c6653ef 100644 --- a/src/plugins/intel_gpu/src/graph/include/lstm_seq_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/lstm_seq_inst.h @@ -65,8 +65,12 @@ class typed_primitive_inst : public typed_primitive_inst_base; - } // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/lstm_seq.cpp b/src/plugins/intel_gpu/src/graph/lstm_seq.cpp index e854c4a02bd8e5..bb4a15ab5aa22e 100644 --- a/src/plugins/intel_gpu/src/graph/lstm_seq.cpp +++ b/src/plugins/intel_gpu/src/graph/lstm_seq.cpp @@ -52,6 +52,26 @@ std::vector lstm_seq_inst::calc_output_layouts(lstm_seq_node const& node template std::vector lstm_seq_inst::calc_output_layouts(lstm_seq_node const& node, const kernel_impl_params& impl_param); +void lstm_seq_inst::on_execute() { + update_output_memory(); +} + +void lstm_seq_inst::update_output_memory() { + if (!can_be_optimized()) + return; + + for (size_t i = 1; i < 3; i++) { + if (node->get_program().is_new_shape_infer() && input_memory_ptr(i+6) == nullptr) + return; + + if (output_memory_ptr(i) != nullptr && _network.get_engine().is_the_same_buffer(output_memory(i), input_memory(i+6))) + return; + + _outputs[i] = {_network.get_engine().reinterpret_buffer(input_memory(i+6), _impl_params->get_output_layout(i))}; + } +} + + std::string lstm_seq_inst::to_string(lstm_seq_node const& node) { auto desc = node.get_primitive(); auto node_info = node.desc_to_json(); diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/lstm_seq_gpu_bfyx_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/lstm_seq_gpu_bfyx_ref.cl index 74c0b6f1dd042a..221aec1892cf88 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/lstm_seq_gpu_bfyx_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/lstm_seq_gpu_bfyx_ref.cl @@ -139,5 +139,5 @@ KERNEL(lstm_seq)( printf("hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)] is %f\n", hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)]); } //printf("cell state for %d is %f \n", OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0), cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)]); - //printf("R is %p B is %p ; %p out0 %p add for out1 for out2 %p batch %d\n", &R, &B, &hidden_history, &hidden_state, &cell_state, b); + printf("R is %p B is %p ; hidden history %p hidden state %p cell state %p batch %d\n", &R[0], &B[0], &hidden_history[0], &hidden_state[0], &cell_state[0], b); } diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/lstm/lstm_seq_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/lstm/lstm_seq_kernel_base.h index 29a62af159945a..c1a1d80c571c58 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/lstm/lstm_seq_kernel_base.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/lstm/lstm_seq_kernel_base.h @@ -22,12 +22,6 @@ struct lstm_seq_params : public base_params { }; lstm_seq_params() : base_params(KernelType::LSTM_SEQ) {} - - DataTensor initial_hidden_state; - DataTensor initial_cell_state; - DataTensor sequence_lengths; - DataTensor WR; - DataTensor B; bool has_cell = false; order_type gate_order = offset_iofz; float clip = 0; diff --git a/src/plugins/intel_gpu/src/plugin/ops/rnn.cpp b/src/plugins/intel_gpu/src/plugin/ops/rnn.cpp index 2295ff9261094f..d6ea31eba3267b 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/rnn.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/rnn.cpp @@ -265,9 +265,9 @@ static void CreateLSTMSequenceOp(ProgramBuilder& p, const std::shared_ptrget_output_shape(1))); - cldnn::memory::ptr shared_memory1 = p.get_engine().allocate_memory(out1Layout); + cldnn::memory::ptr shared_memory0 = p.get_engine().allocate_memory(out1Layout); const cldnn::primitive_id mutable_id_0 = layerName + "_md_write0"; - const cldnn::mutable_data mutable_prim_0{mutable_id_0, shared_memory1}; + const cldnn::mutable_data mutable_prim_0{mutable_id_0, shared_memory0}; p.add_primitive(*op, mutable_prim_0); inputs.push_back(cldnn::input_info(mutable_id_0)); @@ -278,9 +278,9 @@ static void CreateLSTMSequenceOp(ProgramBuilder& p, const std::shared_ptrget_output_shape(2))); - cldnn::memory::ptr shared_memory2 = p.get_engine().allocate_memory(out2Layout); + cldnn::memory::ptr shared_memory1 = p.get_engine().allocate_memory(out2Layout); const cldnn::primitive_id mutable_id_1 = layerName + "_md_write1"; - const cldnn::mutable_data mutable_prim_1{mutable_id_1, shared_memory2}; + const cldnn::mutable_data mutable_prim_1{mutable_id_1, shared_memory1}; p.add_primitive(*op, mutable_prim_1); inputs.push_back(cldnn::input_info(mutable_id_1)); @@ -289,8 +289,8 @@ static void CreateLSTMSequenceOp(ProgramBuilder& p, const std::shared_ptr