Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ansor][AutoTVM v2.0] Phase 1: Add annotation/compute_at/compute_root/compute_inline steps #6073

Merged
merged 15 commits into from
Jul 21, 2020
260 changes: 130 additions & 130 deletions python/tvm/auto_scheduler/loop_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,108 +126,100 @@ def stage_ops(self):
"""
return [stage.op for stage in self.stages]

def reorder(self, stage, order):
""" Schedule primitive corresponds to te.reorder.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be reordered, which can be specified by the integer index, Operation,
or output tensor of the stage.
order : List[Iterator]
Iterators in the expected order.
"""
self.state_object = _ffi_api.StateReorder(self.state_object, self._resolve_stage_id(stage),
order)

def compute_at(self, stage, target_stage, target_iter):
""" Schedule primitive corresponds to te.compute_at.
def bind(self, stage, iterator, thread_name):
""" Schedule primitive corresponds to te.bind.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be compute at, which can be specified by the integer index, Operation,
or output tensor of the stage.
target_stage : Union[int, Operation, Tensor]
The target stage of compute_at, which can be specified by the integer index, Operation,
The Stage to be binded, which can be specified by the integer index, Operation,
or output tensor of the stage.
target_iter : Iterator
The target Iterator of compute_at.
iterator : Iterator
The iterator to be binded.
thread_name : str
The thread type to be binded. Candidates:
- vthread
- blockIdx.x
- threadIdx.x
- blockIdx.y
- threadIdx.y
- blockIdx.z
- threadIdx.z

Notes
-----
After compute_at, we need careful dependency analysis to compute the accurate bound
information. However, it is relatively expensive and complicated, so we just fill "None"
as bound for the newly created iterators.
Call ComputeDAG::InferBound on the returned state to get the complete bound information.
Returns
-------
res_it : Iterator
The binded Iterator.
"""
self.state_object = _ffi_api.StateComputeAt(self.state_object,
self._resolve_stage_id(stage),
self._resolve_stage_id(target_stage),
target_iter)
if not thread_name in State.ANNOTATION_TRANS_TABLE.keys():
raise ValueError("Invalid thread_name: ", thread_name)

def compute_root(self, stage):
""" Schedule primitive corresponds to te.compute_root.
self.state_object, res = _ffi_api.StateBind(self.state_object,
self._resolve_stage_id(stage), iterator,
State.ANNOTATION_TRANS_TABLE[thread_name])
return res

def parallel(self, stage, iterator):
""" Schedule primitive corresponds to te.parallel.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be compute root, which can be specified by the integer index, Operation,
The Stage to be paralleled, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to be paralleled.

Notes
-----
After compute_root, we need careful dependency analysis to compute the accurate bound
information. However, it is relatively expensive and complicated, so we just fill "None"
as bound for the newly created iterators.
Call ComputeDAG::InferBound on the returned state to get the complete bound information.
Returns
-------
res_it : Iterator
The paralleled Iterator.
"""
self.state_object = _ffi_api.StateComputeRoot(self.state_object,
self._resolve_stage_id(stage))
self.state_object, res = _ffi_api.StateParallel(self.state_object,
self._resolve_stage_id(stage), iterator)
return res

def compute_inline(self, stage):
""" Schedule primitive corresponds to te.compute_inline.
def unroll(self, stage, iterator, max_unroll=None):
""" Schedule primitive corresponds to te.unroll.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be compute inlined, which can be specified by the integer index, Operation,
The Stage to be unrolled, which can be specified by the integer index, Operation,
or output tensor of the stage.
"""
self.state_object = _ffi_api.StateComputeInline(self.state_object,
self._resolve_stage_id(stage))
iterator : Iterator
The iterator to be unrolled.
max_unroll : Optional[int]
The max unroll limit. Iterator with extent larger than this limit will be skipped.

def split(self, stage, iterator, lengths, inner_to_outer=True):
""" Schedule primitive corresponds to te.split.
Returns
-------
res_it : Iterator
The unrolled Iterator.
"""
self.state_object, res = _ffi_api.StateUnroll(self.state_object,
self._resolve_stage_id(stage), iterator,
max_unroll if max_unroll else -1)
return res

This API supports multiple split factors. (e.g. with 2 split factors, the original iterator
will be split to 3 parts, use `inner_to_outer` to control the split order)
def vectorize(self, stage, iterator):
""" Schedule primitive corresponds to te.vectorize.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be split, which can be specified by the integer index, Operation,
The Stage to be vectorized, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to be split.
lengths: List[int]
The multiple split factors. Can be None to be filled by search policy.
inner_to_outer: boolean = True
Whether the factor go from inner to outer, or from outer to inner.
The iterator to be vectorized.

Returns
-------
res_its : List[Iterator]
The splitted new Iterators.

Notes
-----
If we do split on an iterator which has stages attached at it(by compute_at), the inner
most iterator of split results will become the new attach point.
res_it : Iterator
The vectorized Iterator.
"""
self.state_object, res = _ffi_api.StateSplit(self.state_object,
self._resolve_stage_id(stage),
iterator, lengths, inner_to_outer)
self.state_object, res = _ffi_api.StateVectorize(self.state_object,
self._resolve_stage_id(stage), iterator)
return res

