Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LPT] Refactoring: PoC #5226

Merged
merged 4 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
103 changes: 78 additions & 25 deletions inference-engine/src/cldnn_engine/cldnn_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,12 @@
#include <transformations/low_precision/disable_convert_constant_folding_on_const_path.hpp>
#include <low_precision/pull_reshape_through_dequantization.hpp>
#include <low_precision/pull_transpose_through_dequantization.hpp>
#include <low_precision/transformer.hpp>
#include <low_precision/convolution.hpp>
#include <low_precision/convolution_backprop_data.hpp>
#include <low_precision/group_convolution.hpp>
#include <low_precision/low_precision.hpp>
#include <low_precision/mat_mul.hpp>
#include <low_precision/multiply_to_group_convolution.hpp>
#include <low_precision/strided_slice.hpp>
#include <low_precision/network_helper.hpp>

Expand Down Expand Up @@ -150,10 +153,12 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
OV_ITT_SCOPED_TASK(itt::domains::CLDNNPlugin, "clDNNEngine::TransformNetwork");
auto nGraphFunc = clonedNetwork.getFunction();

using const_node_ptr = const std::shared_ptr<const ngraph::Node>;

bool enableInt8;
{
ngraph::pass::Manager manager;
enableInt8 = config.enableInt8 && ngraph::pass::low_precision::LowPrecisionTransformer::isFunctionQuantized(nGraphFunc);
enableInt8 = config.enableInt8 && ngraph::pass::low_precision::LowPrecision::isFunctionQuantized(nGraphFunc);
if (enableInt8) {
manager.register_pass<ngraph::pass::DisableConvertConstantFoldingOnConstPath>(
std::vector<ngraph::element::Type>{ ngraph::element::i8, ngraph::element::u8, ngraph::element::i4, ngraph::element::u4 });
Expand Down Expand Up @@ -207,8 +212,6 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc

auto pass_config = manager.get_pass_config();

using const_node_ptr = const std::shared_ptr<const ngraph::Node>;

// SpaceToDepth/DepthToSpace node implementation supports only equal input/output tensors with rank <= 5
pass_config->set_callback<ngraph::pass::ConvertSpaceToDepth,
ngraph::pass::ConvertDepthToSpace>(
Expand Down Expand Up @@ -390,28 +393,78 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
if (!config.enable_fp16_for_quantized_models) {
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::f16, ngraph::element::f32 }});
}
auto lptPrerequisites = manager.register_pass<ngraph::pass::GraphRewrite>();
const std::vector<ngraph::element::Type> supportedTypes = { ngraph::element::i8, ngraph::element::u8 };
lptPrerequisites->add_matcher<PullReshapeThroughDequantization>(supportedTypes);
lptPrerequisites->add_matcher<PullTransposeThroughDequantization>(supportedTypes);
lptPrerequisites->add_matcher<ngraph::pass::LinOpSequenceFusion>();
manager.run_passes(nGraphFunc);

auto params = LayerTransformation::Params(true, // updatePrecisions
LayerTransformation::QuantizedTensorAlignment::UpdateLevel, // quantizedTensorAlignmentOnActivations
LayerTransformation::QuantizedTensorAlignment::None, // quantizedTensorAlignmentOnWeights
true); // supportAsymmetricQuantization
LowPrecisionTransformer transformer(LowPrecisionTransformer::getAllTransformations(params)
.add<MatMulTransformation, ngraph::opset1::MatMul>(LayerTransformation::Params(params)
.setSupportAsymmetricQuantization(false)
.setSupport3DTensorOnActivations(false))
.add<ConvolutionBackpropDataTransformation, ngraph::opset1::ConvolutionBackpropData>(LayerTransformation::Params(params)
.setSupportAsymmetricQuantization(false)
.setDeconvolutionSpecificChannelsRatio(true))
// INT8 StridedSlice not supported
.remove<StridedSliceTransformation, ngraph::opset1::StridedSlice>());

transformer.transform(nGraphFunc);
auto supportedPrecisions = std::vector<OperationPrecisionRestriction>({
OperationPrecisionRestriction::create<ngraph::opset1::Convolution>({
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to have variadic template to avoid code duplication like for this case with Convolutions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main idea of OperationPrecisionRestriction::create static template method: to create restriction which is based on two values:

  1. input argument: std::vector with restrictions as input argument,
  2. type (operation type) which is used for template instantiation.
    So, values have different types, there is no operation type instance and we should use one vector instance with restrictions only. I think we should not use variadic templates for the method here.

{0, {ngraph::element::u8, ngraph::element::i8}},
{1, {ngraph::element::i8}},
}),
OperationPrecisionRestriction::create<ngraph::opset1::ConvolutionBackpropData>({
{0, {ngraph::element::u8, ngraph::element::i8}},
{1, {ngraph::element::i8}}
}),
OperationPrecisionRestriction::create<ngraph::opset1::GroupConvolution>({
{0, {ngraph::element::u8, ngraph::element::i8}},
{1, {ngraph::element::i8}}
}),
OperationPrecisionRestriction::create<ngraph::opset1::StridedSlice>({})
});

