Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【CINN】move ir_copy from namespace optim to ir_utils #57582

Merged
merged 1 commit into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/cinn/auto_schedule/analysis/analyze_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ std::vector<ir::Var> IndicesToVars(const std::vector<ir::Expr>& indices) {
for (const ir::Expr& e : indices) {
// Whether we have to convert other types, like const numbers to Var?
if (e.As<ir::_Var_>() != nullptr) {
ir::Expr copy_e = optim::IRCopy(e);
ir::Expr copy_e = ir::ir_utils::IRCopy(e);
ir::_Var_* var_ref = copy_e.As<ir::_Var_>();
result.emplace_back(ir::Var(var_ref));
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/auto_schedule/cost_model/feature_extractor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ void FeatureExtractor::Visit(const For *x) {
}

void FeatureExtractor::Visit(const PolyFor *x) {
Expr copy = optim::IRCopy(Expr(x));
Expr copy = ir::ir_utils::IRCopy(Expr(x));
feature_.IntoLoopBlock();
optim::TransformPolyForToFor(&copy);
ir::For *loop = copy.As<For>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ ir::IRSchedule MakeIRSchedule(const std::vector<ir::LoweredFunc>& lowered_funcs,
const std::string& task_key) {
std::vector<Expr> exprs;
for (auto&& func : lowered_funcs) {
exprs.emplace_back(optim::IRCopy(func->body));
exprs.emplace_back(ir::ir_utils::IRCopy(func->body));
}
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
task_registry->Regist(task_key, ir::ModuleExpr(exprs));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ std::vector<SearchState> EvolutionarySearch::GetTopKCandidatesFromDatabase(
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
for (auto&& record : records) {
ir::IRSchedule ir_sch(
optim::IRCopy(task_registry->Get(task_key)->module_expr),
ir::ir_utils::IRCopy(task_registry->Get(task_key)->module_expr),
utils::ForkRandomState(&rand_seed_));
ir::ScheduleDesc::ReplayWithProto(record.trace, &ir_sch);
results.emplace_back(SearchState(std::move(ir_sch), record.predicted_cost));
Expand Down Expand Up @@ -181,9 +181,9 @@ SearchState EvolutionarySearch::CrossOver(const SearchState& state1,

for (size_t i = 0; i < father_exprs.size(); ++i) {
if (utils::SampleUniformInt(0, 2, &rand_seed_) == 0) {
cross_over_exprs.push_back(optim::IRCopy(father_exprs[i]));
cross_over_exprs.push_back(ir::ir_utils::IRCopy(father_exprs[i]));
} else {
cross_over_exprs.push_back(optim::IRCopy(mother_exprs[i]));
cross_over_exprs.push_back(ir::ir_utils::IRCopy(mother_exprs[i]));
}
}
auto res = SearchState(ir::IRSchedule(ir::ModuleExpr(cross_over_exprs),
Expand Down Expand Up @@ -217,7 +217,7 @@ SearchState EvolutionarySearch::Mutate(
const auto& task_key = tune_task_.serialized_key;
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
ir::IRSchedule new_ir_sch(
optim::IRCopy(task_registry->Get(task_key)->module_expr),
ir::ir_utils::IRCopy(task_registry->Get(task_key)->module_expr),
utils::ForkRandomState(rand_seed));
new_trace.Replay(&new_ir_sch, true);
ApplyPostScheduleRules(&new_ir_sch, post_schedule_rules_);
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/auto_schedule/task/task_optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ TaskOptimizer::Result TaskOptimizer::OptimizeByEvolution(
auto& optimized_funcs = result.functions;
auto& best_cost = result.cost;
// use initial lowered function as default result
optimized_funcs = optim::IRCopy(task_->lowered_funcs);
optimized_funcs = ir::ir_utils::IRCopy(task_->lowered_funcs);
if (options.num_measure_trials ==
0) { // no need to measure and simply return the best searched
std::vector<MeasureInput> measure_candidates;
Expand Down Expand Up @@ -347,7 +347,7 @@ std::vector<SearchState> TaskOptimizer::SearchOneRound(
CHECK_EQ(best_exprs.size(), task_->lowered_funcs.size())
<< "RuntimeError: Expr size is not equal to LoweredFunc size in "
"TaskOptimizer";
auto init_funcs = optim::IRCopy(task_->lowered_funcs);
auto init_funcs = ir::ir_utils::IRCopy(task_->lowered_funcs);
std::vector<ir::LoweredFunc> valid_funcs;
for (size_t j = 0; j < best_exprs.size(); ++j) {
auto updated_f =
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/auto_schedule/task/task_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class InitialTaskRegistry : public Registry<InitialTaskInfo> {
std::lock_guard<std::mutex> guard(registering_mutex);
if (fmap_.count(task_key) == 0) {
InitialTaskInfo* task_info =
new InitialTaskInfo(task_key, optim::IRCopy(module_expr));
new InitialTaskInfo(task_key, ir::ir_utils::IRCopy(module_expr));
__REGISTER__(task_key, task_info);
}
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/backends/codegen_cuda_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
}

Expr CreateDeviceFunctionGivenDeviceKernel(Expr expr) {
auto copied = optim::IRCopy(expr);
auto copied = ir::ir_utils::IRCopy(expr);
auto* lowered_func = copied.as_lowered_func();
lowered_func->name = GenDeviceKernelName(lowered_func->name);
return copied;
Expand Down
16 changes: 8 additions & 8 deletions paddle/cinn/common/cas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1584,7 +1584,7 @@ bool CASasSymbol(Expr expr) {

Expr ConvertCinnToCAS(Expr expr) {
VLOG(7) << "Begin ConvertCinnToCAS " << expr;
Expr copied = optim::IRCopy(expr);
Expr copied = ir::ir_utils::IRCopy(expr);
struct Mutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* expr) { Visit(expr); }
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
Expand Down Expand Up @@ -1710,7 +1710,7 @@ Expr ConvertCinnToCAS(Expr expr) {
* simplify the condition ensures correctness, though not sufficient.
*/
Expr ReplaceMinToConstant(Expr expr) {
Expr copied = optim::IRCopy(expr);
Expr copied = ir::ir_utils::IRCopy(expr);
struct Mutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* expr) { Visit(expr); }
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
Expand All @@ -1727,10 +1727,10 @@ Expr ReplaceMinToConstant(Expr expr) {
auto min_b = op->b();
if (min_a.is_constant() && !min_b.is_constant()) {
CHECK(min_a->type().is_integer());
*expr = optim::IRCopy(min_a);
*expr = ir::ir_utils::IRCopy(min_a);
} else if (min_b.is_constant() && !min_a.is_constant()) {
CHECK(min_b->type().is_integer());
*expr = optim::IRCopy(min_b);
*expr = ir::ir_utils::IRCopy(min_b);
}
}
};
Expand All @@ -1743,7 +1743,7 @@ Expr ReplaceMinToConstant(Expr expr) {
* constant value and 1 inconstant value, return the constant max value.
*/
Expr ReplaceMaxToConstant(Expr expr) {
Expr copied = optim::IRCopy(expr);
Expr copied = ir::ir_utils::IRCopy(expr);
struct Mutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* expr) { Visit(expr); }
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
Expand All @@ -1760,10 +1760,10 @@ Expr ReplaceMaxToConstant(Expr expr) {
auto max_b = op->b();
if (max_a.is_constant() && !max_b.is_constant()) {
CHECK(max_a->type().is_integer());
*expr = optim::IRCopy(max_a);
*expr = ir::ir_utils::IRCopy(max_a);
} else if (max_b.is_constant() && !max_a.is_constant()) {
CHECK(max_b->type().is_integer());
*expr = optim::IRCopy(max_b);
*expr = ir::ir_utils::IRCopy(max_b);
}
}
};
Expand All @@ -1773,7 +1773,7 @@ Expr ReplaceMaxToConstant(Expr expr) {

Expr ConvertCasToCinn(Expr expr) {
VLOG(7) << "Begin ConvertCasToCinn : " << expr;
Expr copied = optim::IRCopy(expr);
Expr copied = ir::ir_utils::IRCopy(expr);

struct Mutator : ir::IRMutator<Expr*> {
void operator()(Expr* expr) { Visit(expr); }
Expand Down
43 changes: 22 additions & 21 deletions paddle/cinn/ir/schedule/ir_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
new_loop_vars.push_back(temp_var);
}
substitute_value = common::AutoSimplify(substitute_value);
Expr new_node = optim::IRCopy(for_node->body);
Expr new_node = ir::ir_utils::IRCopy(for_node->body);
ReplaceExpr(&new_node, {for_node->loop_var}, {substitute_value});
std::vector<Expr> splited_loops;
splited_loops.resize(processed_factors.size());
Expand Down Expand Up @@ -252,7 +252,7 @@ Expr ScheduleImpl::Fuse(const std::vector<Expr>& loops) {
}
substitute_value[0] = fused_expr;

Expr fused_body = optim::IRCopy(for_nodes.back()->body);
Expr fused_body = ir::ir_utils::IRCopy(for_nodes.back()->body);
ReplaceExpr(&fused_body, loop_vars, substitute_value);
optim::Simplify(&fused_body);
Expr fused_extent(1);
Expand Down Expand Up @@ -321,7 +321,7 @@ void ScheduleImpl::MutateForType(const Expr& loop,
<< "loop is not serial, current forloop type is "
<< static_cast<int>(for_node->for_type()) << ", and it cannot become "
<< static_cast<int>(for_type);
auto loop_copy = optim::IRCopy(loop);
auto loop_copy = ir::ir_utils::IRCopy(loop);
auto* new_for_node = loop_copy.As<ir::For>();
CHECK(new_for_node);
new_for_node->set_for_type(for_type);
Expand Down Expand Up @@ -674,7 +674,7 @@ struct RfCreater : public ir::IRMutator<> {
CHECK(root_realize);
auto root_block = root_realize->schedule_block.As<ScheduleBlock>();
CHECK(root_block);
Expr root_loop = optim::IRCopy(root_block->body);
Expr root_loop = ir::ir_utils::IRCopy(root_block->body);
if (auto block = root_loop.As<Block>()) {
CHECK_EQ(block->stmts.size(), 1U)
<< "rfactor root should only have one block stmt";
Expand All @@ -685,13 +685,13 @@ struct RfCreater : public ir::IRMutator<> {
auto rf_for = rf_loop_.As<For>();
CHECK(rf_for);
// create new rfactor forloops
Expr new_rf_forloop = optim::IRCopy(root_loop);
Expr new_rf_forloop = ir::ir_utils::IRCopy(root_loop);
RfMutator rf_mutator(rf_loop_, rf_axis_);
rf_mutator(&new_rf_forloop);
VLOG(3) << "After RfMutator, new rf_forloop is\n" << new_rf_forloop;
auto new_rf_tensor = rf_mutator.GetNewRfTensor();
// create final write-back forloops
Expr final_forloop = optim::IRCopy(root_loop);
Expr final_forloop = ir::ir_utils::IRCopy(root_loop);
FinalMutator final_mutator(rf_loop_, rf_axis_, new_rf_tensor);
final_mutator(&final_forloop);
VLOG(3) << "After FinalMuator, final write-back forloop is\n"
Expand Down Expand Up @@ -721,7 +721,7 @@ struct CacheReadRewriter : public ir::IRMutator<> {
public:
static Expr Rewrite(const Expr& root, CacheBlockInfo* info) {
CacheReadRewriter rewriter(root, info);
Expr new_root = optim::IRCopy(root);
Expr new_root = ir::ir_utils::IRCopy(root);
rewriter(&new_root);
return new_root;
}
Expand Down Expand Up @@ -762,7 +762,7 @@ struct CacheWriteRewriter : public ir::IRMutator<> {
public:
static Expr Rewrite(const Expr& root, CacheBlockInfo* info) {
CacheWriteRewriter rewriter(root, info);
Expr new_root = optim::IRCopy(root);
Expr new_root = ir::ir_utils::IRCopy(root);
rewriter.mutate_cache_block = true;
rewriter(&info->cache_block);
rewriter.mutate_cache_block = false;
Expand Down Expand Up @@ -1194,7 +1194,7 @@ struct LoopReconstructor : public ir::IRMutator<> {
loop_.As<ir::For>()->device_api,
std::move(loop_body));
}
new_loop_ = optim::IRCopy(loop_);
new_loop_ = ir::ir_utils::IRCopy(loop_);

// Replace the copied Tensor object with the original Tensor object,
// to ensure that the same Tensor in a AST is the same object.
Expand Down Expand Up @@ -1431,9 +1431,9 @@ void ScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) {
}

Expr result = loops.size() < block_loops.size()
? optim::IRCopy(block_loops[loops.size()])
: optim::IRCopy(this_block);
Expr new_loop = optim::IRCopy(this_loop);
? ir::ir_utils::IRCopy(block_loops[loops.size()])
: ir::ir_utils::IRCopy(this_block);
Expr new_loop = ir::ir_utils::IRCopy(this_loop);

// Get the body of block_loop under the same loops
auto body = block_loops.at(loops.size() - 1).As<ir::For>()->body;
Expand Down Expand Up @@ -1608,7 +1608,7 @@ void ComputeInliner::Visit(const ir::Load* expr, Expr* op) {
Expr ComputeInliner::ReplaceInlinedTensor(Expr* load) {
CHECK(load->As<ir::Load>());
SetIndexSubstitution(load->As<ir::Load>()->indices);
Expr value_copy = optim::IRCopy(inlined_store_.As<Store>()->value);
Expr value_copy = ir::ir_utils::IRCopy(inlined_store_.As<Store>()->value);
ReplaceExpr(&value_copy, idx_sub_var_, idx_sub_expr_);
return value_copy;
}
Expand Down Expand Up @@ -1684,7 +1684,7 @@ void ReverseComputeInliner::Visit(const ir::Store* expr, Expr* op) {
Expr ReverseComputeInliner::ReplaceInlinedTensor(Expr* load) {
CHECK(load->As<ir::Load>());
SetIndexSubstitution(load->As<ir::Load>()->indices);
Expr value_copy = optim::IRCopy(inlined_store_.As<Store>()->value);
Expr value_copy = ir::ir_utils::IRCopy(inlined_store_.As<Store>()->value);
return value_copy;
}

Expand All @@ -1699,7 +1699,7 @@ Expr ReverseComputeInliner::ReplaceTargetTensor(Expr* store) {
idx_sub_expr_.emplace_back(idx_vars_[i]);
}

Expr value_copy = optim::IRCopy(target_store_);
Expr value_copy = ir::ir_utils::IRCopy(target_store_);
ReplaceExpr(&value_copy, idx_sub_var_, idx_sub_expr_);
return value_copy;
}
Expand Down Expand Up @@ -1936,7 +1936,7 @@ void ScheduleImpl::Annotate(const Expr& block,
CHECK(block.As<ir::ScheduleBlockRealize>());
CHECK(block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>());
auto copied_block = optim::IRCopy(block);
auto copied_block = ir::ir_utils::IRCopy(block);
auto* schedule_block = copied_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>();
schedule_block->attrs.emplace(key, value);
Expand Down Expand Up @@ -2195,7 +2195,7 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
}
CHECK(!used_target_loop_vars.empty());
std::vector<Expr> used_target_loops;
auto expr_copy = optim::IRCopy(expr);
auto expr_copy = ir::ir_utils::IRCopy(expr);
for (auto& var : used_target_loop_vars) {
auto find_loop_var = ir::ir_utils::CollectIRNodesWithoutTensor(
expr_copy,
Expand All @@ -2220,7 +2220,7 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
VLOG(3) << "changed_loop_num is : " << changed_loop_num;
VLOG(3) << "old_iter_values.size() is : " << old_iter_values.size();
if (changed_loop_num >= static_cast<int>(old_iter_values.size())) {
new_loop = optim::IRCopy(block);
new_loop = ir::ir_utils::IRCopy(block);
new_loop.As<ir::ScheduleBlockRealize>()->iter_values = new_iter_values;
} else {
CHECK(old_iter_values[changed_loop_num].as_var());
Expand All @@ -2234,7 +2234,7 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
},
true);
CHECK_EQ(find_partial_loop.size(), 1U);
new_loop = optim::IRCopy(*find_partial_loop.begin());
new_loop = ir::ir_utils::IRCopy(*find_partial_loop.begin());
auto find_schedule_block = ir::ir_utils::CollectIRNodesWithoutTensor(
new_loop,
[&](const Expr* x) { return x->As<ir::ScheduleBlockRealize>(); },
Expand Down Expand Up @@ -2332,13 +2332,14 @@ IRSchedule::IRSchedule(ir::ModuleExpr&& mod_expr,
}

IRSchedule::IRSchedule(const IRSchedule& other)
: impl_(std::make_unique<ScheduleImpl>(optim::IRCopy(other.GetModule()))),
: impl_(std::make_unique<ScheduleImpl>(
ir::ir_utils::IRCopy(other.GetModule()))),
trace_(other.trace_) {
this->InitSeed(other.ForkSeed());
}

IRSchedule& IRSchedule::operator=(const IRSchedule& src) {
impl_ = std::make_unique<ScheduleImpl>(optim::IRCopy(src.GetModule()));
impl_ = std::make_unique<ScheduleImpl>(ir::ir_utils::IRCopy(src.GetModule()));
trace_ = src.trace_;
this->InitSeed(src.ForkSeed());
return *this;
Expand Down
18 changes: 9 additions & 9 deletions paddle/cinn/ir/schedule/ir_schedule_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,8 @@ IterRange GetAccessedRange(const Expr& index,
var_maxs.emplace_back(range.min + range.extent - 1);
}

Expr indice_min = optim::IRCopy(index);
Expr indice_max = optim::IRCopy(index);
Expr indice_min = ir::ir_utils::IRCopy(index);
Expr indice_max = ir::ir_utils::IRCopy(index);
// replace the var by the corresponding iter_value
ReplaceExpr(&indice_min, iter_vars, var_mins);
ReplaceExpr(&indice_max, iter_vars, var_maxs);
Expand Down Expand Up @@ -408,7 +408,7 @@ std::vector<IterRange> CalculateTensorRegions(

std::vector<IterRange> result;
for (int i = 0; i < tensor_indices.size(); ++i) {
Expr binded_index = optim::IRCopy(tensor_indices[i]);
Expr binded_index = ir::ir_utils::IRCopy(tensor_indices[i]);
ReplaceExpr(&binded_index, iter_vars, iter_values);
auto range = GetAccessedRange(binded_index, loop_vars, loop_ranges);

Expand Down Expand Up @@ -656,7 +656,7 @@ Expr ConstructOtherStmtChain(const std::vector<Expr>& stmts,
const std::vector<int> reordered_indices) {
Expr new_loop;
for (int i = reordered_indices.size() - 1; i >= 0; --i) {
Expr temp = optim::IRCopy(loops[reordered_indices[i]]);
Expr temp = ir::ir_utils::IRCopy(loops[reordered_indices[i]]);
CHECK(temp.defined());
CHECK(temp.As<ir::For>());
if (new_loop.defined()) {
Expand Down Expand Up @@ -695,10 +695,10 @@ Expr ConstructNewLoopChain(const std::vector<Expr>& chain,
Expr temp;
if (loop_set.count(loop_in_chain)) {
CHECK_GE(index, 0);
temp = optim::IRCopy(ordered_loops[index]);
temp = ir::ir_utils::IRCopy(ordered_loops[index]);
--index;
} else {
temp = optim::IRCopy(loop_in_chain);
temp = ir::ir_utils::IRCopy(loop_in_chain);
}
CHECK(temp.defined());
CHECK(temp.As<ir::For>());
Expand Down Expand Up @@ -1029,9 +1029,9 @@ std::vector<IterRange> CalculateRequiredRegions(
for (const Expr& req_block : required_blocks) {
CHECK(req_block.As<ir::ScheduleBlockRealize>());
Expr block_body =
optim::IRCopy(req_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->body);
ir::ir_utils::IRCopy(req_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->body);
auto iter_vars = req_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->iter_vars;
Expand Down
7 changes: 4 additions & 3 deletions paddle/cinn/ir/test/ir_copy_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
#include "paddle/cinn/ir/utils/ir_printer.h"

namespace cinn {
namespace optim {
namespace ir {
namespace ir_utils {

TEST(IrCopy, basic) {
Expr a(1.f);
auto aa = IRCopy(a);
LOG(INFO) << "aa " << aa;
}

} // namespace optim
} // namespace ir_utils
} // namespace ir
} // namespace cinn
Loading