Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-miotk committed Aug 6, 2024
1 parent 886b412 commit 8b2c049
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ struct lstm_seq : public primitive_base<lstm_seq> {
const uint32_t direction = 0,
const padding& output_padding = padding(),
const int num_outputs = 1)
: primitive_base(id, {x, initial_hidden_state, initial_cell_state, seq_lenghts, W, R, B, out1_prim_id, out2_prim_id}, {output_padding}, {}, \
: primitive_base(id, {x, initial_hidden_state, initial_cell_state, seq_lenghts, W, R, B}, {output_padding}, {}, \
num_outputs),
cell(cell),
clip(clip),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ bool crop_in_place_optimization::match(const program_node& node,
return false;
// if the node is marked as network output, prevent optimizations which would affect a form of its output,
// unless debug flag is set
if (node.is_output() || crop_params.fused_desc.size() > 0 || node.is_in_shape_of_subgraph() || node.is_type<mutable_data>())
if (node.is_output() || crop_params.fused_desc.size() > 0 || node.is_in_shape_of_subgraph())
return false;

const auto& crop_layout = crop_params.get_output_layout();
Expand All @@ -476,7 +476,7 @@ bool crop_in_place_optimization::match(const program_node& node,
// do not optimize when next node is concatenation which is not output
if (user->is_type<concatenation>() && !user->is_output())
return false;
if (user->is_type<loop>() || user->is_type<non_max_suppression>() || user->is_type<mutable_data>())
if (user->is_type<loop>() || user->is_type<non_max_suppression>())
return false;
// If the input tensor of convolution includes dynamic padding, there is an issue
// where the total size of tensor is not properly calculated and becomes 0
Expand Down
8 changes: 4 additions & 4 deletions src/plugins/intel_gpu/src/graph/impls/ocl/lstm_seq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ struct lstm_seq_impl : typed_primitive_impl_ocl<lstm_seq> {
params.input_forget = primitive->input_forget;
params.direction = primitive->direction;
//Legacy multi-output
if (impl_param.input_layouts.size() > 7) {
params.outputs.push_back(convert_data_tensor(impl_param.input_layouts[1]));
params.outputs.push_back(convert_data_tensor(impl_param.input_layouts[1]));
}

params.outputs.push_back(convert_data_tensor(impl_param.input_layouts[1]));
params.outputs.push_back(convert_data_tensor(impl_param.input_layouts[1]));

return params;
}

Expand Down
19 changes: 3 additions & 16 deletions src/plugins/intel_gpu/src/graph/lstm_seq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,24 @@ std::vector<layout> lstm_seq_inst::calc_output_layouts(lstm_seq_node const& node
input_shapes.push_back(input_shape);
}

ov::op::v5::LSTMSequence op;

std::vector<ShapeType> output_shapes = ov::op::v5::shape_infer(&op, input_shapes);

// input partial shape [batch, input_size (= hidden_size * 4)]
auto input_layout_x = impl_param.get_input_layout(0);
auto input_pshape_x = input_layout_x.get_partial_shape();
auto input_layout_hidden = impl_param.get_input_layout(1);
auto input_pshape_hidden = input_layout_hidden.get_partial_shape();
/*
if (impl_param.desc->output_data_types.size() > 0) {
OPENVINO_ASSERT(static_cast<bool>(impl_param.desc->output_data_types[0]) == false, "Output data type forcing is not supported for lstm_seq_node!");
}
*/
if (input_pshape_x.is_static()) {
OPENVINO_ASSERT(input_pshape_x.rank().get_length() == 4, "input_layout rank should be 4 on static shape.");
}
int lstm_batch_size, lstm_seq_length, lstm_hidden_size;
if (input_pshape_x[input_pshape_x.size() - 3].is_static()) {
if (input_pshape_x[0].is_static()) {
lstm_batch_size = input_pshape_x[0].get_length();
} else {
lstm_batch_size = -1;
}

if (input_pshape_x[input_pshape_x.size() - 2].is_static()) {
if (input_pshape_x[1].is_static()) {
lstm_seq_length = input_pshape_x[1].get_length();
} else {
lstm_seq_length = -1;
}

if (input_pshape_hidden[input_pshape_hidden.size() - 1].is_static()) {
if (input_pshape_hidden[2].is_static()) {
lstm_hidden_size = input_pshape_hidden[2].get_length();
} else {
lstm_hidden_size = -1;
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_gpu/src/plugin/ops/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ static void CreateLSTMSequenceOp(ProgramBuilder& p, const std::shared_ptr<ov::op
int direction = op->get_direction() == ov::op::RecurrentSequenceDirection::REVERSE ? 1 : 0;
cldnn::lstm_seq prim(layerName, inputs[0], inputs[1], \
inputs[2], inputs[3], inputs[4], inputs[5], cldnn::input_info(bias), \
"", clip, 0, activations, activation_params, cldnn::lstm_weights_order::fizo, direction, cldnn::padding(), op->get_output_size());
"", clip, 0, activations, activation_params, cldnn::lstm_weights_order::fizo, direction, cldnn::padding(), op->get_output_size() );
prim.output_paddings = get_output_paddings(op);
prim.output_data_types = get_output_data_types(op);
p.add_primitive(*op, prim);
Expand Down

0 comments on commit 8b2c049

Please sign in to comment.