auto perTensorQuantization = std::vector<OperationPerTensorQuantizationRestriction>({
OperationPerTensorQuantizationRestriction::create<ngraph::opset1::Convolution>({0}),
OperationPerTensorQuantizationRestriction::create<ngraph::opset1::ConvolutionBackpropData>({0}),
});

ngraph::pass::Manager lptManager;

auto lptPassConfig = lptManager.get_pass_config();
lptPassConfig->disable<ngraph::pass::low_precision::StridedSliceTransformation>();
lptPassConfig->set_callback<ngraph::pass::low_precision::MarkupPrecisions>([](const_node_ptr& node) -> bool {
if (const auto mulitply = std::dynamic_pointer_cast<const ngraph::opset1::Multiply>(node)) {
return !MultiplyToGroupConvolutionTransformation::canBeTransformedToGroupConvolution(mulitply);
}
return false;
});
lptPassConfig->set_callback<ConvolutionBackpropDataTransformation>([](const_node_ptr& node) -> bool {
auto fillStaticChannel = [](const ngraph::PartialShape& shape, size_t& channel) -> bool {
const auto rank = shape.rank();
if (rank.is_dynamic()) {
return false;
}
if (rank.get_length() < 2ul) {
return false;
}
const auto dimension = shape[1];
if (dimension.is_dynamic()) {
return false;
}
channel = dimension.get_length();
return true;
};

size_t inputChannels;
if (!fillStaticChannel(node->get_input_partial_shape(0), inputChannels)) {
return true;
}

size_t outputChannels;
if (!fillStaticChannel(node->get_output_partial_shape(0), outputChannels)) {
return true;
}


if ((inputChannels % 4 != 0) || (outputChannels % 16 != 0)) {
return true;
}

return LayerTransformation::isAsymmetricQuantization(node) || WeightableLayerTransformation::isAsymmetricOnWeights(node);
});
lptPassConfig->set_callback<MatMulTransformation>([](const_node_ptr& node) -> bool {
return MatMulTransformation::is3DTensorOnActivations(node);
});

lptManager.register_pass<LowPrecision>(supportedPrecisions, perTensorQuantization);
lptManager.run_passes(nGraphFunc);
}

