Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
jcf94 committed Jul 27, 2020
1 parent aeb8c7b commit e4408da
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
6 changes: 5 additions & 1 deletion python/tvm/auto_scheduler/loop_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,11 @@ def rfactor(self, stage, iterator, factor_iter_id):
self._resolve_stage_id(stage),
iterator, factor_iter_id,
self.compute_dag)
return self._insert_new_stage(int(new_stage_id))
# Add a new stage will change all ops behind the added stage. But we still want to keep the
# original ops map, apply stage id offset to stage_id_map to make them work.
self._apply_stage_id_offset(int(new_stage_id))
self._update_stage_id_map()
return self.stages[int(new_stage_id)].op

def copy(self):
""" Do deep copy of this State. """
Expand Down
15 changes: 5 additions & 10 deletions src/auto_scheduler/transform_step.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1476,16 +1476,10 @@ void RfactorStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
int RfactorStepNode::ApplyToState(State* state, const ComputeDAG& dag) const {
StateNode* pstate = state->CopyOnWrite();
const auto& compute_at_type = pstate->stages[stage_id]->compute_at;
Array<Step> replay_steps;
for (size_t i = 0; i < pstate->transform_steps.size(); ++i) {
AddStageModificationSteps(i, pstate->transform_steps, &replay_steps);
if (pstate->transform_steps[i].same_as(GetRef<Step>(this))) {
break;
}
}
const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(replay_steps);
const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(
GetFormerStageModifiableSteps(GetRef<Step>(this), (*state)->transform_steps));

// target -> target_compute + target
// target_stage -> target_compute + target_stage
// Should insert new stage, update target stage, update the later stage's op
pstate->stages.insert(pstate->stages.begin() + stage_id,
Stage(current_compute_dag->ops[stage_id]));
Expand All @@ -1499,7 +1493,8 @@ int RfactorStepNode::ApplyToState(State* state, const ComputeDAG& dag) const {
stage.CopyOnWrite()->op = current_compute_dag->ops[i];
pstate->stages.Set(i, std::move(stage));
}
pstate->attach_map = pstate->attach_map.ApplyStageIdOfffset(stage_id, 1);
pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(stage_id, 1);
pstate->current_compute_dag = std::move(current_compute_dag);

return stage_id;
}
Expand Down

0 comments on commit e4408da

Please sign in to comment.