Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Apr 14, 2021
1 parent 44ba180 commit a08b656
Show file tree
Hide file tree
Showing 11 changed files with 61 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ class TRANSFORMATIONS_API LowPrecision;
} // namespace pass
} // namespace ngraph

class ngraph::pass::low_precision::LowPrecision: public ngraph::pass::FunctionPass {
class ngraph::pass::low_precision::LowPrecision : public ngraph::pass::FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;
LowPrecision(
const std::vector<OperationPrecisionRestriction>& restrictions = {},
// TODO: debug only
const LayerTransformation::Params = LayerTransformation::Params());
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
#pragma once

#include <memory>
#include <set>
#include <unordered_set>
#include <vector>
#include <string>

#include <ngraph/node.hpp>
#include <ngraph/variant.hpp>
Expand Down Expand Up @@ -40,9 +38,7 @@ class TRANSFORMATIONS_API ngraph::VariantWrapper<std::shared_ptr<IntervalsAlignm

std::shared_ptr<ngraph::Variant> merge(const ngraph::NodeVector& nodes) override;

std::shared_ptr<ngraph::Variant> init(const std::shared_ptr<ngraph::Node>& node) override;

std::shared_ptr<IntervalsAlignmentAttribute> get() { return this->m_value; };
std::shared_ptr<IntervalsAlignmentAttribute> get() const { return this->m_value; };

std::string get_string() override;
};
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,8 @@ class TRANSFORMATIONS_API ngraph::VariantWrapper<PrecisionPreservedAttribute> :

VariantWrapper(const value_type& value) : VariantImpl<value_type>(value) {}

// TODO: not completed for several branches
std::shared_ptr<ngraph::Variant> merge(const ngraph::NodeVector& nodes) override;

std::shared_ptr<ngraph::Variant> init(const std::shared_ptr<ngraph::Node>& node) override;

PrecisionPreservedAttribute get() { return this->m_value; };

std::string get_string() override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,11 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::low_precision::LowPrecision, "LowPrecision"
ngraph::pass::low_precision::LowPrecision::LowPrecision(
const std::vector<OperationPrecisionRestriction>& restrictions,
const LayerTransformation::Params params) : restrictions(restrictions), params(params){
//
}

bool ngraph::pass::low_precision::LowPrecision::run_on_function(std::shared_ptr<ngraph::Function> f) {
// TODO: to debug only
TransformationContext context(f);
// TransformationContext context(f);

// pass config should be reused
const std::vector<ngraph::element::Type> supportedTypes = { ngraph::element::i8, ngraph::element::u8 };
Expand All @@ -81,7 +80,6 @@ bool ngraph::pass::low_precision::LowPrecision::run_on_function(std::shared_ptr<
manager.register_pass<ngraph::pass::low_precision::PropagatePrecisions>();
manager.register_pass<ngraph::pass::low_precision::AlignConcatQuantizationParamters>();


//{
// // TODO: just to DEBUG: use the same manager
// ngraph::pass::Manager manager1;
Expand Down Expand Up @@ -120,12 +118,15 @@ bool ngraph::pass::low_precision::LowPrecision::run_on_function(std::shared_ptr<
common->add_matcher<ngraph::pass::low_precision::ReluTransformation>();
common->add_matcher<ngraph::pass::low_precision::ReshapeTransformation>();
common->add_matcher<ngraph::pass::low_precision::SqueezeTransformation>();
//common->add_matcher<ngraph::pass::low_precision::SplitTransformation>();
//common->add_matcher<ngraph::pass::low_precision::StridedSliceTransformation>();
common->add_matcher<ngraph::pass::low_precision::TransposeTransformation>();
common->add_matcher<ngraph::pass::low_precision::UnsqueezeTransformation>();

//cleanupStep4.register_pass<ngraph::pass::low_precision::FuseConvertTransformation, opset1::Multiply>(params);
//common->add_matcher<ngraph::pass::low_precision::VariadicSplit>();

std::shared_ptr<ngraph::pass::GraphRewrite> cleanup = manager.register_pass<ngraph::pass::GraphRewrite>();
//cleanup->add_matcher<ngraph::pass::low_precision::FoldConvertTransformation>();
//cleanup->add_matcher<ngraph::pass::low_precision::FuseConvertTransformation>();
cleanup->add_matcher<ngraph::pass::low_precision::FakeQuantizeTransformation>();
cleanup->add_matcher<ngraph::pass::low_precision::FuseSubtractToFakeQuantizeTransformation>();
cleanup->add_matcher<ngraph::pass::low_precision::FuseMultiplyToFakeQuantizeTransformation>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,6 @@ bool ngraph::pass::low_precision::MarkupPrecisions::run_on_function(std::shared_
setRestriction(node, precisionsByPort);
}
}


//Output<Node> output = node->output(0);
}
return true;
}
Expand All @@ -127,14 +124,33 @@ bool ngraph::pass::low_precision::MarkupPrecisions::isDisabled(const std::shared
return false;
}

