Skip to content

Commit

Permalink
Merge pull request neo-ai#12 from trevor-m/trevmorr-speedup-reduce-su…
Browse files Browse the repository at this point in the history
…bgraph

Trevmorr speedup reduce subgraph size
  • Loading branch information
jianzhong-xu committed Jul 3, 2020
2 parents d2d88ee + 30c7341 commit f9719db
Showing 1 changed file with 75 additions and 66 deletions.
141 changes: 75 additions & 66 deletions python/tvm/relay/backend/contrib/tidl_reduce_subgraph_size.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -195,70 +195,13 @@ def visit_call(self, call):
last_op_args = []
if isinstance(last_op, tvm.relay.expr.Tuple):
# Subgraph has multiple outputs!
ancestor, distances = find_common_ancestor(last_op)

def get_field(field):
"""Get field as it is, unless it is a TupleGetItem which we will remove."""
if isinstance(field, tvm.relay.expr.Call):
# Handle concat
if isinstance(field.args[0], tvm.relay.expr.Tuple):
args = []
for f in field.args[0].fields:
args.append(f)
return args
return [field]
if isinstance(field, tvm.relay.expr.TupleGetItem):
args = []
for arg in field.tuple_value.args:
args.append(arg)
return args
if isinstance(field, tvm.relay.expr.Tuple):
args = []
for arg in field.fields:
args.append(arg)
return args
raise ValueError("New output of subgraph must be Call node.")

def get_args(field):
"""Gather args from field, excluding exclude node"""
args = []
if isinstance(field, tvm.relay.expr.Call):
for arg in field.args:
# Handle concat
if isinstance(arg, tvm.relay.expr.Tuple):
for f in arg.fields:
args.append(f)
else:
args.append(arg)
elif isinstance(field, tvm.relay.expr.TupleGetItem):
for arg in field.tuple_value.args:
args.append(arg)
elif isinstance(field, tvm.relay.expr.Tuple):
for arg in field.fields:
args.append(arg)
else:
raise ValueError("New output of subgraph must be Call node.")
return args

# All nodes come from same parent.
if all([dist == 0 for dist in distances]):
last_op_args = ancestor.args
else:
# Remove node with longest path
index_to_remove = np.argmax(distances)
# field[index_to_remove] is further from LCA, remove it
# by replacing it with its args.
last_op_args = []
for i in range(0, len(last_op.fields)):
if i == index_to_remove:
last_op_args += get_args(last_op.fields[i])
else:
last_op_args += get_field(last_op.fields[i])

# Remove duplicates.
seen = set()
seen_add = seen.add
last_op_args = [x for x in last_op_args if not (x in seen or seen_add(x))]
ancestor, _ = find_common_ancestor(last_op)
# Removing only op furthest from LCA greatly increase time taken for this pass.
# Instead, always delete all the way up to the lower common ancestor. This may
# cause more ops to be removed than is required, but it is much faster.
# TODO(trevmorr): Consider rewriting in C++ to improve speed.
# last_op_args = self._remove_op_furthest_from_lca(last_op, ancestor, distances)
last_op_args = ancestor.args
elif isinstance(last_op, tvm.relay.expr.Call):
last_op_args = last_op.args
elif isinstance(last_op, tvm.relay.expr.TupleGetItem):
Expand All @@ -268,8 +211,7 @@ def get_args(field):
else:
raise ValueError("Last op is not Call, Tuple, or TupleGetItem")
# Gather new outputs of the subgraph - from removed op's inputs
# This map will map Expr to index in new_outputs tuple
#print('last_op_args', last_op_args)
# This map will map Expr to index in new_outputs tuplea
new_outputs = []
last_op_input_to_new_output_map = {}
if len(last_op_args) > 1:
Expand Down Expand Up @@ -318,6 +260,73 @@ def get_args(field):
return subgraph_gv(*args)
return super().visit_call(call)

def _remove_op_furthest_from_lca(self, last_op, ancestor, distances):
"""For subgraph with multiple outputs, pick output with logest path to least common
ancestor. Returns list of new outputs.
"""
def get_field(field):
"""Get field as it is, unless it is a TupleGetItem which we will remove."""
if isinstance(field, tvm.relay.expr.Call):
# Handle concat
if isinstance(field.args[0], tvm.relay.expr.Tuple):
args = []
for f in field.args[0].fields:
args.append(f)
return args
return [field]
if isinstance(field, tvm.relay.expr.TupleGetItem):
args = []
for arg in field.tuple_value.args:
args.append(arg)
return args
if isinstance(field, tvm.relay.expr.Tuple):
args = []
for arg in field.fields:
args.append(arg)
return args
raise ValueError("New output of subgraph must be Call node.")

def get_args(field):
"""Gather args from field, excluding exclude node"""
args = []
if isinstance(field, tvm.relay.expr.Call):
for arg in field.args:
# Handle concat
if isinstance(arg, tvm.relay.expr.Tuple):
for f in arg.fields:
args.append(f)
else:
args.append(arg)
elif isinstance(field, tvm.relay.expr.TupleGetItem):
for arg in field.tuple_value.args:
args.append(arg)
elif isinstance(field, tvm.relay.expr.Tuple):
for arg in field.fields:
args.append(arg)
else:
raise ValueError("New output of subgraph must be Call node.")
return args

# All nodes come from same parent.
if all([dist == 0 for dist in distances]):
return ancestor.args
# Remove node with longest path
index_to_remove = np.argmax(distances)
# field[index_to_remove] is further from LCA, remove it
# by replacing it with its args.
last_op_args = []
for i in range(0, len(last_op.fields)):
if i == index_to_remove:
last_op_args += get_args(last_op.fields[i])
else:
last_op_args += get_field(last_op.fields[i])

# Remove duplicates.
seen = set()
seen_add = seen.add
last_op_args = [x for x in last_op_args if not (x in seen or seen_add(x))]
return last_op_args

def reduce_subgraph_size(mod, max_num_layers=256, max_total_memory_mb=512):
"""
Reduces size of subgraph to fit limitations.
Expand Down

0 comments on commit f9719db

Please sign in to comment.