diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 8560a57bc902..7aa5de0e9c1d 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -25,16 +25,17 @@ Basically this is a simplified TVM IR with schedule primitives. We don't use the existing TVM IR because 1. We want fast incremental change to the loop structures -2. We want serializable history for replay and backtracking +2. We want serializable transformation history for replay, backtracking, and mutation 3. We may create some new macro schedule primitives -After search is done, we will lower this IR to TVM IR with TVM's schedule primitives. +After the search is done, we will lower this IR to TVM IR with TVM's schedule primitives. Because we share a lot common objects during search, the transformation is implemented in copy on write style. All objects are immutable, which is similar to TVM IR. """ import tvm._ffi +from tvm.te.tensor import Operation, Tensor from tvm.runtime import Object from . import _ffi_api @@ -80,43 +81,9 @@ def __init__(self, state_object, dag): self.state_object = state_object self.compute_dag = dag - self.stages_cache = None - self.stage_id_map = {} - self.__update_tensor_stage_map() - - def __getitem__(self, k): - if not self.stages_cache: - self.stages_cache = _ffi_api.StateGetStages(self.state_object) - if isinstance(k, tvm.te.Tensor): - return self.stages_cache[self.stage_id_map[k.op]] - raise ValueError("Item must be Tensor") - - def __update_tensor_stage_map(self): - if not self.stages_cache: - self.stages_cache = _ffi_api.StateGetStages(self.state_object) - for index, stage in enumerate(self.stages_cache): - self.stage_id_map[stage.op] = index - - def __insert_new_stage(self, new_stage_id): - new_stage_id = int(new_stage_id) - self.stages_cache = _ffi_api.StateGetStages(self.state_object) - added_stage_tensor = self.stages_cache[new_stage_id].op.output(0) - - for key, value in self.stage_id_map.items(): - if value >= new_stage_id: - self.stage_id_map[key] = value + 1 - self.stage_id_map[added_stage_tensor.op] = new_stage_id - self.__update_tensor_stage_map() - - return added_stage_tensor - - def clear_cache(self): - self.stages_cache = None - - def copy(self): - state = State(self.state_object, self.compute_dag) - state.stage_id_map = self.stage_id_map.copy() - return state + self.stages_cache = None # A list to cache all stages + self.stage_id_map = {} # A dict maps operation to stage id + self._update_stage_id_map() @property def stages(self): @@ -130,15 +97,15 @@ def stages(self): return self.stages_cache @property - def stage_tensors(self): + def stage_ops(self): """ Returns ------- - Tensor + ops: List[Operation] """ if not self.stages_cache: self.stages_cache = _ffi_api.StateGetStages(self.state_object) - return [stage.op.output(0) for stage in self.stages_cache] + return [stage.op for stage in self.stages_cache] def transform_steps_size(self): """ Return the size of transform_steps @@ -149,30 +116,27 @@ def reorder(self, stage_id, order): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to reorder order : List[Iterator] Iterators in the expected order """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order) - self.clear_cache() + self._clear_cache() def split(self, stage_id, iterator, lengths, inner_to_outer=True): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to split iterator : Iterator The iterator to split - lengths: List[Int] + lengths: List[int] The split factors - inner_to_outer: Bool + inner_to_outer: bool True to use `factor` to split from inner to outer, False to use `nparts` to split from outer to inner @@ -181,27 +145,24 @@ def split(self, stage_id, iterator, lengths, inner_to_outer=True): res_its : List[Iterator] The splitted new Iterators """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, iterator, lengths, inner_to_outer) - self.clear_cache() + self._clear_cache() return res def follow_split(self, stage_id, iterator, src_step_id, n_split): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to split iterator : Iterator The iterator to split - src_step_id : Int + src_step_id : int The index of the split step to follow in the history - n_split : Int + n_split : int The number of split level Returns @@ -209,14 +170,11 @@ def follow_split(self, stage_id, iterator, src_step_id, n_split): res_its : List[Iterator] The splitted new Iterators """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, stage_id, iterator, src_step_id, n_split) - self.clear_cache() + self._clear_cache() return res def follow_fused_split(self, stage_id, iterator, src_step_ids, level, @@ -224,15 +182,15 @@ def follow_fused_split(self, stage_id, iterator, src_step_ids, level, """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to split iterator : Iterator The iterator to split - src_step_ids : List[Int] + src_step_ids : List[int] The indices of the split steps to follow in the history - level : Int + level : int Use the length in this split level - factor_or_nparts : Bool + factor_or_nparts : bool True to use `factor` for split from inner to outer, False to use `nparts` for split from outer to inner @@ -241,22 +199,19 @@ def follow_fused_split(self, stage_id, iterator, src_step_ids, level, res_its : List[Iterator] The splitted new Iterators """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, stage_id, iterator, src_step_ids, level, factor_or_nparts) - self.clear_cache() + self._clear_cache() return res def fuse(self, stage_id, iters): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to fuse iters : List[Iterator] The iterators to be fused @@ -266,20 +221,17 @@ def fuse(self, stage_id, iters): res_it : Iterator The fused Iterator """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters) - self.clear_cache() + self._clear_cache() return res def vectorize(self, stage_id, iterator): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to vectorize iterator : Iterator The iterator to be vectorized @@ -289,20 +241,17 @@ def vectorize(self, stage_id, iterator): res_it : Iterator The vectorized Iterator """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateVectorize(self.state_object, stage_id, iterator) - self.clear_cache() + self._clear_cache() return res def parallel(self, stage_id, iterator): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to parallel iterator : Iterator The iterator to be parallelized @@ -312,24 +261,21 @@ def parallel(self, stage_id, iterator): res_it : Iterator The parallelized Iterator """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateParallel(self.state_object, stage_id, iterator) - self.clear_cache() + self._clear_cache() return res def unroll(self, stage_id, iterator, max_unroll=-1): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to unroll iterator : Iterator The iterator to be unrolled - max_unroll: Int + max_unroll: int The maximum length of the iterator that can be unrolled Returns @@ -337,21 +283,18 @@ def unroll(self, stage_id, iterator, max_unroll=-1): res_it : Iterator The unrolled Iterator """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, iterator, max_unroll) - self.clear_cache() + self._clear_cache() return res def bind_thread(self, stage_id, iterator, thread_name): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to bind iterator : Iterator The iterator to be bound @@ -372,201 +315,167 @@ def bind_thread(self, stage_id, iterator, thread_name): } thread_id = trans_table[thread_name] - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateBindThread(self.state_object, stage_id, iterator, thread_id) - self.clear_cache() + self._clear_cache() return res def compute_at(self, stage_id, target_stage_id, target_iter): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of source stage - target_stage_id : Int + target_stage_id : Union[int, Operation, Tensor] The index of the target stage of compute_at target_iter : Iterator The target Iterator of compute_at """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") - if isinstance(target_stage_id, tvm.te.Tensor): - target_stage_id = self.stage_id_map[target_stage_id.op] - elif not isinstance(target_stage_id, int): - raise ValueError("target_stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) + target_stage_id = self._resolve_stage_id(target_stage_id) self.state_object = _ffi_api.StateComputeAt(self.state_object, stage_id, target_stage_id, target_iter) - self.clear_cache() + self._clear_cache() def compute_root(self, stage_id): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to compute root """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object = _ffi_api.StateComputeRoot(self.state_object, stage_id) - self.clear_cache() + self._clear_cache() def compute_inline(self, stage_id): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to compute inline """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object = _ffi_api.StateComputeInline(self.state_object, stage_id) - self.clear_cache() + self._clear_cache() def cache_read(self, stage_id, scope_name, reader_stage_ids): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to do cache_read - scope_name : Str - reader_stage_ids : List[Int] + scope_name : str + reader_stage_ids : List[int] Returns ------- - new_stage_id : Int + new_stage_id : int The added staged id """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) + if isinstance(reader_stage_ids, list): tmp_list = [] for reader_stage_id in reader_stage_ids: - if isinstance(reader_stage_id, tvm.te.Tensor): - tmp_list.append(self.stage_id_map[reader_stage_id.op]) - elif isinstance(reader_stage_id, int): - tmp_list.append(reader_stage_id) - else: - raise ValueError("reader_stage_id must be Tensor or Int") + tmp_list.append(self._resolve_stage_id(reader_stage_id)) reader_stage_ids = tmp_list else: - raise ValueError("reader_stage_ids must be list of Tensor or Int") + raise ValueError("reader_stage_ids must be list of Tensor or int") self.state_object, new_stage_id = _ffi_api.StateCacheRead(self.state_object, stage_id, scope_name, reader_stage_ids, self.compute_dag) - return self.__insert_new_stage(new_stage_id) + return self._insert_new_stage(new_stage_id) def cache_write(self, stage_id, scope_name): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to do cache read - scope_name : Str + scope_name : str Returns ------- - new_stage_id : Int + new_stage_id : int The added staged id """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, new_stage_id = _ffi_api.StateCacheWrite(self.state_object, stage_id, scope_name, self.compute_dag) - return self.__insert_new_stage(new_stage_id) + return self._insert_new_stage(new_stage_id) def pragma(self, stage_id, iterator, pragma_type): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to add pragma iterator : Iterator The iterator to add pragma - pragma_type : Str + pragma_type : str """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object = _ffi_api.StatePragma(self.state_object, stage_id, iterator, pragma_type) - self.clear_cache() + self._clear_cache() def rfactor(self, stage_id, iterator, factor_iter_id): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to do reduction factor iterator : Iterator - factor_iter_id : Int + factor_iter_id : int Returns ------- - new_stage_id : Int + new_stage_id : int The added staged id """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object, stage_id, iterator, factor_iter_id, self.compute_dag) - return self.__insert_new_stage(new_stage_id) + return self._insert_new_stage(new_stage_id) def storage_align(self, stage_id, iterator, factor, offset): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to do storage align iterator : Iterator - factor : Int - offset : Int + factor : int + offset : int """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, iterator, factor, offset) - self.clear_cache() + self._clear_cache() def tensorize(self, stage_id, iterator, ti_func_name): """ The `ti_func_name` corresponds to a global registered funcion - that returns a TensorIntrin + that returns a Tensorintrin Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to do storage align iterator : Iterator - The target iterator - ti_func_name : Str + The iterator to be tensorized + ti_func_name : str Tensorize intrinsic function name Returns @@ -574,17 +483,66 @@ def tensorize(self, stage_id, iterator, ti_func_name): res_it : Iterator The tensorized Iterator """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateTensorize(self.state_object, stage_id, iterator, ti_func_name) - self.clear_cache() + self._clear_cache() return res + def _resolve_stage_id(self, stage_id): + if isinstance(stage_id, Operation): + return self.stage_id_map[stage_id] + elif isinstance(stage_id, tvm.te.Tensor): + return self.stage_id_map[stage_id.op] + elif isinstance(stage_id, int): + return stage_id + else: + raise ValueError("Invalid stage_id") + + def _update_stage_id_map(self): + if not self.stages_cache: + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + for index, stage in enumerate(self.stages_cache): + self.stage_id_map[stage.op] = index + + def _insert_new_stage(self, new_stage_id): + new_stage_id = int(new_stage_id) + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + added_op = self.stages_cache[new_stage_id].op + + # Add a new stage will change all ops. But we still want to use the old ops to index stages, + # So we keep updating them and do not remove the old ops. + + # Update stage_id_map for old ops, so we can still use the old ops to index stages. + for key, value in self.stage_id_map.items(): + if value >= new_stage_id: + self.stage_id_map[key] = value + 1 + self.stage_id_map[added_op] = new_stage_id + + # Update stage_id_map for new ops + self._update_stage_id_map() + + return added_op + + def _clear_cache(self): + self.stages_cache = None + + def copy(self): + state = State(self.state_object, self.compute_dag) + state.stage_id_map = self.stage_id_map.copy() + return state + + def __getitem__(self, key): + if not self.stages_cache: + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + if isinstance(key, Tensor): + key = key.op + if isinstance(key, Operation): + return self.stages_cache[self.stage_id_map[key]] + raise ValueError("Item must be Tensor") + def __str__(self): return str(self.state_object) diff --git a/tests/python/unittest/test_ansor_compute_dag.py b/tests/python/unittest/test_ansor_compute_dag.py index 313dc1f89902..0768f82b805a 100644 --- a/tests/python/unittest/test_ansor_compute_dag.py +++ b/tests/python/unittest/test_ansor_compute_dag.py @@ -34,9 +34,9 @@ def test_infer_bound(): dag, s = get_tiled_matmul() s = dag.infer_bound_from_state(s) - A_global = s.stage_tensors[1] - B_global = s.stage_tensors[3] - C_global = s.stage_tensors[4] + A_global = s.stage_ops[1] + B_global = s.stage_ops[3] + C_global = s.stage_ops[4] assert s[B_global].iters[0].range.extent == 512 assert s[B_global].iters[1].range.extent == 16 assert s[A_global].iters[0].range.extent == 1 diff --git a/tests/python/unittest/test_ansor_feature.py b/tests/python/unittest/test_ansor_feature.py index bcc7683b3f4a..705556c65edf 100644 --- a/tests/python/unittest/test_ansor_feature.py +++ b/tests/python/unittest/test_ansor_feature.py @@ -33,7 +33,7 @@ def fequal(a, b): def test_cpu_matmul(): dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) s = dag.get_init_state() - C = s.stage_tensors[2] + C = s.stage_ops[2] i, j, k = s[C].iters io, ii = s.split(C, i, [16]) @@ -42,7 +42,7 @@ def test_cpu_matmul(): s.vectorize(C, ji) s.parallel(C, io) s.parallel(C, jo) - s.unroll(2, k) + s.unroll(C, k) target = tvm.target.create('llvm') task = ansor.SearchTask(dag, "test", target) diff --git a/tests/python/unittest/test_ansor_loop_state.py b/tests/python/unittest/test_ansor_loop_state.py index 87688e276469..d90be1a78421 100644 --- a/tests/python/unittest/test_ansor_loop_state.py +++ b/tests/python/unittest/test_ansor_loop_state.py @@ -115,14 +115,14 @@ def test_compute_at_root_inline(): s0 = dag.get_init_state() # data, padding, kernel = 0, 1, 2 - conv = s0.stage_tensors[3] + conv = s0.stage_ops[3] # bias = 4 - bias_add = s0.stage_tensors[5] + bias_add = s0.stage_ops[5] # bn_scale = 6 - bn_mul = s0.stage_tensors[7] + bn_mul = s0.stage_ops[7] # bn_offset = 8 - bn_add = s0.stage_tensors[9] - relu = s0.stage_tensors[10] + bn_add = s0.stage_ops[9] + relu = s0.stage_ops[10] s0.compute_inline(bn_add) s0.compute_inline(bn_mul) @@ -193,8 +193,8 @@ def test_cache_read_write(): dag = ansor.ComputeDAG([data, kernel_data, add]) s0 = dag.get_init_state() - pad_temp = s0.stage_tensors[1] - kernel_split = s0.stage_tensors[3] + pad_temp = s0.stage_ops[1] + kernel_split = s0.stage_ops[3] # 0: init state ori_its = s0[add].iters