Skip to content

Commit

Permalink
#20927 support inputs that have no batch (#26778)
Browse files Browse the repository at this point in the history
#20927 
### Details:
 - *add batch dimension before pool*
 - *remove batch dimension after pool*
  • Loading branch information
tianyiSKY1 authored Oct 17, 2024
1 parent 55d8c47 commit 8822480
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 58 deletions.
71 changes: 66 additions & 5 deletions src/frontends/pytorch/src/op/avg_poolnd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/avg_pool.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/pad.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "utils.hpp"

namespace ov {
Expand All @@ -17,10 +22,31 @@ namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_avg_poolnd(const NodeContext& context) {
OutputVector translate_avg_pool_base(const NodeContext& context, int dims) {
num_inputs_check(context, 2, 7);
auto input = context.get_input(0);
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));

auto const_0 = v0::Constant::create(element::i64, Shape{1}, {0});
auto const_1 = v0::Constant::create(element::i64, Shape{1}, {1});
bool is_static = input.get_partial_shape().rank().is_static();
bool no_batch_dim = is_static && input.get_partial_shape().rank().get_length() == dims + 1;

if (is_static) {
if (no_batch_dim) {
input = context.mark_node(std::make_shared<v0::Unsqueeze>(input, const_0));
}
} else {
input = context.mark_node(std::make_shared<v0::Unsqueeze>(input, const_0));
auto unsqueeze_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));
auto rank = context.mark_node(std::make_shared<v0::ShapeOf>(unsqueeze_shape));
auto end_index = context.mark_node(std::make_shared<v1::Add>(rank, const_1));
auto start_index = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-dims - 2}));
auto reshape_pattern =
context.mark_node(std::make_shared<v8::Slice>(unsqueeze_shape, start_index, end_index, const_1, const_0));
input = context.mark_node(std::make_shared<v1::Reshape>(input, reshape_pattern, true));
}

auto kernel = context.const_input<Shape>(1);
Strides strides;
if (!context.input_is_none(2)) {
Expand All @@ -47,8 +73,43 @@ OutputVector translate_avg_poolnd(const NodeContext& context) {
}
PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(6),
"Translation for aten::avg_pool2d do not support divisor_override input.");
return {context.mark_node(
std::make_shared<v14::AvgPool>(input, strides, pads, pads, kernel, !count_include_pad, rounding_type))};
auto res = context.mark_node(
std::make_shared<v14::AvgPool>(input, strides, pads, pads, kernel, !count_include_pad, rounding_type));

if (is_static) {
if (no_batch_dim) {
res = context.mark_node(std::make_shared<v0::Squeeze>(res, const_0));
}
} else {
auto pooled_output_shape = context.mark_node(std::make_shared<v3::ShapeOf>(res));

auto start_index_input = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-dims}));
auto slice_input_shape =
context.mark_node(std::make_shared<v8::Slice>(input_shape, const_0, start_index_input, const_1, const_0));

auto start_index_pooled = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-dims}));
auto end_index_pooled = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {2 + dims}));
auto slice_pooled_output_shape = context.mark_node(
std::make_shared<v8::Slice>(pooled_output_shape, start_index_pooled, end_index_pooled, const_1, const_0));

auto concat_shape = context.mark_node(
std::make_shared<v0::Concat>(OutputVector{slice_input_shape, slice_pooled_output_shape}, 0));
res = context.mark_node(std::make_shared<v1::Reshape>(res, concat_shape, true));
}

return {res};
};

OutputVector translate_avg_pool1d(const NodeContext& context) {
return translate_avg_pool_base(context, 1);
};

OutputVector translate_avg_pool2d(const NodeContext& context) {
return translate_avg_pool_base(context, 2);
};

OutputVector translate_avg_pool3d(const NodeContext& context) {
return translate_avg_pool_base(context, 3);
};

} // namespace op
Expand Down
105 changes: 94 additions & 11 deletions src/frontends/pytorch/src/op/max_poolnd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
#include "openvino/op/multiply.hpp"
#include "openvino/op/pad.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "utils.hpp"

Expand All @@ -24,9 +28,31 @@ namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_max_poolnd(const NodeContext& context) {
OutputVector translate_max_pool_base(const NodeContext& context, int dims) {
num_inputs_check(context, 3, 6);
auto input = context.get_input(0);
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));

auto const_0 = v0::Constant::create(element::i64, Shape{1}, {0});
auto const_1 = v0::Constant::create(element::i64, Shape{1}, {1});
bool is_static = input.get_partial_shape().rank().is_static();
bool no_batch_dim = is_static && input.get_partial_shape().rank().get_length() == dims + 1;

