Skip to content

Commit

Permalink
[AutoScheduler] Improve doc string (#6176)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Jul 30, 2020
1 parent 760a547 commit a26ac93
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 255 deletions.
4 changes: 2 additions & 2 deletions include/tvm/auto_scheduler/auto_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ class TuningOptions : public ObjectRef {
/*!
* \brief Run schedule search for a given compute declaration.
* \param task The search task of the compute declaration.
* \param search_policy The search policy to be used.
* \param search_policy The search policy.
* \param tuning_options Tuning and measurement options.
* \return A `te::schedule` and the an Array of `te::Tensor` to be used in `tvm.lower` or
* \return A `te::schedule` and an Array of `te::Tensor` to be used in `tvm.lower` or
* `tvm.build`.
*/
TVM_DLL std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchTask task,
Expand Down
37 changes: 18 additions & 19 deletions include/tvm/auto_scheduler/compute_dag.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@
* \brief The auto-scheduler's computational graph and related program analyses.
*
* We convert a compute declaration described by `tvm.compute` (could be a single operator or a
* subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
* a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
* total float operation count, consumer/producer relations of each operation stage, whether an
* operation stage should be tiled/compute inlined ...). These analyses can help the search policy
* to make decisions during search process.
* ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
* TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
* subgraph) to a ComputeDAG. It keeps the input/output tensors, all operations in the DAG, and
* some static analysis results for the DAG (e.g. the total float operation count, consumer/producer
* relations of operations, whether an operation stage should be tiled/compute inlined ...).
* These analyses can help the search policy to make decisions during the search.
* ComputeDAG is also responsible for the interaction between auto-scheduler's `LoopState` and
* TVM schedule (e.g. applying the `LoopState` transform steps to a TVM schedule, providing
* `LoopState` with extra information got from TVM schedule ...).
*/

Expand All @@ -47,18 +46,18 @@
namespace tvm {
namespace auto_scheduler {

/*! \brief Static analysis result for a ComputeDAG */
/*! \brief Static analyzer for a ComputeDAG */
class AccessAnalyzerNode : public Object {
public:
template <class T>
using OperationMap = std::unordered_map<te::Operation, T, ObjectPtrHash, ObjectPtrEqual>;

/*! \brief Map an operation to all operations it reads from.
* For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses
* For each operation pair, use a two-dimentional array for multiple multi-dimentional accesses
* The inner vector represents the indices of multi-dimensional access.*/
OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;
/*! \brief Map an operation to all operations it is read by.
* For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses
* For each operation pair, use a two-dimentional array for multiple multi-dimentional accesses
* The inner vector represents the indices of multi-dimensional access.*/
OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_by;
/*! \brief Store the number of common outer iterators for operation pairs that have
Expand Down Expand Up @@ -92,7 +91,7 @@ class AccessAnalyzer : public ObjectRef {
explicit AccessAnalyzer(const Array<te::Tensor>& tensors);

/*!
* \brief Return whether this operation is an injective operation
* \brief Return whether this operation is an op with simple access
* (e.g., injective, broadcast and elementwise ops without reduction)
* \param op The operation
*/
Expand All @@ -113,13 +112,13 @@ class AccessAnalyzer : public ObjectRef {
TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const;

/*!
* \brief Return whether this operation is an output op
* \brief Return whether this operation is an output operation
* \param op The operation
*/
TVM_DLL bool IsOutput(const te::Operation& op) const;

/*!
* \brief Get all consumers of on operation
* \brief Get all consumers of an operation
* \param state The current loop state
* \param op The operation
* \return The set of consumers
Expand All @@ -129,7 +128,7 @@ class AccessAnalyzer : public ObjectRef {
const State& state, const te::Operation& op) const;

/*!
* \brief Get all producers of on operation
* \brief Get all producers of an operation
* \param state The current loop state
* \param op The operation
* \return The set of producers
Expand All @@ -139,7 +138,7 @@ class AccessAnalyzer : public ObjectRef {
const State& state, const te::Operation& op) const;

/*!
* \brief Get all direct producers of on operation
* \brief Get all direct producers of an operation
* \param op The operation
* \return The set of direct producers
* \note This function DOES NOT propagate the relation for inlined ops
Expand All @@ -158,25 +157,25 @@ class AccessAnalyzer : public ObjectRef {

/*!
* \brief Return whether two operations are elementwise-matched
* (e.g. conv2d and relu are elementwise matched)
* (e.g. conv2d and relu are elementwise-matched)
* \note This function propagates the relation for chains with multiple ops.
*/
TVM_DLL bool ElementWiseMatch(const te::Operation& op, const te::Operation& target_op) const;

TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode);
};

/*! \brief The TVM Auto-scheduler computational graph and related program analyses. */
/*! \brief The auto-scheduler's computational graph and related program analyses. */
class ComputeDAGNode : public Object {
public:
/*!
* \brief Input and output tensors.
* This is used as the input of `tvm.lower` or `tvm.build`.
*/
Array<te::Tensor> tensors;
/*! \brief All related operations in topo order. */
/*! \brief All used operations in topo order. */
Array<te::Operation> ops;
/*! \brief The number of total float operations for this ComputeDAG. */
/*! \brief The number of float operations in this ComputeDAG. */
double flop_ct;
/*! \brief The initial state without any transform steps. */
State init_state;
Expand Down
Loading

0 comments on commit a26ac93

Please sign in to comment.