Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#20927 support inputs that have no batch #26778

Merged
merged 10 commits into from
Oct 17, 2024
54 changes: 49 additions & 5 deletions src/frontends/pytorch/src/op/avg_poolnd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,39 @@
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/pad.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/add.hpp"
#include "utils.hpp"


namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_avg_poolnd(const NodeContext& context) {
OutputVector translate_avg_pool_base(const NodeContext& context, int dims) {
num_inputs_check(context, 2, 7);
auto input = context.get_input(0);
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));

auto const_0 = v0::Constant::create(element::i64, Shape{1}, {0});
auto const_1 = v0::Constant::create(element::i64, Shape{1}, {1});

// Unsqueeze axis to add batch dimension
input = context.mark_node(std::make_shared<v0::Unsqueeze>(input, const_0));
tianyiSKY1 marked this conversation as resolved.
Show resolved Hide resolved

// Reshape pattern based on dimensions
auto unsqueeze_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));
auto rank = context.mark_node(std::make_shared<v0::ShapeOf>(unsqueeze_shape));
auto end_index = context.mark_node(std::make_shared<v1::Add>(rank, const_1));
auto start_index = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-dims - 2}));
auto reshape_pattern = context.mark_node(std::make_shared<v8::Slice>(unsqueeze_shape, start_index, end_index, const_1, const_0));
input = context.mark_node(std::make_shared<v1::Reshape>(input, reshape_pattern, true));

auto kernel = context.const_input<Shape>(1);
Strides strides;
if (!context.input_is_none(2)) {
Expand All @@ -47,10 +67,34 @@ OutputVector translate_avg_poolnd(const NodeContext& context) {
}
PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(6),
"Translation for aten::avg_pool2d do not support divisor_override input.");
return {context.mark_node(
std::make_shared<v14::AvgPool>(input, strides, pads, pads, kernel, !count_include_pad, rounding_type))};
auto res = context.mark_node(
std::make_shared<v14::AvgPool>(input, strides, pads, pads, kernel, !count_include_pad, rounding_type));

// Reshape back to original shape
auto pooled_output_shape = context.mark_node(std::make_shared<v3::ShapeOf>(res));
auto slice_input_shape = context.mark_node(std::make_shared<v8::Slice>(input_shape, const_0,
context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-dims})), const_1, const_0));
auto slice_pooled_output_shape = context.mark_node(std::make_shared<v8::Slice>(pooled_output_shape, context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-dims})),
context.mark_node(v0::Constant::create(element::i64, Shape{1}, {2 + dims})), const_1, const_0));
auto concat_shape = context.mark_node(std::make_shared<v0::Concat>(OutputVector{slice_input_shape, slice_pooled_output_shape}, 0));
res = context.mark_node(std::make_shared<v1::Reshape>(res, concat_shape, true));
return {res};

};

OutputVector translate_avg_pool1d(const NodeContext& context) {
return translate_avg_pool_base(context, 1);
};

OutputVector translate_avg_pool2d(const NodeContext& context) {
return translate_avg_pool_base(context, 2);
};

OutputVector translate_avg_pool3d(const NodeContext& context) {
return translate_avg_pool_base(context, 3);
};


} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
60 changes: 54 additions & 6 deletions src/frontends/pytorch/src/op/max_poolnd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "utils.hpp"

Expand All @@ -24,9 +28,25 @@ namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_max_poolnd(const NodeContext& context) {
OutputVector translate_max_pool_base(const NodeContext& context, int dims) {
num_inputs_check(context, 3, 6);
auto input = context.get_input(0);
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));

auto const_0 = v0::Constant::create(element::i64, Shape{1}, {0});
auto const_1 = v0::Constant::create(element::i64, Shape{1}, {1});

// Unsqueeze axis to add batch dimension
input = context.mark_node(std::make_shared<v0::Unsqueeze>(input, const_0));

