Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-miotk committed Jul 30, 2024
1 parent 2bb561f commit 939d23c
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 16 deletions.
4 changes: 2 additions & 2 deletions src/plugins/intel_gpu/include/intel_gpu/primitives/lstm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,8 @@ struct lstm_seq : public primitive_base<lstm_seq> {
if (!cell.empty())
ret.push_back(cell);
*/
//ret.push_back(second_output);
//ret.push_back(third_output);
ret.push_back(out1_prim_id);
ret.push_back(out2_prim_id);
return ret;
}
};
Expand Down
6 changes: 5 additions & 1 deletion src/plugins/intel_gpu/src/graph/include/lstm_seq_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,12 @@ class typed_primitive_inst<lstm_seq> : public typed_primitive_inst_base<lstm_seq
size_t offset = 8;
return dep_memory_ptr(offset);
}

void update_output_memory() override;

private:
void on_execute() override;
};

using lstm_seq_inst = typed_primitive_inst<lstm_seq>;

} // namespace cldnn
20 changes: 20 additions & 0 deletions src/plugins/intel_gpu/src/graph/lstm_seq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,26 @@ std::vector<layout> lstm_seq_inst::calc_output_layouts(lstm_seq_node const& node

template std::vector<layout> lstm_seq_inst::calc_output_layouts<ov::PartialShape>(lstm_seq_node const& node, const kernel_impl_params& impl_param);

void lstm_seq_inst::on_execute() {
update_output_memory();
}

void lstm_seq_inst::update_output_memory() {
if (!can_be_optimized())
return;

for (size_t i = 1; i < 3; i++) {
if (node->get_program().is_new_shape_infer() && input_memory_ptr(i+6) == nullptr)
return;

if (output_memory_ptr(i) != nullptr && _network.get_engine().is_the_same_buffer(output_memory(i), input_memory(i+6)))
return;

_outputs[i] = {_network.get_engine().reinterpret_buffer(input_memory(i+6), _impl_params->get_output_layout(i))};
}
}


std::string lstm_seq_inst::to_string(lstm_seq_node const& node) {
auto desc = node.get_primitive();
auto node_info = node.desc_to_json();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,5 +139,5 @@ KERNEL(lstm_seq)(
printf("hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)] is %f\n", hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)]);
}
//printf("cell state for %d is %f \n", OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0), cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)]);
//printf("R is %p B is %p ; %p out0 %p add for out1 for out2 %p batch %d\n", &R, &B, &hidden_history, &hidden_state, &cell_state, b);
printf("R is %p B is %p ; hidden history %p hidden state %p cell state %p batch %d\n", &R[0], &B[0], &hidden_history[0], &hidden_state[0], &cell_state[0], b);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@ struct lstm_seq_params : public base_params {
};

lstm_seq_params() : base_params(KernelType::LSTM_SEQ) {}

DataTensor initial_hidden_state;
DataTensor initial_cell_state;
DataTensor sequence_lengths;
DataTensor WR;
DataTensor B;
bool has_cell = false;
order_type gate_order = offset_iofz;
float clip = 0;
Expand Down
12 changes: 6 additions & 6 deletions src/plugins/intel_gpu/src/plugin/ops/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,9 @@ static void CreateLSTMSequenceOp(ProgramBuilder& p, const std::shared_ptr<ov::op
cldnn::element_type_to_data_type(mutable_precision_first),
cldnn::format::bfyx,
tensor_from_dims(op->get_output_shape(1)));
cldnn::memory::ptr shared_memory1 = p.get_engine().allocate_memory(out1Layout);
cldnn::memory::ptr shared_memory0 = p.get_engine().allocate_memory(out1Layout);
const cldnn::primitive_id mutable_id_0 = layerName + "_md_write0";
const cldnn::mutable_data mutable_prim_0{mutable_id_0, shared_memory1};
const cldnn::mutable_data mutable_prim_0{mutable_id_0, shared_memory0};
p.add_primitive(*op, mutable_prim_0);

inputs.push_back(cldnn::input_info(mutable_id_0));
Expand All @@ -278,9 +278,9 @@ static void CreateLSTMSequenceOp(ProgramBuilder& p, const std::shared_ptr<ov::op
cldnn::element_type_to_data_type(mutable_precision_second),
cldnn::format::bfyx,
tensor_from_dims(op->get_output_shape(2)));
cldnn::memory::ptr shared_memory2 = p.get_engine().allocate_memory(out2Layout);
cldnn::memory::ptr shared_memory1 = p.get_engine().allocate_memory(out2Layout);
const cldnn::primitive_id mutable_id_1 = layerName + "_md_write1";
const cldnn::mutable_data mutable_prim_1{mutable_id_1, shared_memory2};
const cldnn::mutable_data mutable_prim_1{mutable_id_1, shared_memory1};
p.add_primitive(*op, mutable_prim_1);

inputs.push_back(cldnn::input_info(mutable_id_1));
Expand All @@ -289,8 +289,8 @@ static void CreateLSTMSequenceOp(ProgramBuilder& p, const std::shared_ptr<ov::op
inputs[2], inputs[3], inputs[4], inputs[5], cldnn::input_info(bias), inputs[7].pid, inputs[8].pid, \
"", clip, 0, activations, activation_params, cldnn::lstm_weights_order::fizo, 0);
p.add_primitive(*op, prim);
p.add_primitive(*op, cldnn::mutable_data(lstm_seq_id + ".out1", {cldnn::input_info(lstm_seq_id + ".out0")}, shared_memory1));
p.add_primitive(*op, cldnn::mutable_data(lstm_seq_id + ".out2", {cldnn::input_info(lstm_seq_id + ".out0")}, shared_memory2));
p.add_primitive(*op, cldnn::mutable_data(lstm_seq_id + ".out1", {cldnn::input_info(lstm_seq_id + ".out0")}, shared_memory0));
p.add_primitive(*op, cldnn::mutable_data(lstm_seq_id + ".out2", {cldnn::input_info(lstm_seq_id + ".out0")}, shared_memory1));
}

REGISTER_FACTORY_IMPL(v4, LSTMCell);
Expand Down

0 comments on commit 939d23c

Please sign in to comment.