Skip to content

Commit

Permalink
wip, yet to be anaylyzed
Browse files Browse the repository at this point in the history
  • Loading branch information
CuriousPanCake committed Nov 1, 2024
1 parent 8da8a30 commit 796880b
Showing 1 changed file with 26 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,26 @@ std::vector<ov::DiscreteTypeInfo> ov::pass::MoveEltwiseUpThroughDataMov::get_def
};
}

std::shared_ptr<ov::Node> recurse_up(std::shared_ptr<ov::Node> node, std::shared_ptr<ov::Node> stop_node, std::shared_ptr<ov::Node> new_eltwise) {
if (node->input_value(0).get_node_shared_ptr() == stop_node) {
ov::OutputVector node_inputs = node->input_values();
node_inputs[0] = new_eltwise;
auto new_node = node->clone_with_new_inputs(node_inputs);
ov::copy_runtime_info(node, new_node);
new_node->set_friendly_name(node->get_friendly_name());
return new_node;
}

auto node_above = recurse_up(node->input_value(0).get_node_shared_ptr(), stop_node, new_eltwise);

ov::OutputVector node_inputs = node->input_values();
node_inputs[0] = node_above;
auto new_node = node->clone_with_new_inputs(node_inputs);
ov::copy_runtime_info(node, new_node);
new_node->set_friendly_name(node->get_friendly_name());
return new_node;
}

ov::pass::MoveEltwiseUpThroughDataMovScalar::MoveEltwiseUpThroughDataMovScalar(
std::vector<DiscreteTypeInfo> allowed_data_movement_ops) {
MATCHER_SCOPE(MoveEltwiseUpThroughDataMovScalar);
Expand Down Expand Up @@ -110,7 +130,7 @@ ov::pass::MoveEltwiseUpThroughDataMovScalar::MoveEltwiseUpThroughDataMovScalar(
ov::replace_output_update_name(eltwise->output(0), eltwise->input_value(0));

ov::OutputVector eltwise_inputs = eltwise->input_values();
eltwise_inputs[0] = child->input_value(0);
eltwise_inputs[0] = current;
auto new_eltwise = eltwise->clone_with_new_inputs(eltwise_inputs);
// WA: it's necessary to set empty friendly name here
// to avoid name duplication in TypeRelaxed cases
Expand All @@ -123,7 +143,11 @@ ov::pass::MoveEltwiseUpThroughDataMovScalar::MoveEltwiseUpThroughDataMovScalar(
ov::copy_runtime_info(child, new_child);
new_child->set_friendly_name(child->get_friendly_name());

ov::replace_node(child, new_child);
auto bottom = recurse_up(eltwise->input_value(0).get_node_shared_ptr(), current, new_eltwise);
ov::replace_node(eltwise->input_value(0).get_node_shared_ptr(), bottom);

std::cout << "working on " << eltwise << std::endl;

return true;
};

Expand Down

0 comments on commit 796880b

Please sign in to comment.