template <class Operation>
std::string name() {
return Operation::get_type_info_static().name;
}

bool ngraph::pass::low_precision::MarkupPrecisions::isPrecisionPreserved(const std::shared_ptr<Node>& node) {
if (isDisabled(node)) {
return false;
}

// TODO: think how to handle conditions <= not mandatory for PoC
// TODO: operation set version is not affected <= not mandatory for PoC
static std::unordered_set<std::string> precisionPreserved = {
{ "Concat" },
{ "MaxPool" }
{ name<opset1::Concat>() },
// TODO: there are conditions
{ name<opset1::DepthToSpace>() },
{ name<opset1::MaxPool>() },
// TODO: there are conditions
{ name<opset1::Relu>() },
// TODO: there are conditions
{ name<opset1::Reshape>() },
{ name<opset1::Squeeze>() },
{ name<opset1::Split>() },
{ name<opset1::StridedSlice>() },
{ name<opset1::Transpose>() },
{ name<opset1::Unsqueeze>() },
{ name<opset1::VariadicSplit>() }
};

return precisionPreserved.find(node->get_type_name()) != precisionPreserved.end();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,64 +55,6 @@ std::vector<std::shared_ptr<ngraph::VariantWrapper<std::shared_ptr<PrecisionsAtt
return parentAttributes;
}

std::vector<std::shared_ptr<ngraph::VariantWrapper<std::shared_ptr<PrecisionsAttribute>>>> getChildrenInputRestrictions(const std::shared_ptr<ngraph::Node> node) {
std::vector<std::shared_ptr<ngraph::VariantWrapper<std::shared_ptr<PrecisionsAttribute>>>> childAttributes;
for (Output<Node>& output : node->outputs()) {
for (const Input<Node>& input : output.get_target_inputs()) {
auto& inputRtInfo = input.get_rt_info();
auto inputAttributeIt = inputRtInfo.find(ngraph::VariantWrapper<std::shared_ptr<PrecisionsAttribute>>::type_info.name);
if (inputAttributeIt != inputRtInfo.end()) {
const auto attribute = std::dynamic_pointer_cast<ngraph::VariantWrapper<std::shared_ptr<PrecisionsAttribute>>>(inputAttributeIt->second);
childAttributes.push_back(attribute);
}
}
}
return childAttributes;
}

std::shared_ptr<ngraph::VariantWrapper<std::shared_ptr<PrecisionsAttribute>>> getParentInputRestriction(const std::shared_ptr<ngraph::Node> node, const size_t parentIndex) {
const auto& inputs = node->inputs();
Input<Node> input = inputs[parentIndex];
const auto& inputNode = input.get_source_output().get_node()->shared_from_this();
if (NetworkHelper::isPrecisionPreserved(inputNode)) {
for (const Input<Node>& input : inputNode->inputs()) {
auto& inputRtInfo = input.get_rt_info();
auto inputAttributeIt = inputRtInfo.find(ngraph::VariantWrapper<std::shared_ptr<PrecisionsAttribute>>::type_info.name);
if (inputAttributeIt != inputRtInfo.end()) {
const auto& attribute = std::dynamic_pointer_cast<ngraph::VariantWrapper<std::shared_ptr<PrecisionsAttribute>>>(inputAttributeIt->second);
return attribute;
}
}
}

if (is_type<opset1::FakeQuantize>(inputNode)) {
const auto& outputPortRtInfo = inputNode->outputs()[0].get_rt_info();
auto attributeIt = outputPortRtInfo.find(ngraph::VariantWrapper<std::shared_ptr<PrecisionsAttribute>>::type_info.name);
if (attributeIt != outputPortRtInfo.end()) {
const auto& attribute = std::dynamic_pointer_cast<ngraph::VariantWrapper<std::shared_ptr<PrecisionsAttribute>>>(attributeIt->second);
return attribute;
}
}

return nullptr;
}

