Skip to content

Commit

Permalink
Fix Pruning for case with INT8 GroupConvolution operation (openvinoto…
Browse files Browse the repository at this point in the history
  • Loading branch information
Gleb Kazantaev authored and andrei-cv committed Aug 30, 2021
1 parent b5fe35a commit 824ffaf
Showing 1 changed file with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ namespace mask_propagation {

class Convolution;
class GroupConvolution;
class GroupConvolutionReshape;
class Elementwise;
class PassThrough;
class StopPropagation;
class FakeQuantize;
class Concat;
class Reshape;

} // namespace mask_propagation
} // namespace pass
Expand Down Expand Up @@ -192,9 +192,9 @@ class ngraph::pass::mask_propagation::GroupConvolution : public MatcherPass {
}
};

class ngraph::pass::mask_propagation::Reshape : public MatcherPass {
class ngraph::pass::mask_propagation::GroupConvolutionReshape : public MatcherPass {
public:
Reshape() {
GroupConvolutionReshape() {
auto input = pattern::any_input(pattern::has_static_shape());
auto shape = pattern::any_input();
// Working only for Reshapes on Group Convolution weights
Expand Down Expand Up @@ -258,10 +258,12 @@ class ngraph::pass::mask_propagation::Reshape : public MatcherPass {
ngraph::replace_node(old_shape_const, new_const);

setMask(m_output, output_mask);
return true;
// This transformation propagates only Reshape mask and doesn't do anything with GroupConvolution.
// So, not to disable GroupConvolution mask propagation we return false here.
return false;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(reshape, "ReshapeMaskPropagation");
auto m = std::make_shared<ngraph::pattern::Matcher>(gconv, "ReshapeMaskPropagation");
register_matcher(m, callback);
}
};
Expand Down Expand Up @@ -604,11 +606,11 @@ class ngraph::pass::mask_propagation::StopPropagation : public MatcherPass {

ngraph::pass::PropagateMasks::PropagateMasks() {
add_matcher<mask_propagation::Convolution>();
add_matcher<mask_propagation::GroupConvolutionReshape>();
add_matcher<mask_propagation::GroupConvolution>();
add_matcher<mask_propagation::Elementwise>();
add_matcher<mask_propagation::PassThrough>();
add_matcher<mask_propagation::FakeQuantize>();
add_matcher<mask_propagation::Concat>();
add_matcher<mask_propagation::Reshape>();
add_matcher<mask_propagation::StopPropagation>();
}

0 comments on commit 824ffaf

Please sign in to comment.