Skip to content

Commit

Permalink
[LPT] separateInStandaloneBranch fix
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli authored and v-Golubev committed May 19, 2021
1 parent 16a08e6 commit 9ea8bec
Showing 1 changed file with 13 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,10 @@ std::shared_ptr<ngraph::Node> NetworkHelper::separateInStandaloneBranch(std::sha
if (dequantization.isShared()) {
Output<Node> parent = dequantization.data;
if (dequantization.convert != nullptr) {
parent = dequantization.convert->clone_with_new_inputs({ parent });
parent.get_node_shared_ptr()->set_friendly_name(parent.get_node_shared_ptr()->get_name() + "_new");
auto convert = dequantization.convert->clone_with_new_inputs({ parent });
convert->set_friendly_name(parent.get_node_shared_ptr()->get_name() + "_new");
copy_runtime_info(parent.get_node_shared_ptr(), convert);
parent = convert->output(0);
}

if (dequantization.subtract != nullptr) {
Expand All @@ -537,22 +539,26 @@ std::shared_ptr<ngraph::Node> NetworkHelper::separateInStandaloneBranch(std::sha
outputs.push_back(input.get_source_output());
}

parent = dequantization.subtract->clone_with_new_inputs({parent, parentOnWeights->clone_with_new_inputs(outputs) });
parent.get_node_shared_ptr()->set_friendly_name(parent.get_node_shared_ptr()->get_name() + "_new");
auto subtract = dequantization.subtract->clone_with_new_inputs({parent, parentOnWeights->clone_with_new_inputs(outputs) });
subtract->set_friendly_name(parent.get_node_shared_ptr()->get_name() + "_new");
copy_runtime_info(parent.get_node_shared_ptr(), subtract);
parent = subtract->output(0);
}

if (dequantization.multiply != nullptr) {
parent = dequantization.multiply->clone_with_new_inputs({
auto multiply = dequantization.multiply->clone_with_new_inputs({
parent,
dequantization.multiply->get_input_node_shared_ptr(1)->clone_with_new_inputs({}) });
parent.get_node_shared_ptr()->set_friendly_name(parent.get_node_shared_ptr()->get_name() + "_new");
multiply->set_friendly_name(parent.get_node_shared_ptr()->get_name() + "_new");
copy_runtime_info(parent.get_node_shared_ptr(), multiply);
parent = multiply->output(0);
}

std::vector<Output<Node>> inputs = node->input_values();
const size_t inputIndex = NetworkHelper::getChildInputIndex(dequantization.multiply, node);
inputs[inputIndex] = parent;
const std::shared_ptr<Node> newNode = node->clone_with_new_inputs(inputs);

copy_runtime_info(node, newNode);
replace_node(node, newNode);
newNode->set_friendly_name(node->get_friendly_name());

Expand Down

0 comments on commit 9ea8bec

Please sign in to comment.