// Reshape pattern based on dimensions
auto unsqueeze_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));
auto rank = context.mark_node(std::make_shared<v0::ShapeOf>(unsqueeze_shape));
auto end_index = context.mark_node(std::make_shared<v1::Add>(rank, const_1));
auto start_index = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-dims - 2}));
auto reshape_pattern = context.mark_node(std::make_shared<v8::Slice>(unsqueeze_shape, start_index, end_index, const_1, const_0));
input = context.mark_node(std::make_shared<v1::Reshape>(input, reshape_pattern, true));

auto kernel = context.const_input<Shape>(1);
Strides strides;
if (!context.input_is_none(2)) {
Expand All @@ -53,7 +73,7 @@ OutputVector translate_max_poolnd(const NodeContext& context) {
rounding_type = context.const_input<bool>(5) ? RoundingType::CEIL_TORCH : RoundingType::FLOOR;
}

auto res = context.mark_node(std::make_shared<v14::MaxPool>(context.get_input(0),
auto res = context.mark_node(std::make_shared<v14::MaxPool>(input,
strides,
dilations,
pads,
Expand All @@ -63,19 +83,47 @@ OutputVector translate_max_poolnd(const NodeContext& context) {
PadType::EXPLICIT,
element::i64,
2));

// Reshape back to original shape
auto pooled_output_shape = context.mark_node(std::make_shared<v3::ShapeOf>(res));
auto slice_input_shape = context.mark_node(std::make_shared<v8::Slice>(input_shape, const_0,
context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-dims})), const_1, const_0));
auto slice_pooled_output_shape = context.mark_node(std::make_shared<v8::Slice>(pooled_output_shape, context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-dims})),
context.mark_node(v0::Constant::create(element::i64, Shape{1}, {2 + dims})), const_1, const_0));
auto concat_shape = context.mark_node(std::make_shared<v0::Concat>(OutputVector{slice_input_shape, slice_pooled_output_shape}, 0));
if (context.get_output_size() == 2) {
auto out1 = res->output(0);
auto out2 = res->output(1);
out1 = context.mark_node(std::make_shared<v1::Reshape>(out1, concat_shape, true));
out2 = context.mark_node(std::make_shared<v1::Reshape>(out2, concat_shape, true));
return {std::move(out1), std::move(out2)};
} else {
res = context.mark_node(std::make_shared<v1::Reshape>(res, concat_shape, true));
return {res};
}
};

OutputVector translate_max_poolnd_fx(const NodeContext& context) {
auto output = translate_max_poolnd(context);
OutputVector translate_max_pool1d(const NodeContext& context) {
return translate_max_pool_base(context, 1);
};

OutputVector translate_max_pool2d(const NodeContext& context) {
return translate_max_pool_base(context, 2);
};

OutputVector translate_max_pool3d(const NodeContext& context) {
return translate_max_pool_base(context, 3);
};

OutputVector translate_max_pool2d_fx(const NodeContext& context) {
auto output = translate_max_pool2d(context);
return {context.mark_node(make_list_construct(output))};
}
};