if (is_static) {
if (no_batch_dim) {
input = context.mark_node(std::make_shared<v0::Unsqueeze>(input, const_0));
}
} else {
input = context.mark_node(std::make_shared<v0::Unsqueeze>(input, const_0));
auto unsqueeze_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));
auto rank = context.mark_node(std::make_shared<v0::ShapeOf>(unsqueeze_shape));
auto end_index = context.mark_node(std::make_shared<v1::Add>(rank, const_1));
auto start_index = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-dims - 2}));
auto reshape_pattern =
context.mark_node(std::make_shared<v8::Slice>(unsqueeze_shape, start_index, end_index, const_1, const_0));
input = context.mark_node(std::make_shared<v1::Reshape>(input, reshape_pattern, true));
}

auto kernel = context.const_input<Shape>(1);
Strides strides;
if (!context.input_is_none(2)) {
Expand All @@ -53,7 +79,7 @@ OutputVector translate_max_poolnd(const NodeContext& context) {
rounding_type = context.const_input<bool>(5) ? RoundingType::CEIL_TORCH : RoundingType::FLOOR;
}

auto res = context.mark_node(std::make_shared<v14::MaxPool>(context.get_input(0),
auto res = context.mark_node(std::make_shared<v14::MaxPool>(input,
strides,
dilations,
pads,
Expand All @@ -63,19 +89,76 @@ OutputVector translate_max_poolnd(const NodeContext& context) {
PadType::EXPLICIT,
element::i64,
2));
if (context.get_output_size() == 2) {
auto out1 = res->output(0);
auto out2 = res->output(1);
return {std::move(out1), std::move(out2)};
if (is_static) {
if (no_batch_dim) {
if (context.get_output_size() == 2) {
auto out1 = res->output(0);
auto out2 = res->output(1);
out1 = context.mark_node(std::make_shared<v0::Squeeze>(out1, const_0));
out2 = context.mark_node(std::make_shared<v0::Squeeze>(out2, const_0));
return {std::move(out1), std::move(out2)};
} else {
res = context.mark_node(std::make_shared<v0::Squeeze>(res, const_0));
return {res};
}
} else {
if (context.get_output_size() == 2) {
auto out1 = res->output(0);
auto out2 = res->output(1);
return {std::move(out1), std::move(out2)};
} else {
return {res};
}
}

} else {
return {res};
auto pooled_output_shape = context.mark_node(std::make_shared<v3::ShapeOf>(res));

auto start_index_input = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-dims}));
auto slice_input_shape =
context.mark_node(std::make_shared<v8::Slice>(input_shape, const_0, start_index_input, const_1, const_0));

auto start_index_pooled = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-dims}));
auto end_index_pooled = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {2 + dims}));
auto slice_pooled_output_shape = context.mark_node(
std::make_shared<v8::Slice>(pooled_output_shape, start_index_pooled, end_index_pooled, const_1, const_0));

auto concat_shape = context.mark_node(
std::make_shared<v0::Concat>(OutputVector{slice_input_shape, slice_pooled_output_shape}, 0));
if (context.get_output_size() == 2) {
auto out1 = res->output(0);
auto out2 = res->output(1);
out1 = context.mark_node(std::make_shared<v1::Reshape>(out1, concat_shape, true));
out2 = context.mark_node(std::make_shared<v1::Reshape>(out2, concat_shape, true));
return {std::move(out1), std::move(out2)};
} else {
res = context.mark_node(std::make_shared<v1::Reshape>(res, concat_shape, true));
return {res};
}
}
};

OutputVector translate_max_poolnd_fx(const NodeContext& context) {
auto output = translate_max_poolnd(context);
OutputVector translate_max_pool1d(const NodeContext& context) {
return translate_max_pool_base(context, 1);
};

OutputVector translate_max_pool2d(const NodeContext& context) {
return translate_max_pool_base(context, 2);
};

OutputVector translate_max_pool3d(const NodeContext& context) {
return translate_max_pool_base(context, 3);
};

OutputVector translate_max_pool2d_fx(const NodeContext& context) {
auto output = translate_max_pool2d(context);
return {context.mark_node(make_list_construct(output))};
}
};