std::vector<std::shared_ptr<ngraph::VariantWrapper<std::shared_ptr<PrecisionsAttribute>>>> getParentOutputRestrictions(const std::shared_ptr<ngraph::Node> node) {
std::vector<std::shared_ptr<ngraph::VariantWrapper<std::shared_ptr<PrecisionsAttribute>>>> parentOutputAttributes;
for (Input<Node>& input : node->inputs()) {
const auto& inputNode = input.get_source_output().get_node()->shared_from_this();
const auto& parentOutput = input.get_source_output();

auto& parentOutputRtInfo = parentOutput.get_rt_info();
auto outputAttributeIt = parentOutputRtInfo.find(ngraph::VariantWrapper<std::shared_ptr<PrecisionsAttribute>>::type_info.name);
if (outputAttributeIt != parentOutputRtInfo.end()) {
const auto& attribute = std::dynamic_pointer_cast<ngraph::VariantWrapper<std::shared_ptr<PrecisionsAttribute>>>(outputAttributeIt->second);
parentOutputAttributes.push_back(attribute);
}
}
return parentOutputAttributes;
}

void replaceAttributeInInputs(
std::shared_ptr<ngraph::Function> f,
const std::shared_ptr<ngraph::VariantWrapper<std::shared_ptr<PrecisionsAttribute>>> newAttribute,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ std::shared_ptr<ngraph::Variant> VariantWrapper<IntervalsAlignmentAttributePtr>:
return resultAttributeWrapper;
}

std::shared_ptr<ngraph::Variant> VariantWrapper<IntervalsAlignmentAttributePtr>::init(const std::shared_ptr<ngraph::Node>& node) {
return nullptr;
}

std::string VariantWrapper<IntervalsAlignmentAttributePtr>::get_string() {
std::stringstream ss;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,6 @@ std::shared_ptr<ngraph::Variant> VariantWrapper<PrecisionPreservedAttribute>::me
return newAttribute;
}

std::shared_ptr<ngraph::Variant> VariantWrapper<PrecisionPreservedAttribute>::init(const std::shared_ptr<ngraph::Node>& node) {
return nullptr;
}

std::string VariantWrapper<PrecisionPreservedAttribute>::get_string() {
auto value = this->m_value;
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,20 +132,31 @@ class ConcatTransformation : public LayerTransformation, public testing::WithPar
ngraph::element::undefined,
{});

ngraph::pass::VisualizeTree("c:\\Projects\\temp\\test.actual").run_on_function(actualFunction);

//SimpleLowPrecisionTransformer transform;
//transform.register_pass<ngraph::pass::low_precision::FakeQuantizeDecompositionTransformation>();
//transform.register_pass<ngraph::pass::low_precision::ConcatTransformation>();
//transform.transform(actualFunction);

auto supportedPrecisionsOnActivation = std::vector<ngraph::pass::low_precision::OperationPrecisionRestriction>({
ngraph::pass::low_precision::OperationPrecisionRestriction::create<ngraph::opset1::Convolution>({
{0, {ngraph::element::u8}},
{1, {ngraph::element::i8}}
})
});

//#define VISUALIZE_TREE
#ifndef VISUALIZE_TREE

ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::low_precision::MarkupPrecisions>(supportedPrecisionsOnActivation);
manager.register_pass<ngraph::pass::low_precision::MarkupAvgPoolPrecisions>();
manager.register_pass<ngraph::pass::low_precision::PropagatePrecisions>();
manager.register_pass<ngraph::pass::low_precision::AlignConcatQuantizationParamters>();

std::shared_ptr<ngraph::pass::GraphRewrite> common = manager.register_pass<ngraph::pass::GraphRewrite>();
common->add_matcher<ngraph::pass::low_precision::ConcatTransformation>();
common->add_matcher<ngraph::pass::low_precision::FakeQuantizeDecompositionTransformation>();

manager.run_passes(actualFunction);

#else
ngraph::pass::VisualizeTree("c:\\Projects\\temp\\test.actual").run_on_function(actualFunction);

ngraph::pass::Manager manager1;
manager1.register_pass<ngraph::pass::low_precision::MarkupPrecisions>(supportedPrecisionsOnActivation);
manager1.run_passes(actualFunction);
Expand All @@ -166,17 +177,17 @@ class ConcatTransformation : public LayerTransformation, public testing::WithPar
manager4.run_passes(actualFunction);
ngraph::pass::VisualizeTree("c:\\Projects\\temp\\test.transforming4").run_on_function(actualFunction);

