Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR][refactor] Convert loop_var into LoopIndexStmt #953

Merged
merged 22 commits into from
May 14, 2020
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions taichi/analysis/verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,44 @@ class IRVerifier : public BasicStmtVisitor {
TI_ASSERT(stmt->ptr->is<AllocaStmt>());
}

void visit(LoopIndexStmt *stmt) override {
basic_verify(stmt);
TI_ASSERT(stmt->loop);
if (stmt->loop->is<OffloadedStmt>()) {
TI_ASSERT(stmt->loop->as<OffloadedStmt>()->task_type ==
OffloadedStmt::TaskType::struct_for ||
stmt->loop->as<OffloadedStmt>()->task_type ==
OffloadedStmt::TaskType::range_for);
} else {
TI_ASSERT(stmt->loop->is<StructForStmt>() ||
stmt->loop->is<RangeForStmt>());
}
}

void visit(RangeForStmt *for_stmt) override {
basic_verify(for_stmt);
TI_ASSERT(for_stmt->loop_var->is<AllocaStmt>());
TI_ASSERT_INFO(irpass::analysis::gather_statements(
for_stmt->loop_var->parent,
[&](Stmt *s) {
if (auto store = s->cast<LocalStoreStmt>())
return store->ptr == for_stmt->loop_var;
else if (auto atomic = s->cast<AtomicOpStmt>()) {
return atomic->dest == for_stmt->loop_var;
} else {
return false;
}
})
.empty(),
"loop_var of {} modified", for_stmt->id);
if (for_stmt->loop_var) {
TI_ASSERT(for_stmt->loop_var->is<AllocaStmt>());
TI_ASSERT_INFO(irpass::analysis::gather_statements(
for_stmt->loop_var->parent,
[&](Stmt *s) {
if (auto store = s->cast<LocalStoreStmt>())
return store->ptr == for_stmt->loop_var;
else if (auto atomic = s->cast<AtomicOpStmt>()) {
return atomic->dest == for_stmt->loop_var;
} else {
return false;
}
}).empty(), "loop_var of {} modified", for_stmt->id);
}
for_stmt->body->accept(this);
}