OutputVector translate_max_pool3d_fx(const NodeContext& context) {
auto output = translate_max_pool3d(context);
return {context.mark_node(make_list_construct(output))};
};

} // namespace op
} // namespace pytorch
Expand Down
37 changes: 21 additions & 16 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ OP_CONVERTER(translate_argmax);
OP_CONVERTER(translate_argmin);
OP_CONVERTER(translate_as_strided);
OP_CONVERTER(translate_as_tensor);
OP_CONVERTER(translate_avg_poolnd);
OP_CONVERTER(translate_avg_pool1d);
OP_CONVERTER(translate_avg_pool2d);
OP_CONVERTER(translate_avg_pool3d);
OP_CONVERTER(translate_bool);
OP_CONVERTER(translate_batch_norm);
OP_CONVERTER(translate_bitwise_and);
Expand Down Expand Up @@ -139,7 +141,9 @@ OP_CONVERTER(translate_masked_scatter);
OP_CONVERTER(translate_masked_select);
OP_CONVERTER(translate_max);
OP_CONVERTER(translate_maximum);
OP_CONVERTER(translate_max_poolnd);
OP_CONVERTER(translate_max_pool1d);
OP_CONVERTER(translate_max_pool2d);
OP_CONVERTER(translate_max_pool3d);
OP_CONVERTER(translate_mean);
OP_CONVERTER(translate_meshgrid);
OP_CONVERTER(translate_min);
Expand Down Expand Up @@ -281,7 +285,8 @@ OP_CONVERTER(translate_leaky_relu_fx);
OP_CONVERTER(translate_log_sigmoid_fx);
OP_CONVERTER(translate_log_softmax_fx);
OP_CONVERTER(translate_max_dim_fx);
OP_CONVERTER(translate_max_poolnd_fx);
OP_CONVERTER(translate_max_pool2d_fx);
OP_CONVERTER(translate_max_pool3d_fx);
OP_CONVERTER(translate_mean_fx);
OP_CONVERTER(translate_min_dim_fx);
OP_CONVERTER(translate_new_full_fx);
Expand Down Expand Up @@ -380,9 +385,9 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::atanh",
op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atanh>, 1>},
{"aten::atanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Atanh>>},
{"aten::avg_pool1d", op::quantizable_op<op::translate_avg_poolnd>},
{"aten::avg_pool2d", op::quantizable_op<op::translate_avg_poolnd>},
{"aten::avg_pool3d", op::quantizable_op<op::translate_avg_poolnd>},
{"aten::avg_pool1d", op::quantizable_op<op::translate_avg_pool1d>},
{"aten::avg_pool2d", op::quantizable_op<op::translate_avg_pool2d>},
{"aten::avg_pool3d", op::quantizable_op<op::translate_avg_pool3d>},
{"aten::baddbmm", op::translate_addmm},
{"aten::batch_norm", op::translate_batch_norm},
{"aten::bitwise_and", op::translate_bitwise_and},
Expand Down Expand Up @@ -534,12 +539,12 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::max", op::translate_max},
{"aten::mv", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::maximum", op::translate_maximum},
{"aten::max_pool1d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool1d_with_indices", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool2d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool2d_with_indices", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool3d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool3d_with_indices", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool1d", op::quantizable_op<op::translate_max_pool1d>},
{"aten::max_pool1d_with_indices", op::quantizable_op<op::translate_max_pool1d>},
{"aten::max_pool2d", op::quantizable_op<op::translate_max_pool2d>},
{"aten::max_pool2d_with_indices", op::quantizable_op<op::translate_max_pool2d>},
{"aten::max_pool3d", op::quantizable_op<op::translate_max_pool3d>},
{"aten::max_pool3d_with_indices", op::quantizable_op<op::translate_max_pool3d>},
{"aten::mean", op::quantizable_op<op::translate_mean>},
{"aten::meshgrid", op::translate_meshgrid},
{"aten::min", op::translate_min},
Expand Down Expand Up @@ -771,8 +776,8 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.asinh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Asinh>},
{"aten.atan.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atan>},
{"aten.atanh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atanh>},
{"aten.avg_pool2d.default", op::translate_avg_poolnd},
{"aten.avg_pool3d.default", op::translate_avg_poolnd},
{"aten.avg_pool2d.default", op::translate_avg_pool2d},
{"aten.avg_pool3d.default", op::translate_avg_pool3d},
{"aten.baddbmm.default", op::translate_addmm_fx},
{"aten.bitwise_and.Scalar", op::translate_bitwise_and},
{"aten.bitwise_and.Tensor", op::translate_bitwise_and},
Expand Down Expand Up @@ -870,8 +875,8 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.masked_fill_.Tensor", op::inplace_op<op::translate_masked_fill>},
{"aten.max.default", op::translate_max},
{"aten.max.dim", op::translate_max_dim_fx},
{"aten.max_pool2d_with_indices.default", op::translate_max_poolnd_fx},
{"aten.max_pool3d_with_indices.default", op::translate_max_poolnd_fx},
{"aten.max_pool2d_with_indices.default", op::translate_max_pool2d_fx},
{"aten.max_pool3d_with_indices.default", op::translate_max_pool3d_fx},
{"aten.maximum.default", op::translate_maximum},
{"aten.mean.default", op::translate_mean_fx},
{"aten.mean.dim", op::translate_mean_fx},
Expand Down
Loading

0 comments on commit 8822480

Please sign in to comment.