Skip to content

Commit

Permalink
only 2 outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-miotk committed Jul 30, 2024
1 parent fd5f3dc commit 03cbf57
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 24 deletions.
11 changes: 5 additions & 6 deletions src/plugins/intel_gpu/include/intel_gpu/primitives/lstm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,7 @@ struct lstm_seq : public primitive_base<lstm_seq> {
const lstm_weights_order offset_order = lstm_weights_order::iofz,
const uint32_t direction = 0,
const padding& output_padding = padding())
: primitive_base(id, {x, initial_hidden_state, initial_cell_state, seq_lenghts, W, R, B}, {output_padding}, {}, 3),
out1_prim_id(out1_prim_id),
: primitive_base(id, {x, initial_hidden_state, initial_cell_state, seq_lenghts, W, R, B}, {output_padding}, {}, 2),
out2_prim_id(out2_prim_id),
cell(cell),
clip(clip),
Expand All @@ -195,7 +194,7 @@ struct lstm_seq : public primitive_base<lstm_seq> {
direction(direction) {}

/// @brief Primitive id containing the initial value of the cell state data.
primitive_id out1_prim_id;
//primitive_id out1_prim_id;
primitive_id out2_prim_id;
primitive_id cell;
/// @brief Cell clip threshold T. It is applied to the input of activations [-T, T]. No clip is applied if it is not specified.
Expand Down Expand Up @@ -251,7 +250,7 @@ struct lstm_seq : public primitive_base<lstm_seq> {

void save(BinaryOutputBuffer& ob) const override {
primitive_base<lstm_seq>::save(ob);
ob << out1_prim_id;
//ob << out1_prim_id;
ob << out2_prim_id;
ob << cell;
ob << clip;
Expand All @@ -264,7 +263,7 @@ struct lstm_seq : public primitive_base<lstm_seq> {

void load(BinaryInputBuffer& ib) override {
primitive_base<lstm_seq>::load(ib);
ib >> out1_prim_id;
//ib >> out1_prim_id;
ib >> out2_prim_id;
ib >> cell;
ib >> clip;
Expand All @@ -282,7 +281,7 @@ struct lstm_seq : public primitive_base<lstm_seq> {
if (!cell.empty())
ret.push_back(cell);
*/
ret.push_back(out1_prim_id);
//ret.push_back(out1_prim_id);
ret.push_back(out2_prim_id);
return ret;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ KERNEL(lstm_seq)(
const __global INPUT5_TYPE* R,
const __global INPUT6_TYPE* B,
__global OUTPUT_TYPE* hidden_history,
__global OUTPUT1_TYPE* hidden_state,
__global OUTPUT2_TYPE* cell_state
)
{
Expand Down Expand Up @@ -135,7 +134,7 @@ KERNEL(lstm_seq)(
//hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)]
local_hidden_state = gate_output[3]*ACTIVATION_H(ACTIVATION_CLIP(cell_state[OUTPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)], ACTIVATION_PARAMS_CLIP), ACTIVATION_PARAMS_H);
//printf("hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] is %f on b %d\n", hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)], b);
hidden_history[OUTPUT_GET_INDEX_SAFE(b, i, 0, 0)] = local_hidden_state;
hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, 0)] = local_hidden_state;
printf("hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)] is %f\n", hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)]);
}
//printf("cell state for %d is %f \n", OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0), cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ KernelsData LSTMSeqKernelBase::GetCommonKernelsData(const Params& params) const
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, 6});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 0});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 1});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 2});
//kernel.params.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 2});
auto cldnnJit = GetJitConstants(orgParams);
auto entryPoint = GetEntryPoint(kernelName, orgParams.layerID, params);
auto jit = CreateJit(kernelName, cldnnJit, entryPoint);
Expand Down
21 changes: 6 additions & 15 deletions src/plugins/intel_gpu/src/plugin/ops/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,19 +260,6 @@ static void CreateLSTMSequenceOp(ProgramBuilder& p, const std::shared_ptr<ov::op

cldnn::primitive_id lstm_seq_id = layerName;// + "_lstm_seq";

auto mutable_precision_first = op->get_output_element_type(1);
cldnn::layout out1Layout = cldnn::layout(
cldnn::element_type_to_data_type(mutable_precision_first),
cldnn::format::bfyx,
tensor_from_dims(op->get_output_shape(1)));
cldnn::memory::ptr shared_memory0 = p.get_engine().allocate_memory(out1Layout);
const cldnn::primitive_id mutable_id_0 = layerName + "_md_write0";
const cldnn::mutable_data mutable_prim_0{mutable_id_0, shared_memory0};
p.add_primitive(*op, mutable_prim_0);

inputs.push_back(cldnn::input_info(mutable_id_0));
auto f_id = inputs.back().pid;

auto mutable_precision_second = op->get_output_element_type(2);
cldnn::layout out2Layout = cldnn::layout(
cldnn::element_type_to_data_type(mutable_precision_second),
Expand All @@ -288,11 +275,15 @@ static void CreateLSTMSequenceOp(ProgramBuilder& p, const std::shared_ptr<ov::op
cldnn::lstm_seq prim(lstm_seq_id + ".out0", inputs[0], inputs[1], \
inputs[2], inputs[3], inputs[4], inputs[5], cldnn::input_info(bias), "", "", \
"", clip, 0, activations, activation_params, cldnn::lstm_weights_order::fizo, 0);
prim.out1_prim_id = f_id;
//prim.out1_prim_id = f_id;
prim.out2_prim_id = s_id;
p.add_primitive(*op, prim);
p.add_primitive(*op, cldnn::mutable_data(lstm_seq_id + ".out1", {cldnn::input_info(lstm_seq_id + ".out0")}, shared_memory0));
p.add_primitive(*op, cldnn::mutable_data(lstm_seq_id + ".out2", {cldnn::input_info(lstm_seq_id + ".out0")}, shared_memory1));
int b = op->get_input_shape(0)[0];
int seqlen = op->get_input_shape(0)[1];
int hidden_size = op->get_input_shape(1)[2];
p.add_primitive(*op, cldnn::crop(lstm_seq_id + ".out1", {cldnn::input_info(lstm_seq_id + ".out0")}, \
cldnn::tensor{ b, 1, seqlen-1, hidden_size}, cldnn::tensor{ 0, 0, 0, 0 }));
}

REGISTER_FACTORY_IMPL(v4, LSTMCell);
Expand Down

0 comments on commit 03cbf57

Please sign in to comment.