Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Jul 23, 2020
1 parent bea7dcf commit b01fcf8
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 123 deletions.
55 changes: 30 additions & 25 deletions include/tvm/auto_scheduler/compute_dag.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,27 @@ namespace auto_scheduler {
class AccessAnalyzerNode : public Object {
public:
template <class T>
using OperationMap = std::unordered_map<te::Operation, T, ObjectHash, ObjectEqual>;
using OperationMap = std::unordered_map<te::Operation, T, ObjectPtrHash, ObjectPtrEqual>;

/*! \brief Map an operation to all operations it reads from.
* For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
* For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses
* The inner vector represents the indices of multi-dimensional access.*/
OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;
/*! \brief Map an operation to all operations it is read by.
* For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
* For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses
* The inner vector represents the indices of multi-dimensional access.*/
OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_by;
/*! \brief Store the number of common outer iterators for operation pairs that have
* read-write relations. */
OperationMap<OperationMap<int>> num_common_outer_iterators;
/*! \brief Store whether the operation is injective */
OperationMap<bool> is_injective;
/*! \brief Store whether the operation is strictly-inlineable */
/*! \brief Store whether the operation is an op with only simple access.
* (e.g., injective, broadcast and elementwise ops without reduction) */
OperationMap<bool> is_simple_access;
/*! \brief Store whether the operation is strictly-inlineable
* (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations) */
OperationMap<bool> is_strict_inlineable;
/*! \brief Store whether the operation needs multi-level tiling */
/*! \brief Store whether the operation needs multi-level tiling
* (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d) */
OperationMap<bool> needs_multi_level_tiling;
/*! \brief Store whether the operation is an output operation */
OperationMap<bool> is_output;
Expand All @@ -86,22 +91,25 @@ class AccessAnalyzer : public ObjectRef {
explicit AccessAnalyzer(const Array<te::Tensor>& tensors);

/*!
* \brief Return whether this operation needs multi-level tiling
* \brief Return whether this operation is an injective operation
* (e.g., injective, broadcast and elementwise ops without reduction)
* \param op The operation
*/
TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const;
TVM_DLL bool IsSimpleAccess(const te::Operation& op) const;

/*!
* \brief Return whether this operation is an injective operation
* \brief Return whether this operation is strictly inlinable
* (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations)
* \param op The operation
*/
TVM_DLL bool IsInjective(const te::Operation& op) const;
TVM_DLL bool IsStrictInlineable(const te::Operation& op) const;

/*!
* \brief Return whether this operation is strictly inlinable
* \brief Return whether this operation needs multi-level tiling
* (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d)
* \param op The operation
*/
TVM_DLL bool IsStrictInlineable(const te::Operation& op) const;
TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const;

/*!
* \brief Return whether this operation is an output op
Expand All @@ -113,33 +121,30 @@ class AccessAnalyzer : public ObjectRef {
* \brief Get all consumers of on operation
* \param state The current loop state
* \param op The operation
* \param consumers The return consumer set
* \return The set of consumers
* \note This function propagates the relation for inlined ops
*/
TVM_DLL void GetConsumers(
const State& state, const te::Operation& op,
std::unordered_set<te::Operation, ObjectHash, ObjectEqual>* consumers) const;
TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetConsumers(
const State& state, const te::Operation& op) const;

/*!
* \brief Get all producers of on operation
* \param state The current loop state
* \param op The operation
* \param producers The return producer set
* \return The set of producers
* \note This function propagates the relation for inlined ops
*/
TVM_DLL void GetProducers(
const State& state, const te::Operation& op,
std::unordered_set<te::Operation, ObjectHash, ObjectEqual>* producers) const;
TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetProducers(
const State& state, const te::Operation& op) const;

/*!
* \brief Get all direct producers of on operation
* \param op The operation
* \param producers The return producer set
* \return The set of direct producers
* \note This function DOES NOT propagate the relation for inlined ops
*/
TVM_DLL void GetDirectProducers(
const te::Operation& op,
std::unordered_set<te::Operation, ObjectHash, ObjectEqual>* producers) const;
TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetDirectProducers(
const te::Operation& op) const;

/*!
* \brief Get the number of common outer iterators.
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/auto_scheduler/loop_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ using IterKey = std::pair<int, int>;
*/
class AttachMapNode : public Object {
public:
struct key_hash : public std::function<std::size_t(IterKey)> {
struct IterKeyHash {
std::size_t operator()(const IterKey& k) const {
return ::dmlc::HashCombine(std::hash<int>()(k.first), std::hash<int>()(k.second));
}
Expand All @@ -168,7 +168,7 @@ class AttachMapNode : public Object {
/*! \brief A Map to store the mapping of stage to its attached iterator. */
std::unordered_map<StageKey, IterKey> stage_to_attach_iter;
/*! \brief A Map to store the mapping of iterator to the stage attached to it. */
std::unordered_map<IterKey, std::vector<StageKey>, key_hash> iter_to_attached_stages;
std::unordered_map<IterKey, std::vector<StageKey>, IterKeyHash> iter_to_attached_stages;

static constexpr const char* _type_key = "auto_scheduler.AttachMap";
TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object);
Expand Down
1 change: 1 addition & 0 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import tvm
from .task import create
from .topi_integration import TaskExtractEnv
from .dispatcher import FallbackContext

logger = logging.getLogger('autotvm')

Expand Down
Loading

0 comments on commit b01fcf8

Please sign in to comment.