// TODO: debug only
ngraph::pass::low_precision::TransformationContext context(actualFunction);
{
ngraph::pass::Manager manager;
std::shared_ptr<ngraph::pass::GraphRewrite> common = manager.register_pass<ngraph::pass::GraphRewrite>();
common->add_matcher<ngraph::pass::low_precision::ConcatTransformation>();
common->add_matcher<ngraph::pass::low_precision::FakeQuantizeDecompositionTransformation>(ngraph::pass::low_precision::LayerTransformation::Params(), context);
common->add_matcher<ngraph::pass::low_precision::FakeQuantizeDecompositionTransformation>();
manager.run_passes(actualFunction);
ngraph::pass::VisualizeTree("c:\\Projects\\temp\\test.transformed").run_on_function(actualFunction);
}

#endif

// dequantization output precision depends on input precision
// to avoid huge amount of tests cases let's define dequantization output precision as input precision
if (!testValues.result.dequantizationAfter.multiply.empty()) {
Expand Down Expand Up @@ -206,7 +217,9 @@ class ConcatTransformation : public LayerTransformation, public testing::WithPar
testValues.result.precisionAfterOperation,
testValues.result.dequantizationAfter);

#ifdef VISUALIZE_TREE
ngraph::pass::VisualizeTree("c:\\Projects\\temp\\test.reference").run_on_function(referenceFunction);
#endif
}

static std::string getTestCaseName(testing::TestParamInfo<ConcatTransformationParams> obj) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class ConcatWithNeighborsWithConvolutionTransformation : public LayerTransformat
})
});

#define VISUALIZE_TREE_NOT
#define VISUALIZE_TREE
#ifndef VISUALIZE_TREE

ngraph::pass::Manager manager;
Expand All @@ -143,6 +143,8 @@ class ConcatWithNeighborsWithConvolutionTransformation : public LayerTransformat
manager.run_passes(actualFunction);

#else
ngraph::pass::VisualizeTree("c:\\Projects\\temp\\test.actual").run_on_function(actualFunction);

{
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::low_precision::MarkupPrecisions>(supportedPrecisionsOnActivation);
Expand Down Expand Up @@ -174,11 +176,7 @@ class ConcatWithNeighborsWithConvolutionTransformation : public LayerTransformat
{
ngraph::pass::Manager manager;
std::shared_ptr<ngraph::pass::GraphRewrite> common = manager.register_pass<ngraph::pass::GraphRewrite>();
//common->add_matcher<ngraph::pass::low_precision::ConcatTransformation>();
//common->add_matcher<ngraph::pass::low_precision::ConvolutionTransformation>();
common->add_matcher<ngraph::pass::low_precision::FakeQuantizeDecompositionTransformation>();
//common->add_matcher<ngraph::pass::low_precision::MaxPoolTransformation>();

manager.run_passes(actualFunction);
ngraph::pass::VisualizeTree("c:\\Projects\\temp\\test.transformed").run_on_function(actualFunction);
}
Expand All @@ -188,7 +186,6 @@ class ConcatWithNeighborsWithConvolutionTransformation : public LayerTransformat
std::shared_ptr<ngraph::pass::GraphRewrite> common = manager.register_pass<ngraph::pass::GraphRewrite>();
common->add_matcher<ngraph::pass::low_precision::ConcatTransformation>();
common->add_matcher<ngraph::pass::low_precision::ConvolutionTransformation>();
//common->add_matcher<ngraph::pass::low_precision::FakeQuantizeDecompositionTransformation>();
common->add_matcher<ngraph::pass::low_precision::MaxPoolTransformation>();

manager.run_passes(actualFunction);
Expand All @@ -208,7 +205,9 @@ class ConcatWithNeighborsWithConvolutionTransformation : public LayerTransformat
testValues.result.dequantizationAfter1,
testValues.result.dequantizationAfter2);

#ifdef VISUALIZE_TREE
ngraph::pass::VisualizeTree("c:\\Projects\\temp\\test.reference").run_on_function(referenceFunction);
#endif
}

static std::string getTestCaseName(testing::TestParamInfo<ConcatWithNeighborsWithConvolutionParams> obj) {
Expand Down
2 changes: 2 additions & 0 deletions ngraph/core/include/ngraph/variant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ namespace ngraph

virtual std::shared_ptr<ngraph::Variant> init(const std::shared_ptr<ngraph::Node>& node);
virtual std::shared_ptr<ngraph::Variant> merge(const ngraph::NodeVector& nodes);

// TODO: to debug
virtual std::string get_string() { return ""; }
};

Expand Down

0 comments on commit a08b656

Please sign in to comment.