def fuse(self, stage, iters):
Expand Down Expand Up @@ -255,101 +247,109 @@ def fuse(self, stage, iters):
self._resolve_stage_id(stage), iters)
return res

def vectorize(self, stage, iterator):
""" Schedule primitive corresponds to te.vectorize.
def reorder(self, stage, order):
""" Schedule primitive corresponds to te.reorder.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be vectorized, which can be specified by the integer index, Operation,
The Stage to be reordered, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to be vectorized.

Returns
-------
res_it : Iterator
The vectorized Iterator.
order : List[Iterator]
Iterators in the expected order.
"""
self.state_object, res = _ffi_api.StateVectorize(self.state_object,
self._resolve_stage_id(stage), iterator)
return res
self.state_object = _ffi_api.StateReorder(self.state_object, self._resolve_stage_id(stage),
order)

def parallel(self, stage, iterator):
""" Schedule primitive corresponds to te.parallel.
def split(self, stage, iterator, lengths, inner_to_outer=True):
""" Schedule primitive corresponds to te.split.

This API supports multiple split factors. (e.g. with 2 split factors, the original iterator
will be split to 3 parts, use `inner_to_outer` to control the split order)

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be paralleled, which can be specified by the integer index, Operation,
The Stage to be split, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to be paralleled.
The iterator to be split.
lengths: List[int]
The multiple split factors. Can be None to be filled by search policy.
inner_to_outer: boolean = True
Whether the factor go from inner to outer, or from outer to inner.

Returns
-------
res_it : Iterator
The paralleled Iterator.
res_its : List[Iterator]
The splitted new Iterators.

Notes
-----
If we do split on an iterator which has stages attached at it(by compute_at), the inner
most iterator of split results will become the new attach point.
"""
self.state_object, res = _ffi_api.StateParallel(self.state_object,
self._resolve_stage_id(stage), iterator)
self.state_object, res = _ffi_api.StateSplit(self.state_object,
self._resolve_stage_id(stage),
iterator, lengths, inner_to_outer)
return res

def unroll(self, stage, iterator, max_unroll=None):
""" Schedule primitive corresponds to te.unroll.
def compute_at(self, stage, target_stage, target_iter):
""" Schedule primitive corresponds to te.compute_at.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be unrolled, which can be specified by the integer index, Operation,
The Stage to be compute at, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to be unrolled.
max_unroll : Optional[int]
The max unroll limit. Iterator with extent larger than this limit will be skipped.
target_stage : Union[int, Operation, Tensor]
The target stage of compute_at, which can be specified by the integer index, Operation,
or output tensor of the stage.
target_iter : Iterator
The target Iterator of compute_at.

Returns
-------
res_it : Iterator
The unrolled Iterator.
Notes
-----
After compute_at, we need careful dependency analysis to compute the accurate bound
information. However, it is relatively expensive and complicated, so we just fill "None"
as bound for the newly created iterators.
Call ComputeDAG::InferBound on the returned state to get the complete bound information.
"""
self.state_object, res = _ffi_api.StateUnroll(self.state_object,
self._resolve_stage_id(stage), iterator,
max_unroll if max_unroll else -1)
return res
self.state_object = _ffi_api.StateComputeAt(self.state_object,
self._resolve_stage_id(stage),
self._resolve_stage_id(target_stage),
target_iter)

def bind(self, stage, iterator, thread_name):
""" Schedule primitive corresponds to te.bind.
def compute_inline(self, stage):
""" Schedule primitive corresponds to te.compute_inline.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be binded, which can be specified by the integer index, Operation,
The Stage to be compute inlined, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to be binded.
thread_name : str
The thread type to be binded. Candidates:
- vthread
- blockIdx.x
- threadIdx.x
- blockIdx.y
- threadIdx.y
- blockIdx.z
- threadIdx.z

Returns
-------
res_it : Iterator
The binded Iterator.
"""
if not thread_name in State.ANNOTATION_TRANS_TABLE.keys():
raise ValueError("Invalid thread_name: ", thread_name)
self.state_object = _ffi_api.StateComputeInline(self.state_object,
self._resolve_stage_id(stage))

self.state_object, res = _ffi_api.StateBind(self.state_object,
self._resolve_stage_id(stage), iterator,
State.ANNOTATION_TRANS_TABLE[thread_name])
return res
def compute_root(self, stage):
""" Schedule primitive corresponds to te.compute_root.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be compute root, which can be specified by the integer index, Operation,
or output tensor of the stage.

Notes
-----
After compute_root, we need careful dependency analysis to compute the accurate bound
information. However, it is relatively expensive and complicated, so we just fill "None"
as bound for the newly created iterators.
Call ComputeDAG::InferBound on the returned state to get the complete bound information.
"""
self.state_object = _ffi_api.StateComputeRoot(self.state_object,
self._resolve_stage_id(stage))

def copy(self):
""" Do deep copy of this State. """
Expand Down
Loading