Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoScheduler] New layout rewrite option: Weight pre-transpose #6750

Merged
merged 15 commits into from
Nov 2, 2020
32 changes: 26 additions & 6 deletions include/tvm/auto_scheduler/compute_dag.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,24 @@ class ComputeDAGNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object);
};

/*!
* \brief Options for applying layout rewrite.
* This is an optimization to rewrite the layout of input tensors according to the schedule we get.
*/
enum class LayoutRewriteOption : int {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enum class LayoutRewriteOption : uint8 should be enough.

/*! \brief Do not process layout rewrite. */
NoRewrite = 0,
/*! \brief Insert layout transformation stages for input placeholders in the compute DAG */
InsertTransformStage = 1,
/*!
* \brief Do not insert layout transformation stages and assume the input placeholders
* are pre-transformed.
* \note The lowered function with this option does not accept the origial input shapes,
* so this option must be used along with a layout conversion pass in Relay.
*/
RewriteForPreTransformed = 2,
};

/*!
* \brief Managed reference to ComputeDAGNode.
* \sa ComputeDAGNode
Expand All @@ -214,8 +232,10 @@ class ComputeDAG : public ObjectRef {
* \brief Rewrite the layout of placeholder specified by attr `layout_free_placeholders`
* according to the loop nest derived with `transform_steps`.
* \param transform_steps Transform steps of a state.
* \param layout_rewrite Different options in layout rewrite.
* \return The updated ComputeDAG after layout rewrite.
*/
void RewriteLayout(const Array<Step>& transform_steps);
ComputeDAG RewriteLayout(Array<Step>* transform_steps, LayoutRewriteOption layout_rewrite) const;

/*!
* \brief Apply the history transform steps to get a TVM schedule.
Expand All @@ -225,14 +245,14 @@ class ComputeDAG : public ObjectRef {
* \param stage_to_axes The map that stores all axes for one stage.
* Pass a valid pointer if this information needs to be used outside this function.
* \param layout_rewrite Rewrite the layout of placeholders specified by
* attr `layout_free_placeholders`
* attr `layout_free_placeholders`.
* \return A `te.schedule` and the an Array of `te.Tensor` to be used in `tvm.lower`
* or `tvm.build`.
*/
std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(const Array<Step>& transform_steps,
Array<te::Stage>* stages = nullptr,
StageToAxesMap* stage_to_axes = nullptr,
bool layout_rewrite = false) const;
std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(
const Array<Step>& transform_steps, Array<te::Stage>* stages = nullptr,
StageToAxesMap* stage_to_axes = nullptr,
LayoutRewriteOption layout_rewrite = LayoutRewriteOption::NoRewrite) const;

/*!
* \brief Print transform steps as equivalent python schedule API.
Expand Down
46 changes: 31 additions & 15 deletions include/tvm/auto_scheduler/transform_step.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,23 @@ class StepNode : public Object {
*/
class Step : public ObjectRef {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode);
/*!
* \brief CopyOnWrite function for Step.
* This works almost the same as a normal ObjectRef.CopyOnWrite(), but can dispatch to different
* steps.
* \return A base StepNode pointer, need to cast to its real StepNode type before doing any
* modifications.
* \code
*
* SplitStep ref;
* StepNode* mutable_ref = ref.CopyOnWrite();
* dynamic_cast<SplitStepNode*>(mutable_ref)->... = ...;
*
* \endcode
*/
StepNode* CopyOnWrite();

TVM_DEFINE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode);
};

// Forward declaration
Expand Down Expand Up @@ -267,7 +283,7 @@ class AnnotationStepNode : public StepNode {
static constexpr const char* record_prefix_str = "AN";

static constexpr const char* _type_key = "auto_scheduler.AnnotationStep";
TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -330,7 +346,7 @@ class FuseStepNode : public StepNode {
static constexpr const char* record_prefix_str = "FU";

static constexpr const char* _type_key = "auto_scheduler.FuseStep";
TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -390,7 +406,7 @@ class PragmaStepNode : public StepNode {
static constexpr const char* record_prefix_str = "PR";

static constexpr const char* _type_key = "auto_scheduler.PragmaStep";
TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -452,7 +468,7 @@ class ReorderStepNode : public StepNode {
static constexpr const char* record_prefix_str = "RE";

static constexpr const char* _type_key = "auto_scheduler.ReorderStep";
TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -527,7 +543,7 @@ class SplitStepNode : public StepNode {
static constexpr const char* record_prefix_str = "SP";

static constexpr const char* _type_key = "auto_scheduler.SplitStep";
TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -607,7 +623,7 @@ class FollowSplitStepNode : public StepNode {
static constexpr const char* record_prefix_str = "FSP";

static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -688,7 +704,7 @@ class FollowFusedSplitStepNode : public StepNode {
static constexpr const char* record_prefix_str = "FFSP";

static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep";
TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -754,7 +770,7 @@ class StorageAlignStepNode : public StepNode {
static constexpr const char* record_prefix_str = "SA";

static constexpr const char* _type_key = "auto_scheduler.StorageAlignStep";
TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -822,7 +838,7 @@ class ComputeAtStepNode : public StepNode {
static constexpr const char* record_prefix_str = "CA";

static constexpr const char* _type_key = "auto_scheduler.ComputeAtStep";
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -879,7 +895,7 @@ class ComputeInlineStepNode : public StepNode {
static constexpr const char* record_prefix_str = "CI";

static constexpr const char* _type_key = "auto_scheduler.ComputeInlineStep";
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -938,7 +954,7 @@ class ComputeRootStepNode : public StepNode {
static constexpr const char* record_prefix_str = "CR";

static constexpr const char* _type_key = "auto_scheduler.ComputeRootStep";
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -1010,7 +1026,7 @@ class CacheReadStepNode : public StepNode {
static constexpr const char* record_prefix_str = "CHR";

static constexpr const char* _type_key = "auto_scheduler.CacheReadStep";
TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -1081,7 +1097,7 @@ class CacheWriteStepNode : public StepNode {
static constexpr const char* record_prefix_str = "CHW";

static constexpr const char* _type_key = "auto_scheduler.CacheWriteStep";
TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -1148,7 +1164,7 @@ class RfactorStepNode : public StepNode {
static constexpr const char* record_prefix_str = "RF";

static constexpr const char* _type_key = "auto_scheduler.RfactorStep";
TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, StepNode);
};

/*!
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ class ComputeDAG(Object):
Input/output tensors or workload key for a compute declaration.
"""

# Layout Rewrite Options
NoRewrite = 0
InsertTransformStage = 1
RewriteForPreTransformed = 2

def __init__(self, compute_or_sche):
if isinstance(compute_or_sche, str):
compute = workload_key_to_tensors(compute_or_sche)
Expand Down Expand Up @@ -81,7 +86,7 @@ def get_init_state(self):
"""
return State(self.init_state, self)

def apply_steps_from_state(self, state, layout_rewrite=False):
def apply_steps_from_state(self, state, layout_rewrite=NoRewrite):
"""
Apply the history transform steps from a State to get a TVM schedule.

Expand Down
Loading