diff --git a/python/tvm/autotvm/graph_tuner/_base.py b/python/tvm/autotvm/graph_tuner/_base.py index e8d35ac35780..ae220bb5e2f8 100644 --- a/python/tvm/autotvm/graph_tuner/_base.py +++ b/python/tvm/autotvm/graph_tuner/_base.py @@ -23,3 +23,5 @@ INVALID_LAYOUT_TIME = 10e9 MAX_OUTPUT_NODES = 16 + +OPT_OUT_OP = ["layout_transform"] diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py index bb9c52d8da69..f1a075637ae9 100644 --- a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py +++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py @@ -34,6 +34,7 @@ bind_inputs, expr2graph from ._base import INVALID_LAYOUT_TIME +from ._base import OPT_OUT_OP def get_infer_layout(task_name): if task_name.startswith("conv2d"): @@ -153,6 +154,7 @@ def __init__(self, graph, input_shapes, records, target_ops, self._in_nodes_dict = get_in_nodes(self._node_list, self._target_ops, input_shapes.keys()) self._out_nodes_dict = get_out_nodes(self._in_nodes_dict) self._fetch_cfg() + self._opt_out_op = OPT_OUT_OP # Setup infer_layout for elemwise-like nodes # Note: graph tuner currently only supports tuning of single input and single output @@ -162,7 +164,7 @@ def __init__(self, graph, input_shapes, records, target_ops, # elemwise-like node, and use infer_layout function from input op to generate layouts. input_names = self._input_shapes.keys() for idx in sorted(self._in_nodes_dict.keys()): - if has_multiple_inputs(self._node_list, idx, input_names): + if has_multiple_inputs(self._node_list, idx, input_names, self._opt_out_op): node_entry = self._node_list[idx] node_entry["topi_op"] = [] node_entry["workloads"] = [] @@ -246,7 +248,7 @@ def _iterate_layout_transform(self, callback): node_entry = self._node_list[key] target_input_idx = -1 target_input_pos = -1 - if has_multiple_inputs(self._node_list, key, input_names): + if has_multiple_inputs(self._node_list, key, input_names, self._opt_out_op): for i, item in enumerate(val): node = self._node_list[item] if not is_boundary_node(node, input_names): diff --git a/python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py b/python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py index e3e4d1137afd..b9d40c85ba58 100644 --- a/python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py +++ b/python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py @@ -144,7 +144,7 @@ def _backward(self): continue optimal_sch_idx = optimal_record_dict[node_idx] full_states = self._stage_dict[node_idx].full_states - if not has_multiple_inputs(self._node_list, node_idx, input_names): + if not has_multiple_inputs(self._node_list, node_idx, input_names, self._opt_out_op): input_idx = self._in_nodes_dict[node_idx][0] input_node = self._node_list[input_idx] if is_boundary_node(input_node, input_names): diff --git a/python/tvm/autotvm/graph_tuner/pbqp_tuner.py b/python/tvm/autotvm/graph_tuner/pbqp_tuner.py index 36090f4109cf..d58694c26329 100644 --- a/python/tvm/autotvm/graph_tuner/pbqp_tuner.py +++ b/python/tvm/autotvm/graph_tuner/pbqp_tuner.py @@ -249,7 +249,7 @@ def run(self, **kwargs): for key, val in self._in_nodes_dict.items(): target_input_idx = -1 target_input_pos = -1 - if has_multiple_inputs(self._node_list, key, input_names): + if has_multiple_inputs(self._node_list, key, input_names, self._opt_out_op): for i, item in enumerate(val): node = self._node_list[item] if not is_boundary_node(node, input_names): diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index 17450ca3e7f3..f1dd40440532 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -26,7 +26,7 @@ from tvm.autotvm.task import TaskExtractEnv from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node - +from .._base import OPT_OUT_OP def expr2graph(expr, target_ops, node_dict, node_list): """Convert relay expr to graph data structure @@ -204,7 +204,8 @@ def get_direct_ancestor(node_list, visited_dict, target_ops, node_idx, input_nam node_direct_ancestor = [] for item_idx in node["inputs"]: item = node_list[item_idx[0]] - is_multiple_inputs = has_multiple_inputs(node_list, item_idx[0], input_names) + is_multiple_inputs = has_multiple_inputs(node_list, item_idx[0], \ + input_names, OPT_OUT_OP) if item["op"] in target_ops or is_multiple_inputs: node_direct_ancestor.append(item_idx[0]) else: @@ -245,7 +246,8 @@ def get_in_nodes(node_list, target_ops, input_names): get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names) for key, val in visited_dict.items(): node = node_list[key] - is_multiple_inputs = has_multiple_inputs(node_list, key, input_names) + is_multiple_inputs = has_multiple_inputs(node_list, key, \ + input_names, OPT_OUT_OP) if node["op"] in target_ops or is_multiple_inputs: in_node_dict[key] = val diff --git a/python/tvm/autotvm/graph_tuner/utils/utils.py b/python/tvm/autotvm/graph_tuner/utils/utils.py index 2486d0c0bda0..70e95c904dec 100644 --- a/python/tvm/autotvm/graph_tuner/utils/utils.py +++ b/python/tvm/autotvm/graph_tuner/utils/utils.py @@ -20,8 +20,7 @@ from tvm import relay from tvm.relay import transform - -def has_multiple_inputs(node_list, node_idx, input_names): +def has_multiple_inputs(node_list, node_idx, input_names, opt_out_op): """Check whether a node has multiple input nodes except variable nodes. @@ -47,7 +46,14 @@ def has_multiple_inputs(node_list, node_idx, input_names): in_idx = in_idx[0] in_node = node_list[in_idx] # Exclude parameter nodes - if in_node["op"] is not None or \ + if(in_node["op"] is not None and in_node["op"].name in opt_out_op): + increase = False + for t_idx in in_node["inputs"]: + increase = has_multiple_inputs(node_list, t_idx[0], \ + input_names, opt_out_op) + if increase: + num_inputs += 1 + elif in_node["op"] is not None or \ ("name" in in_node and in_node["name"] in input_names): num_inputs += 1 return num_inputs > 1 diff --git a/tests/python/unittest/test_graph_tuner_utils.py b/tests/python/unittest/test_graph_tuner_utils.py index f620accb1719..bd0ebe0cd3f5 100644 --- a/tests/python/unittest/test_graph_tuner_utils.py +++ b/tests/python/unittest/test_graph_tuner_utils.py @@ -27,11 +27,12 @@ from tvm.relay.testing import resnet from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \ get_out_nodes, expr2graph, bind_inputs +from tvm.autotvm.graph_tuner._base import OPT_OUT_OP from tvm.relay.expr import Call, TupleGetItem, Tuple, Var def verify_has_multiple_inputs(node_list, node_idx, input_names, expected_result): - out = has_multiple_inputs(node_list, node_idx, input_names) + out = has_multiple_inputs(node_list, node_idx, input_names, OPT_OUT_OP) assert out == expected_result, "Output mismatch: expecting checking %s to be %s but got %s." \ % (node_list[node_idx]["op"], str(expected_result), str(out))