void visit(StructForStmt *for_stmt) override {
basic_verify(for_stmt);
for (auto &loop_var : for_stmt->loop_vars) {
if (!loop_var)
continue;
TI_ASSERT(loop_var->is<AllocaStmt>());
TI_ASSERT_INFO(irpass::analysis::gather_statements(
loop_var->parent,
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/cpu/codegen_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class CodeGenLLVMCPU : public CodeGenLLVM {
tlctx->get_data_type<int>()});

auto loop_var = create_entry_block_alloca(DataType::i32);
offloaded_loop_vars_llvm[stmt].push_back(loop_var);
loop_vars_llvm[stmt].push_back(loop_var);
builder->CreateStore(get_arg(1), loop_var);
stmt->body->accept(this);

Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
tlctx->get_data_type<int>()});

auto loop_var = create_entry_block_alloca(DataType::i32);
offloaded_loop_vars_llvm[stmt].push_back(loop_var);
loop_vars_llvm[stmt].push_back(loop_var);
builder->CreateStore(get_arg(1), loop_var);
stmt->body->accept(this);

Expand Down
56 changes: 27 additions & 29 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,15 +309,22 @@ class KernelCodegen : public IRVisitor {
}

void visit(LoopIndexStmt *stmt) override {
using TaskType = OffloadedStmt::TaskType;
const auto type = current_kernel_attribs_->task_type;
const auto stmt_name = stmt->raw_name();
if (type == TaskType::range_for) {
if (stmt->loop->is<OffloadedStmt>()) {
using TaskType = OffloadedStmt::TaskType;
const auto type = stmt->loop->as<OffloadedStmt>()->task_type;
if (type == TaskType::range_for) {
TI_ASSERT(stmt->index == 0);
emit("const int {} = {};", stmt_name, kLinearLoopIndexName);
} else if (type == TaskType::struct_for) {
emit("const int {} = {}.coords[{}];", stmt_name, kListgenElemVarName,
stmt->index);
} else {
TI_NOT_IMPLEMENTED;
}
} else if (stmt->loop->is<RangeForStmt>()) {
TI_ASSERT(stmt->index == 0);
emit("const int {} = {};", stmt_name, kLinearLoopIndexName);
} else if (type == TaskType::struct_for) {
emit("const int {} = {}.coords[{}];", stmt_name, kListgenElemVarName,
stmt->index);
emit("const int {} = {};", stmt_name, stmt->loop->raw_name());
} else {
TI_NOT_IMPLEMENTED;
}
Expand Down Expand Up @@ -445,29 +452,20 @@ class KernelCodegen : public IRVisitor {

void visit(RangeForStmt *for_stmt) override {
TI_ASSERT(for_stmt->width() == 1);
auto *loop_var = for_stmt->loop_var;
if (loop_var->ret_type.data_type == DataType::i32) {
if (!for_stmt->reversed) {
emit("for (int {}_ = {}; {}_ < {}; {}_ = {}_ + {}) {{",
loop_var->raw_name(), for_stmt->begin->raw_name(),
loop_var->raw_name(), for_stmt->end->raw_name(),
loop_var->raw_name(), loop_var->raw_name(), 1);
emit(" int {} = {}_;", loop_var->raw_name(), loop_var->raw_name());
} else {
// reversed for loop
emit("for (int {}_ = {} - 1; {}_ >= {}; {}_ = {}_ - {}) {{",
loop_var->raw_name(), for_stmt->end->raw_name(),
loop_var->raw_name(), for_stmt->begin->raw_name(),
loop_var->raw_name(), loop_var->raw_name(), 1);
emit(" int {} = {}_;", loop_var->raw_name(), loop_var->raw_name());
}
auto loop_var_name = for_stmt->raw_name();
if (!for_stmt->reversed) {
emit("for (int {}_ = {}; {}_ < {}; {}_ = {}_ + {}) {{",
loop_var_name, for_stmt->begin->raw_name(),
loop_var_name, for_stmt->end->raw_name(),
loop_var_name, loop_var_name, 1);
emit(" int {} = {}_;", loop_var_name, loop_var_name);
} else {
TI_ASSERT(!for_stmt->reversed);
const auto type_name = metal_data_type_name(loop_var->element_type());
emit("for ({} {} = {}; {} < {}; {} = {} + ({})1) {{", type_name,
loop_var->raw_name(), for_stmt->begin->raw_name(),
loop_var->raw_name(), for_stmt->end->raw_name(),
loop_var->raw_name(), loop_var->raw_name(), type_name);
// reversed for loop
emit("for (int {}_ = {} - 1; {}_ >= {}; {}_ = {}_ - {}) {{",
loop_var_name, for_stmt->end->raw_name(),
loop_var_name, for_stmt->begin->raw_name(),
loop_var_name, loop_var_name, 1);
emit(" int {} = {}_;", loop_var_name, loop_var_name);
}
for_stmt->body->accept(this);
emit("}}");
Expand Down
48 changes: 22 additions & 26 deletions taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -602,38 +602,34 @@ class KernelGen : public IRVisitor {
}

void visit(LoopIndexStmt *stmt) override {
TI_ASSERT(!stmt->is_struct_for);
TI_ASSERT(stmt->index == 0); // TODO: multiple indices
emit("int {} = _itv;", stmt->short_name());
if (stmt->loop->is<OffloadedStmt>()) {
TI_ASSERT(stmt->loop->as<OffloadedStmt>()->task_type ==
OffloadedStmt::TaskType::range_for);
emit("int {} = _itv;", stmt->short_name());
} else if (stmt->loop->is<RangeForStmt>()) {
emit("int {} = {};", stmt->short_name(), stmt->loop->short_name());
} else {
TI_NOT_IMPLEMENTED
}
}

void visit(RangeForStmt *for_stmt) override {
TI_ASSERT(for_stmt->width() == 1);
auto *loop_var = for_stmt->loop_var;
if (loop_var->ret_type.data_type == DataType::i32) {
if (!for_stmt->reversed) {
emit("for (int {}_ = {}; {}_ < {}; {}_ = {}_ + {}) {{",
loop_var->short_name(), for_stmt->begin->short_name(),
loop_var->short_name(), for_stmt->end->short_name(),
loop_var->short_name(), loop_var->short_name(), 1);
// variable named `loop_var->short_name()` is already allocated by
// alloca
emit(" {} = {}_;", loop_var->short_name(), loop_var->short_name());
} else {
// reversed for loop
emit("for (int {}_ = {} - 1; {}_ >= {}; {}_ = {}_ - {}) {{",
loop_var->short_name(), for_stmt->end->short_name(),
loop_var->short_name(), for_stmt->begin->short_name(),
loop_var->short_name(), loop_var->short_name(), 1);
emit(" {} = {}_;", loop_var->short_name(), loop_var->short_name());
}
auto loop_var_name = for_stmt->short_name();
if (!for_stmt->reversed) {
emit("for (int {}_ = {}; {}_ < {}; {}_ = {}_ + {}) {{",
loop_var_name, for_stmt->begin->short_name(),
loop_var_name, for_stmt->end->short_name(),
loop_var_name, loop_var_name, 1);
emit(" int {} = {}_;", loop_var_name, loop_var_name);
} else {
TI_ASSERT(!for_stmt->reversed);
const auto type_name = opengl_data_type_name(loop_var->element_type());
emit("for ({} {} = {}; {} < {}; {} = {} + 1) {{", type_name,
loop_var->short_name(), for_stmt->begin->short_name(),
loop_var->short_name(), for_stmt->end->short_name(),
loop_var->short_name(), loop_var->short_name());
// reversed for loop
emit("for (int {}_ = {} - 1; {}_ >= {}; {}_ = {}_ - {}) {{",
loop_var_name, for_stmt->end->short_name(),
loop_var_name, for_stmt->begin->short_name(),
loop_var_name, loop_var_name, 1);
emit(" int {} = {}_;", loop_var_name, loop_var_name);
}
for_stmt->body->accept(this);
emit("}}");
Expand Down
23 changes: 14 additions & 9 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -790,13 +790,16 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
BasicBlock *after_loop = BasicBlock::Create(*llvm_context, "after_for", func);
BasicBlock *loop_test =
BasicBlock::Create(*llvm_context, "for_loop_test", func);

auto loop_var = create_entry_block_alloca(DataType::i32);
loop_vars_llvm[for_stmt].push_back(loop_var);

if (!for_stmt->reversed) {
builder->CreateStore(llvm_val[for_stmt->begin],
llvm_val[for_stmt->loop_var]);
builder->CreateStore(llvm_val[for_stmt->begin], loop_var);
} else {
builder->CreateStore(
builder->CreateSub(llvm_val[for_stmt->end], tlctx->get_constant(1)),
llvm_val[for_stmt->loop_var]);
loop_var);
}
builder->CreateBr(loop_test);

Expand All @@ -807,12 +810,12 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
if (!for_stmt->reversed) {
cond =
builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SLT,
builder->CreateLoad(llvm_val[for_stmt->loop_var]),
builder->CreateLoad(loop_var),
llvm_val[for_stmt->end]);
} else {
cond =
builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SGE,
builder->CreateLoad(llvm_val[for_stmt->loop_var]),
builder->CreateLoad(loop_var),
llvm_val[for_stmt->begin]);
}
builder->CreateCondBr(cond, body, after_loop);
Expand All @@ -833,9 +836,9 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
builder->SetInsertPoint(loop_inc);

if (!for_stmt->reversed) {
create_increment(llvm_val[for_stmt->loop_var], tlctx->get_constant(1));
create_increment(loop_var, tlctx->get_constant(1));
} else {
create_increment(llvm_val[for_stmt->loop_var], tlctx->get_constant(-1));
create_increment(loop_var, tlctx->get_constant(-1));
}
builder->CreateBr(loop_test);
}
Expand Down Expand Up @@ -1408,13 +1411,15 @@ void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) {

void CodeGenLLVM::visit(LoopIndexStmt *stmt) {
TI_ASSERT(&module->getContext() == tlctx->get_this_thread_context());
if (stmt->is_struct_for) {
if (stmt->loop->is<OffloadedStmt>() &&
stmt->loop->as<OffloadedStmt>()->task_type ==
OffloadedStmt::TaskType::struct_for) {
llvm_val[stmt] = builder->CreateLoad(builder->CreateGEP(
current_coordinates, {tlctx->get_constant(0), tlctx->get_constant(0),
tlctx->get_constant(stmt->index)}));
} else {
llvm_val[stmt] = builder->CreateLoad(
offloaded_loop_vars_llvm[current_offloaded_stmt][stmt->index]);
loop_vars_llvm[stmt->loop][stmt->index]);
}
}

Expand Down
3 changes: 1 addition & 2 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
std::vector<OffloadedTask> offloaded_tasks;
BasicBlock *func_body_bb;

std::unordered_map<OffloadedStmt *, std::vector<llvm::Value *>>
offloaded_loop_vars_llvm;
std::unordered_map<Stmt *, std::vector<llvm::Value *>> loop_vars_llvm;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is it possible to use const Stmt* here? i don't think the key is modified?


using IRVisitor::visit;
using LLVMModuleBuilder::call;
Expand Down
5 changes: 3 additions & 2 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -655,10 +655,11 @@ void Block::replace_statements_in_range(int start,
}

void Block::replace_with(Stmt *old_statement,
std::unique_ptr<Stmt> &&new_statement) {
std::unique_ptr<Stmt> &&new_statement,
bool replace_usages) {
VecStatement vec;
vec.push_back(std::move(new_statement));
replace_with(old_statement, std::move(vec));
replace_with(old_statement, std::move(vec), replace_usages);
}

Stmt *Block::lookup_var(const Identifier &ident) const {
Expand Down
3 changes: 2 additions & 1 deletion taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,8 @@ class Block : public IRNode {
void insert(VecStatement &&stmt, int location = -1);
void replace_statements_in_range(int start, int end, VecStatement &&stmts);
void set_statements(VecStatement &&stmts);
void replace_with(Stmt *old_statement, std::unique_ptr<Stmt> &&new_statement);
void replace_with(Stmt *old_statement, std::unique_ptr<Stmt> &&new_statement,
bool replace_usages = true);
void insert_before(Stmt *old_statement, VecStatement &&new_statements);
void replace_with(Stmt *old_statement,
VecStatement &&new_statements,
Expand Down
8 changes: 4 additions & 4 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,19 +217,19 @@ class OffloadedStmt : public Stmt {

class LoopIndexStmt : public Stmt {
public:
Stmt *loop;
int index;
bool is_struct_for;

LoopIndexStmt(int index, bool is_struct_for)
: index(index), is_struct_for(is_struct_for) {
LoopIndexStmt(Stmt *loop, int index)
: loop(loop), index(index) {
TI_STMT_REG_FIELDS;
}

bool has_global_side_effect() const override {
return false;
}

TI_STMT_DEF_FIELDS(ret_type, index, is_struct_for);
TI_STMT_DEF_FIELDS(ret_type, loop, index);
DEFINE_ACCEPT
};

Expand Down
1 change: 1 addition & 0 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ void full_simplify(IRNode *root,
Kernel *kernel = nullptr);
void print(IRNode *root, std::string *output = nullptr);
void lower(IRNode *root);
void convert_into_loop_index(IRNode *root);
void typecheck(IRNode *root, Kernel *kernel = nullptr);
void loop_vectorize(IRNode *root);
void slp_vectorize(IRNode *root);
Expand Down
Loading