Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize away Passthrough nodes #2896

Merged
merged 1 commit into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backend/src/chain/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def __init__(self):
self.__edges_by_source: dict[NodeId, list[Edge]] = {}
self.__edges_by_target: dict[NodeId, list[Edge]] = {}

def nodes_with_schema_id(self, schema_id: str) -> list[Node]:
return [node for node in self.nodes.values() if node.schema_id == schema_id]

def add_node(self, node: Node):
assert node.id not in self.nodes, f"Duplicate node id {node.id}"
self.nodes[node.id] = node
Expand Down
94 changes: 58 additions & 36 deletions backend/src/chain/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,45 @@ def __removed_dead_nodes(chain: Chain, mutation: _Mutation):
logger.debug(f"Chain optimization: Removed {node.schema_id} node {node.id}")


def __removed_pass_through(chain: Chain, mutation: _Mutation):
"""
Remove Passthrough nodes where possible.
"""

# We only remove Passthrough nodes with a single input-output pair
# For more information, see:
# https://github.com/chaiNNer-org/chaiNNer/issues/2555
# https://github.com/chaiNNer-org/chaiNNer/issues/2556
for node in chain.nodes_with_schema_id("chainner:utility:pass_through"):
out_edges = chain.edges_from(node.id)
if len(out_edges) == 1 and len(chain.edges_to(node.id)) == 1:
edge = out_edges[0]
__passthrough(
chain,
node,
input_id=InputId(edge.source.output_id),
output_id=edge.source.output_id,
)
chain.remove_node(node.id)
mutation.signal()


def __static_switch(chain: Chain, mutation: _Mutation):
"""
If the selected variant of the Switch node is statically known (which should always be the case), then we can statically resolve and remove the Switch node.
"""

for node in list(chain.nodes.values()):
if node.schema_id == "chainner:utility:switch":
value_index = chain.inputs.get(node.id, node.data.inputs[0].id)
if isinstance(value_index, int):
passed = False
for index, i in enumerate(node.data.inputs[1:]):
if index == value_index:
passed = __passthrough(chain, node, i.id)
for node in chain.nodes_with_schema_id("chainner:utility:switch"):
value_index = chain.inputs.get(node.id, node.data.inputs[0].id)
if isinstance(value_index, int):
passed = False
for index, i in enumerate(node.data.inputs[1:]):
if index == value_index:
passed = __passthrough(chain, node, i.id)

if passed:
chain.remove_node(node.id)
mutation.signal()
if passed:
chain.remove_node(node.id)
mutation.signal()


def __useless_conditional(chain: Chain, mutation: _Mutation):
Expand All @@ -95,31 +117,30 @@ def as_bool(value: object):
return True
return None

for node in list(chain.nodes.values()):
if node.schema_id == "chainner:utility:conditional":
# the condition is a constant
const_condition = as_bool(chain.inputs.get(node.id, InputId(0)))
if const_condition is not None:
__passthrough(
chain,
node,
input_id=if_true if const_condition else if_false,
)
chain.remove_node(node.id)
mutation.signal()
continue

# identical true and false branches
true_edge = chain.edge_to(node.id, if_true)
false_edge = chain.edge_to(node.id, if_false)
if (
true_edge is not None
and false_edge is not None
and true_edge.source == false_edge.source
):
__passthrough(chain, node, if_true)
chain.remove_node(node.id)
mutation.signal()
for node in chain.nodes_with_schema_id("chainner:utility:conditional"):
# the condition is a constant
const_condition = as_bool(chain.inputs.get(node.id, InputId(0)))
if const_condition is not None:
__passthrough(
chain,
node,
input_id=if_true if const_condition else if_false,
)
chain.remove_node(node.id)
mutation.signal()
continue

# identical true and false branches
true_edge = chain.edge_to(node.id, if_true)
false_edge = chain.edge_to(node.id, if_false)
if (
true_edge is not None
and false_edge is not None
and true_edge.source == false_edge.source
):
__passthrough(chain, node, if_true)
chain.remove_node(node.id)
mutation.signal()


def optimize(chain: Chain):
Expand All @@ -129,6 +150,7 @@ def optimize(chain: Chain):

__removed_dead_nodes(chain, mutation)
__static_switch(chain, mutation)
__removed_pass_through(chain, mutation)
__useless_conditional(chain, mutation)

if not mutation.changed:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@
),
],
outputs=[
AnyOutput(output_type="Input0", label="Value 1"),
AnyOutput(output_type="Input1", label="Value 2"),
AnyOutput(output_type="Input2", label="Value 3"),
AnyOutput(output_type="Input3", label="Value 4"),
AnyOutput(output_type="Input4", label="Value 5"),
AnyOutput(output_type="Input5", label="Value 6"),
AnyOutput(output_type="Input6", label="Value 7"),
AnyOutput(output_type="Input7", label="Value 8"),
AnyOutput(output_type="Input8", label="Value 9"),
AnyOutput(output_type="Input9", label="Value 10"),
AnyOutput(output_type="Input0", label="Value 1").with_id(0),
AnyOutput(output_type="Input1", label="Value 2").with_id(1),
AnyOutput(output_type="Input2", label="Value 3").with_id(2),
AnyOutput(output_type="Input3", label="Value 4").with_id(3),
AnyOutput(output_type="Input4", label="Value 5").with_id(4),
AnyOutput(output_type="Input5", label="Value 6").with_id(5),
AnyOutput(output_type="Input6", label="Value 7").with_id(6),
AnyOutput(output_type="Input7", label="Value 8").with_id(7),
AnyOutput(output_type="Input8", label="Value 9").with_id(8),
AnyOutput(output_type="Input9", label="Value 10").with_id(9),
],
)
def pass_through_node(
Expand Down
Loading