-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps #6142
Conversation
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> Conflicts: src/auto_scheduler/compute_dag.cc src/auto_scheduler/transform_step.cc src/auto_scheduler/transform_step.h tests/python/unittest/test_auto_scheduler_loop_state.py
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
1. delete a comment 2. add "fuse" between follow_split and follow_fused_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> Conflicts: include/tvm/auto_scheduler/loop_state.h include/tvm/auto_scheduler/transform_step.h src/auto_scheduler/compute_dag.cc src/auto_scheduler/compute_dag.h src/auto_scheduler/loop_state.cc src/auto_scheduler/transform_step.cc tests/python/unittest/test_auto_scheduler_loop_state.py tests/python/unittest/test_auto_scheduler_measure.py
Hi, all. This is an student intern of us, who is now helping us with the Ansor upstreaming. 😄 The follow_split & follow_fused_split are two steps extent to FollowSplitThis is mainly used in stage fusion using compute at. FollowFusedSplitThis is mainly used in GPU cooperative fetching.
In Ansor's search policy, the outer stage has been tiled. The the threadIdx.x axis is binded to a iterator generated by split & fuse step. |
Thanks @jiuqi-yang for the nice work, @merrymercy @tqchen @FrozenGene @comaniac , would you please take a look at this PR? We are trying to accelerate the auto-schedule upstreaming process. |
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some nitpicking comments.
@@ -136,6 +144,10 @@ void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxes | |||
ps->ApplyToSchedule(stages, stage_to_axes); | |||
} else if (auto ps = step.as<SplitStepNode>()) { | |||
ps->ApplyToSchedule(stages, stage_to_axes); | |||
} else if (auto ps = step.as<FollowSplitStepNode>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { | ||
writer->WriteArraySeperator(); | ||
writer->WriteString(record_prefix_str); | ||
writer->WriteArrayItem(stage_id); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What will happen if the order of writing is changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The order here corresponds to the read order defined in the constructor of this step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Others look good to me. Just update these descriptions to follow the other functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine to me, I'll rebase the #6141 after this has been merged.
…t steps (apache#6142) * Add cache_read/cache_write step * Update * Add follow split and follow fused split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> Conflicts: src/auto_scheduler/compute_dag.cc src/auto_scheduler/transform_step.cc src/auto_scheduler/transform_step.h tests/python/unittest/test_auto_scheduler_loop_state.py * add loop_state.py Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update * Update * Update state->current_compute_dag to Optional * Add some doc strings for Follow_Split and Follow_fused_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Check code using c-lint Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add more doc strings and change the order for follow split. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add record test for follow_split and follow_fused_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add record test for follow_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add record test for follow_fused_split. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add test record for follow_fused_split 1. delete a comment 2. add "fuse" between follow_split and follow_fused_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add doc strings for some functions and variables Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Fix the code format in src/auto_scheduler/transform_step.h Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update * Update doc * Update * Update * Fix follow_split and follow_fused_split record test. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Doc update * Update some doc strings Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Fix code style and some function definitions. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add comments on parameters. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add more doc strings and fix some. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> Co-authored-by: chengfan.jcf <chengfan.jcf@alibaba-inc.com> Co-authored-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
…t steps (apache#6142) * Add cache_read/cache_write step * Update * Add follow split and follow fused split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> Conflicts: src/auto_scheduler/compute_dag.cc src/auto_scheduler/transform_step.cc src/auto_scheduler/transform_step.h tests/python/unittest/test_auto_scheduler_loop_state.py * add loop_state.py Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update * Update * Update state->current_compute_dag to Optional * Add some doc strings for Follow_Split and Follow_fused_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Check code using c-lint Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add more doc strings and change the order for follow split. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add record test for follow_split and follow_fused_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add record test for follow_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add record test for follow_fused_split. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add test record for follow_fused_split 1. delete a comment 2. add "fuse" between follow_split and follow_fused_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add doc strings for some functions and variables Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Fix the code format in src/auto_scheduler/transform_step.h Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update * Update doc * Update * Update * Fix follow_split and follow_fused_split record test. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Doc update * Update some doc strings Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Fix code style and some function definitions. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add comments on parameters. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add more doc strings and fix some. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> Co-authored-by: chengfan.jcf <chengfan.jcf@alibaba-inc.com> Co-authored-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
…t steps (apache#6142) * Add cache_read/cache_write step * Update * Add follow split and follow fused split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> Conflicts: src/auto_scheduler/compute_dag.cc src/auto_scheduler/transform_step.cc src/auto_scheduler/transform_step.h tests/python/unittest/test_auto_scheduler_loop_state.py * add loop_state.py Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update * Update * Update state->current_compute_dag to Optional * Add some doc strings for Follow_Split and Follow_fused_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Check code using c-lint Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add more doc strings and change the order for follow split. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add record test for follow_split and follow_fused_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add record test for follow_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add record test for follow_fused_split. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add test record for follow_fused_split 1. delete a comment 2. add "fuse" between follow_split and follow_fused_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add doc strings for some functions and variables Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Fix the code format in src/auto_scheduler/transform_step.h Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update * Update doc * Update * Update * Fix follow_split and follow_fused_split record test. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Doc update * Update some doc strings Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Fix code style and some function definitions. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add comments on parameters. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add more doc strings and fix some. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> Co-authored-by: chengfan.jcf <chengfan.jcf@alibaba-inc.com> Co-authored-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
…t steps (apache#6142) * Add cache_read/cache_write step * Update * Add follow split and follow fused split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> Conflicts: src/auto_scheduler/compute_dag.cc src/auto_scheduler/transform_step.cc src/auto_scheduler/transform_step.h tests/python/unittest/test_auto_scheduler_loop_state.py * add loop_state.py Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update * Update * Update state->current_compute_dag to Optional * Add some doc strings for Follow_Split and Follow_fused_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Check code using c-lint Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add more doc strings and change the order for follow split. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add record test for follow_split and follow_fused_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add record test for follow_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add record test for follow_fused_split. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add test record for follow_fused_split 1. delete a comment 2. add "fuse" between follow_split and follow_fused_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add doc strings for some functions and variables Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Fix the code format in src/auto_scheduler/transform_step.h Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update * Update doc * Update * Update * Fix follow_split and follow_fused_split record test. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Doc update * Update some doc strings Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Fix code style and some function definitions. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add comments on parameters. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add more doc strings and fix some. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> Co-authored-by: chengfan.jcf <chengfan.jcf@alibaba-inc.com> Co-authored-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
…t steps (apache#6142) * Add cache_read/cache_write step * Update * Add follow split and follow fused split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> Conflicts: src/auto_scheduler/compute_dag.cc src/auto_scheduler/transform_step.cc src/auto_scheduler/transform_step.h tests/python/unittest/test_auto_scheduler_loop_state.py * add loop_state.py Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update * Update * Update state->current_compute_dag to Optional * Add some doc strings for Follow_Split and Follow_fused_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Check code using c-lint Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add more doc strings and change the order for follow split. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add record test for follow_split and follow_fused_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add record test for follow_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add record test for follow_fused_split. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add test record for follow_fused_split 1. delete a comment 2. add "fuse" between follow_split and follow_fused_split Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add doc strings for some functions and variables Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Fix the code format in src/auto_scheduler/transform_step.h Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update * Update doc * Update * Update * Fix follow_split and follow_fused_split record test. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Doc update * Update some doc strings Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Fix code style and some function definitions. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add comments on parameters. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Add more doc strings and fix some. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> * Update. Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com> Co-authored-by: chengfan.jcf <chengfan.jcf@alibaba-inc.com> Co-authored-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
For the full upstream plan, see Ansor RFC.
In this PR, we bring follow split and follow fused split steps for Ansor auto_scheduler.
cc @merrymercy @comaniac @junrushao1994 @FrozenGene @jroesch