{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ ie_faster_build(${TARGET_NAME}
ie_add_vs_version_file(NAME ${TARGET_NAME}
FILEDESCRIPTION "Inference Engine LP transformations library")

target_compile_definitions(${TARGET_NAME} PRIVATE inference_engine_transformations_EXPORTS)

target_link_libraries(${TARGET_NAME} PUBLIC inference_engine_transformations
PRIVATE openvino::itt)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@ namespace ngraph {
namespace pass {
namespace low_precision {

class TRANSFORMATIONS_API AddTransformation : public EltwiseBaseTransformation {
class LP_TRANSFORMATIONS_API AddTransformation : public EltwiseBaseTransformation {
public:
AddTransformation(const Params& params) : EltwiseBaseTransformation(params) {}
~AddTransformation() override {}
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
NGRAPH_RTTI_DECLARATION;
AddTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
};

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <ngraph/pass/pass.hpp>
#include "low_precision/lpt_visibility.hpp"

namespace ngraph {
namespace pass {
namespace low_precision {

class LP_TRANSFORMATIONS_API AlignQuantizationIntervals;

} // namespace low_precision
} // namespace pass
} // namespace ngraph

class ngraph::pass::low_precision::AlignQuantizationIntervals : public ngraph::pass::FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>

#include <ngraph/pass/pass.hpp>
#include "low_precision/lpt_visibility.hpp"

namespace ngraph {
namespace pass {
namespace low_precision {

class LP_TRANSFORMATIONS_API AlignQuantizationParameters;

} // namespace low_precision
} // namespace pass
} // namespace ngraph

class ngraph::pass::low_precision::AlignQuantizationParameters : public ngraph::pass::FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ namespace ngraph {
namespace pass {
namespace low_precision {

class TRANSFORMATIONS_API AvgPoolTransformation : public LayerTransformation {
class LP_TRANSFORMATIONS_API AvgPoolTransformation : public LayerTransformation {
public:
AvgPoolTransformation(const Params& params);
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
NGRAPH_RTTI_DECLARATION;
AvgPoolTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once
#include <ngraph/node.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include "rt_info/attribute_parameters.hpp"

namespace ngraph {
namespace pass {
namespace low_precision {

class LP_TRANSFORMATIONS_API BaseMatcherPass;

} // namespace low_precision
} // namespace pass
} // namespace ngraph

class LP_TRANSFORMATIONS_API ngraph::pass::low_precision::BaseMatcherPass : public ngraph::pass::MatcherPass {
public:
BaseMatcherPass(const AttributeParameters& params = AttributeParameters());
AttributeParameters params;
};
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ namespace ngraph {
namespace pass {
namespace low_precision {

class TRANSFORMATIONS_API ClampTransformation : public LayerTransformation {
class LP_TRANSFORMATIONS_API ClampTransformation : public LayerTransformation {
public:
ClampTransformation(const Params& params);
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher& m) const override;
NGRAPH_RTTI_DECLARATION;
ClampTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher& m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
#include <ngraph/check.hpp>
#include <ngraph/opsets/opset1.hpp>

#include "transformations_visibility.hpp"
#include "low_precision/lpt_visibility.hpp"
#include "transformations/rt_info/dequantization_attribute.hpp"

namespace ngraph {
namespace pass {
namespace low_precision {

// template<typename BaseOp2>
// class TRANSFORMATIONS_API DequantizationOp : public BaseOp2 {
// class LP_TRANSFORMATIONS_API DequantizationOp : public BaseOp2 {
// public:
// template <typename ... Args>
// DequantizationOp(Args&&... args) : BaseOp2(std::forward<Args>(args)...) {
Expand Down Expand Up @@ -63,7 +63,7 @@ void copyRuntimeInfo(const ngraph::Node& from, ngraph::Node& to) {

} // namespace

class TRANSFORMATIONS_API DequantizationConvert : public ngraph::opset1::Convert {
class LP_TRANSFORMATIONS_API DequantizationConvert : public ngraph::opset1::Convert {
public:
DequantizationConvert(const ngraph::Output<Node>& arg, const ngraph::element::Type& destination_type) :
ngraph::opset1::Convert(arg, destination_type) {
Expand All @@ -77,7 +77,7 @@ class TRANSFORMATIONS_API DequantizationConvert : public ngraph::opset1::Convert
}
};

class TRANSFORMATIONS_API DequantizationSubtract : public ngraph::opset1::Subtract {
class LP_TRANSFORMATIONS_API DequantizationSubtract : public ngraph::opset1::Subtract {
public:
DequantizationSubtract(
const ngraph::Output<Node>& arg0,
Expand All @@ -94,7 +94,7 @@ class TRANSFORMATIONS_API DequantizationSubtract : public ngraph::opset1::Subtra
}
};

class TRANSFORMATIONS_API DequantizationMultiply : public ngraph::opset1::Multiply {
class LP_TRANSFORMATIONS_API DequantizationMultiply : public ngraph::opset1::Multiply {
public:
DequantizationMultiply(
const Output<Node>& arg0,
Expand All @@ -116,7 +116,7 @@ class TRANSFORMATIONS_API DequantizationMultiply : public ngraph::opset1::Multip
}
};

class TRANSFORMATIONS_API DequantizationAdd : public ngraph::opset1::Add {
class LP_TRANSFORMATIONS_API DequantizationAdd : public ngraph::opset1::Add {
public:
DequantizationAdd(
const ngraph::Output<Node>& arg0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
#include <tuple>
#include <ngraph/ngraph.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <low_precision/lpt_visibility.hpp>

namespace ngraph {
namespace pass {
namespace low_precision {

typedef std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> FakeQuantizeDequantizationValues;

class FakeQuantizeDequantization {
class LP_TRANSFORMATIONS_API FakeQuantizeDequantization {
public:
FakeQuantizeDequantization();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <exception>
#include <string>
#include <ngraph/node.hpp>
#include <transformations_visibility.hpp>
#include <low_precision/lpt_visibility.hpp>

/**
* @def THROW_TRANSFORMATION_EXCEPTION_LPT
Expand All @@ -19,7 +19,7 @@ namespace ngraph {
namespace pass {
namespace low_precision {

class TRANSFORMATIONS_API Exception : std::exception {
class LP_TRANSFORMATIONS_API Exception : std::exception {
std::shared_ptr<std::ostringstream> buffer;
mutable std::string buffer_str;
public:
Expand All @@ -42,7 +42,7 @@ class TRANSFORMATIONS_API Exception : std::exception {
#define THROW_TRANSFORMATION_EXCEPTION throw ::ngraph::pass::low_precision::Exception() << __FILE__ << ":" << __LINE__ << " "


class TRANSFORMATIONS_API InferenceEngineLptException : public Exception {
class LP_TRANSFORMATIONS_API InferenceEngineLptException : public Exception {
public:
InferenceEngineLptException(const std::string& filename, const size_t line, const Node& node) {
*this
Expand Down
Loading