From 9ea8becee1e12355661e0d756c938cc3b602ca35 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Sun, 2 May 2021 23:12:06 +0300 Subject: [PATCH] [LPT] separateInStandaloneBranch fix --- .../src/network_helper.cpp | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/inference-engine/src/low_precision_transformations/src/network_helper.cpp b/inference-engine/src/low_precision_transformations/src/network_helper.cpp index 4a1e942e5753ba..56bfaaa4eee869 100644 --- a/inference-engine/src/low_precision_transformations/src/network_helper.cpp +++ b/inference-engine/src/low_precision_transformations/src/network_helper.cpp @@ -524,8 +524,10 @@ std::shared_ptr NetworkHelper::separateInStandaloneBranch(std::sha if (dequantization.isShared()) { Output parent = dequantization.data; if (dequantization.convert != nullptr) { - parent = dequantization.convert->clone_with_new_inputs({ parent }); - parent.get_node_shared_ptr()->set_friendly_name(parent.get_node_shared_ptr()->get_name() + "_new"); + auto convert = dequantization.convert->clone_with_new_inputs({ parent }); + convert->set_friendly_name(parent.get_node_shared_ptr()->get_name() + "_new"); + copy_runtime_info(parent.get_node_shared_ptr(), convert); + parent = convert->output(0); } if (dequantization.subtract != nullptr) { @@ -537,22 +539,26 @@ std::shared_ptr NetworkHelper::separateInStandaloneBranch(std::sha outputs.push_back(input.get_source_output()); } - parent = dequantization.subtract->clone_with_new_inputs({parent, parentOnWeights->clone_with_new_inputs(outputs) }); - parent.get_node_shared_ptr()->set_friendly_name(parent.get_node_shared_ptr()->get_name() + "_new"); + auto subtract = dequantization.subtract->clone_with_new_inputs({parent, parentOnWeights->clone_with_new_inputs(outputs) }); + subtract->set_friendly_name(parent.get_node_shared_ptr()->get_name() + "_new"); + copy_runtime_info(parent.get_node_shared_ptr(), subtract); + parent = subtract->output(0); } if (dequantization.multiply != nullptr) { - parent = dequantization.multiply->clone_with_new_inputs({ + auto multiply = dequantization.multiply->clone_with_new_inputs({ parent, dequantization.multiply->get_input_node_shared_ptr(1)->clone_with_new_inputs({}) }); - parent.get_node_shared_ptr()->set_friendly_name(parent.get_node_shared_ptr()->get_name() + "_new"); + multiply->set_friendly_name(parent.get_node_shared_ptr()->get_name() + "_new"); + copy_runtime_info(parent.get_node_shared_ptr(), multiply); + parent = multiply->output(0); } std::vector> inputs = node->input_values(); const size_t inputIndex = NetworkHelper::getChildInputIndex(dequantization.multiply, node); inputs[inputIndex] = parent; const std::shared_ptr newNode = node->clone_with_new_inputs(inputs); - + copy_runtime_info(node, newNode); replace_node(node, newNode); newNode->set_friendly_name(node->get_friendly_name());