OutputVector translate_max_pool3d_fx(const NodeContext& context) {
auto output = translate_max_pool3d(context);
return {context.mark_node(make_list_construct(output))};
};

} // namespace op
} // namespace pytorch
Expand Down
37 changes: 21 additions & 16 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ OP_CONVERTER(translate_argmax);
OP_CONVERTER(translate_argmin);
OP_CONVERTER(translate_as_strided);
OP_CONVERTER(translate_as_tensor);
OP_CONVERTER(translate_avg_poolnd);
OP_CONVERTER(translate_avg_pool1d);
OP_CONVERTER(translate_avg_pool2d);
OP_CONVERTER(translate_avg_pool3d);
OP_CONVERTER(translate_bool);
OP_CONVERTER(translate_batch_norm);
OP_CONVERTER(translate_bitwise_and);
Expand Down Expand Up @@ -139,7 +141,9 @@ OP_CONVERTER(translate_masked_scatter);
OP_CONVERTER(translate_masked_select);
OP_CONVERTER(translate_max);
OP_CONVERTER(translate_maximum);
OP_CONVERTER(translate_max_poolnd);
OP_CONVERTER(translate_max_pool1d);
OP_CONVERTER(translate_max_pool2d);
OP_CONVERTER(translate_max_pool3d);
OP_CONVERTER(translate_mean);
OP_CONVERTER(translate_meshgrid);
OP_CONVERTER(translate_min);
Expand Down Expand Up @@ -281,7 +285,8 @@ OP_CONVERTER(translate_leaky_relu_fx);
OP_CONVERTER(translate_log_sigmoid_fx);
OP_CONVERTER(translate_log_softmax_fx);
OP_CONVERTER(translate_max_dim_fx);
OP_CONVERTER(translate_max_poolnd_fx);
OP_CONVERTER(translate_max_pool2d_fx);
OP_CONVERTER(translate_max_pool3d_fx);
OP_CONVERTER(translate_mean_fx);
OP_CONVERTER(translate_min_dim_fx);
OP_CONVERTER(translate_new_full_fx);
Expand Down Expand Up @@ -380,9 +385,9 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::atanh",
op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atanh>, 1>},
{"aten::atanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Atanh>>},
{"aten::avg_pool1d", op::quantizable_op<op::translate_avg_poolnd>},
{"aten::avg_pool2d", op::quantizable_op<op::translate_avg_poolnd>},
{"aten::avg_pool3d", op::quantizable_op<op::translate_avg_poolnd>},
{"aten::avg_pool1d", op::quantizable_op<op::translate_avg_pool1d>},
{"aten::avg_pool2d", op::quantizable_op<op::translate_avg_pool2d>},
{"aten::avg_pool3d", op::quantizable_op<op::translate_avg_pool3d>},
{"aten::baddbmm", op::translate_addmm},
{"aten::batch_norm", op::translate_batch_norm},
{"aten::bitwise_and", op::translate_bitwise_and},
Expand Down Expand Up @@ -534,12 +539,12 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::max", op::translate_max},
{"aten::mv", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::maximum", op::translate_maximum},
{"aten::max_pool1d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool1d_with_indices", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool2d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool2d_with_indices", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool3d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool3d_with_indices", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool1d", op::quantizable_op<op::translate_max_pool1d>},
{"aten::max_pool1d_with_indices", op::quantizable_op<op::translate_max_pool1d>},
{"aten::max_pool2d", op::quantizable_op<op::translate_max_pool2d>},
{"aten::max_pool2d_with_indices", op::quantizable_op<op::translate_max_pool2d>},
{"aten::max_pool3d", op::quantizable_op<op::translate_max_pool3d>},
{"aten::max_pool3d_with_indices", op::quantizable_op<op::translate_max_pool3d>},
{"aten::mean", op::quantizable_op<op::translate_mean>},
{"aten::meshgrid", op::translate_meshgrid},
{"aten::min", op::translate_min},
Expand Down Expand Up @@ -768,8 +773,8 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.asinh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Asinh>},
{"aten.atan.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atan>},
{"aten.atanh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atanh>},
{"aten.avg_pool2d.default", op::translate_avg_poolnd},
{"aten.avg_pool3d.default", op::translate_avg_poolnd},
{"aten.avg_pool2d.default", op::translate_avg_pool2d},
{"aten.avg_pool3d.default", op::translate_avg_pool3d},
{"aten.baddbmm.default", op::translate_addmm_fx},
{"aten.bitwise_and.Scalar", op::translate_bitwise_and},
{"aten.bitwise_and.Tensor", op::translate_bitwise_and},
Expand Down Expand Up @@ -866,8 +871,8 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.masked_fill_.Tensor", op::inplace_op<op::translate_masked_fill>},
{"aten.max.default", op::translate_max},
{"aten.max.dim", op::translate_max_dim_fx},
{"aten.max_pool2d_with_indices.default", op::translate_max_poolnd_fx},
{"aten.max_pool3d_with_indices.default", op::translate_max_poolnd_fx},
{"aten.max_pool2d_with_indices.default", op::translate_max_pool2d_fx},
{"aten.max_pool3d_with_indices.default", op::translate_max_pool3d_fx},
{"aten.maximum.default", op::translate_maximum},
{"aten.mean.default", op::translate_mean_fx},
{"aten.mean.dim", op::translate_mean_fx},
Expand Down
Loading
Loading