Skip to content

Commit

Permalink
Optimize away Passthrough nodes (#2896)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment authored May 22, 2024
1 parent d5cfda0 commit 07db5bc
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 46 deletions.
3 changes: 3 additions & 0 deletions backend/src/chain/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def __init__(self):
self.__edges_by_source: dict[NodeId, list[Edge]] = {}
self.__edges_by_target: dict[NodeId, list[Edge]] = {}

def nodes_with_schema_id(self, schema_id: str) -> list[Node]:
return [node for node in self.nodes.values() if node.schema_id == schema_id]

def add_node(self, node: Node):
assert node.id not in self.nodes, f"Duplicate node id {node.id}"
self.nodes[node.id] = node
Expand Down
94 changes: 58 additions & 36 deletions backend/src/chain/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,45 @@ def __removed_dead_nodes(chain: Chain, mutation: _Mutation):
logger.debug(f"Chain optimization: Removed {node.schema_id} node {node.id}")


def __removed_pass_through(chain: Chain, mutation: _Mutation):
"""
Remove Passthrough nodes where possible.
"""

# We only remove Passthrough nodes with a single input-output pair
# For more information, see:
# https://github.com/chaiNNer-org/chaiNNer/issues/2555
# https://github.com/chaiNNer-org/chaiNNer/issues/2556
for node in chain.nodes_with_schema_id("chainner:utility:pass_through"):
out_edges = chain.edges_from(node.id)
if len(out_edges) == 1 and len(chain.edges_to(node.id)) == 1:
edge = out_edges[0]
__passthrough(
chain,
node,
input_id=InputId(edge.source.output_id),
output_id=edge.source.output_id,
)
chain.remove_node(node.id)
mutation.signal()


def __static_switch(chain: Chain, mutation: _Mutation):
"""
If the selected variant of the Switch node is statically known (which should always be the case), then we can statically resolve and remove the Switch node.
"""

for node in list(chain.nodes.values()):
if node.schema_id == "chainner:utility:switch":
value_index = chain.inputs.get(node.id, node.data.inputs[0].id)
if isinstance(value_index, int):
passed = False
for index, i in enumerate(node.data.inputs[1:]):
if index == value_index:
passed = __passthrough(chain, node, i.id)
for node in chain.nodes_with_schema_id("chainner:utility:switch"):
value_index = chain.inputs.get(node.id, node.data.inputs[0].id)
if isinstance(value_index, int):
passed = False
for index, i in enumerate(node.data.inputs[1:]):
if index == value_index:
passed = __passthrough(chain, node, i.id)

if passed:
chain.remove_node(node.id)
mutation.signal()
if passed:
chain.remove_node(node.id)
mutation.signal()


def __useless_conditional(chain: Chain, mutation: _Mutation):
Expand All @@ -95,31 +117,30 @@ def as_bool(value: object):
return True
return None

for node in list(chain.nodes.values()):
if node.schema_id == "chainner:utility:conditional":
# the condition is a constant
const_condition = as_bool(chain.inputs.get(node.id, InputId(0)))
if const_condition is not None:
__passthrough(
chain,
node,
input_id=if_true if const_condition else if_false,
)
chain.remove_node(node.id)
mutation.signal()
continue

# identical true and false branches
true_edge = chain.edge_to(node.id, if_true)
false_edge = chain.edge_to(node.id, if_false)
if (
true_edge is not None
and false_edge is not None
and true_edge.source == false_edge.source
):
__passthrough(chain, node, if_true)
chain.remove_node(node.id)
mutation.signal()
for node in chain.nodes_with_schema_id("chainner:utility:conditional"):
# the condition is a constant
const_condition = as_bool(chain.inputs.get(node.id, InputId(0)))
if const_condition is not None:
__passthrough(
chain,
node,
input_id=if_true if const_condition else if_false,
)
chain.remove_node(node.id)
mutation.signal()
continue

# identical true and false branches
true_edge = chain.edge_to(node.id, if_true)
false_edge = chain.edge_to(node.id, if_false)
if (
true_edge is not None
and false_edge is not None
and true_edge.source == false_edge.source
):
__passthrough(chain, node, if_true)
chain.remove_node(node.id)
mutation.signal()


def optimize(chain: Chain):
Expand All @@ -129,6 +150,7 @@ def optimize(chain: Chain):

__removed_dead_nodes(chain, mutation)
__static_switch(chain, mutation)
__removed_pass_through(chain, mutation)
__useless_conditional(chain, mutation)

if not mutation.changed:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@
),
],
outputs=[
AnyOutput(output_type="Input0", label="Value 1"),
AnyOutput(output_type="Input1", label="Value 2"),
AnyOutput(output_type="Input2", label="Value 3"),
AnyOutput(output_type="Input3", label="Value 4"),
AnyOutput(output_type="Input4", label="Value 5"),
AnyOutput(output_type="Input5", label="Value 6"),
AnyOutput(output_type="Input6", label="Value 7"),
AnyOutput(output_type="Input7", label="Value 8"),
AnyOutput(output_type="Input8", label="Value 9"),
AnyOutput(output_type="Input9", label="Value 10"),
AnyOutput(output_type="Input0", label="Value 1").with_id(0),
AnyOutput(output_type="Input1", label="Value 2").with_id(1),
AnyOutput(output_type="Input2", label="Value 3").with_id(2),
AnyOutput(output_type="Input3", label="Value 4").with_id(3),
AnyOutput(output_type="Input4", label="Value 5").with_id(4),
AnyOutput(output_type="Input5", label="Value 6").with_id(5),
AnyOutput(output_type="Input6", label="Value 7").with_id(6),
AnyOutput(output_type="Input7", label="Value 8").with_id(7),
AnyOutput(output_type="Input8", label="Value 9").with_id(8),
AnyOutput(output_type="Input9", label="Value 10").with_id(9),
],
)
def pass_through_node(
Expand Down

0 comments on commit 07db5bc

Please sign in to comment.