Skip to content

Commit

Permalink
[IR][refactor] Convert loop_var into LoopIndexStmt (#953)
Browse files Browse the repository at this point in the history
* [skip ci] Convert loop_var into LoopIndexStmt (stage 1)

* [skip ci] Fix irpass::offload

* Should work on CPU now

* Move convert_into_loop_index to right after lower_ast

* Merge master

* Fix tests

* Fix CUDA compilation

* [skip ci] Remove comments

* Metal backend

* [skip ci] Metal backend (continued)

* opengl backend

* [skip ci] enforce code format

* Revert "[skip ci] enforce code format"

This reverts commit f284688.

* [skip ci] Remove LoopIndexStmt::is_struct_for

* Finally move the pass into `lower`

* [skip ci] add comment

* Fix opengl backend

* Use const Stmt*

* [skip ci] enforce code format

Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
xumingkuan and taichi-gardener authored May 14, 2020
1 parent 0cef342 commit c6086e0
Show file tree
Hide file tree
Showing 20 changed files with 243 additions and 161 deletions.
6 changes: 3 additions & 3 deletions docs/hello.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ Let's dive into this simple Taichi program.

import taichi as ti
-------------------
Taichi is a domain-specific language (DSL) embedded in Python. To make Taichi as easy to use as a Python package,
we have done heavy engineering with this goal in mind - letting every Python programmer write Taichi codes with
minimal learning effort. You can even use your favorite Python package management system, Python IDEs and other
Taichi is a domain-specific language (DSL) embedded in Python. To make Taichi as easy to use as a Python package,
we have done heavy engineering with this goal in mind - letting every Python programmer write Taichi codes with
minimal learning effort. You can even use your favorite Python package management system, Python IDEs and other
Python packages in conjunction with Taichi.

Portability
Expand Down
46 changes: 32 additions & 14 deletions taichi/analysis/verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,46 @@ class IRVerifier : public BasicStmtVisitor {
TI_ASSERT(stmt->ptr->is<AllocaStmt>());
}

void visit(LoopIndexStmt *stmt) override {
basic_verify(stmt);
TI_ASSERT(stmt->loop);
if (stmt->loop->is<OffloadedStmt>()) {
TI_ASSERT(stmt->loop->as<OffloadedStmt>()->task_type ==
OffloadedStmt::TaskType::struct_for ||
stmt->loop->as<OffloadedStmt>()->task_type ==
OffloadedStmt::TaskType::range_for);
} else {
TI_ASSERT(stmt->loop->is<StructForStmt>() ||
stmt->loop->is<RangeForStmt>());
}
}

void visit(RangeForStmt *for_stmt) override {
basic_verify(for_stmt);
TI_ASSERT(for_stmt->loop_var->is<AllocaStmt>());
TI_ASSERT_INFO(irpass::analysis::gather_statements(
for_stmt->loop_var->parent,
[&](Stmt *s) {
if (auto store = s->cast<LocalStoreStmt>())
return store->ptr == for_stmt->loop_var;
else if (auto atomic = s->cast<AtomicOpStmt>()) {
return atomic->dest == for_stmt->loop_var;
} else {
return false;
}
})
.empty(),
"loop_var of {} modified", for_stmt->id);
if (for_stmt->loop_var) {
TI_ASSERT(for_stmt->loop_var->is<AllocaStmt>());
TI_ASSERT_INFO(irpass::analysis::gather_statements(
for_stmt->loop_var->parent,
[&](Stmt *s) {
if (auto store = s->cast<LocalStoreStmt>())
return store->ptr == for_stmt->loop_var;
else if (auto atomic = s->cast<AtomicOpStmt>()) {
return atomic->dest == for_stmt->loop_var;
} else {
return false;
}
})
.empty(),
"loop_var of {} modified", for_stmt->id);
}
for_stmt->body->accept(this);
}

void visit(StructForStmt *for_stmt) override {
basic_verify(for_stmt);
for (auto &loop_var : for_stmt->loop_vars) {
if (!loop_var)
continue;
TI_ASSERT(loop_var->is<AllocaStmt>());
TI_ASSERT_INFO(irpass::analysis::gather_statements(
loop_var->parent,
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/cpu/codegen_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class CodeGenLLVMCPU : public CodeGenLLVM {
tlctx->get_data_type<int>()});

auto loop_var = create_entry_block_alloca(DataType::i32);
offloaded_loop_vars_llvm[stmt].push_back(loop_var);
loop_vars_llvm[stmt].push_back(loop_var);
builder->CreateStore(get_arg(1), loop_var);
stmt->body->accept(this);

Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
tlctx->get_data_type<int>()});

auto loop_var = create_entry_block_alloca(DataType::i32);
offloaded_loop_vars_llvm[stmt].push_back(loop_var);
loop_vars_llvm[stmt].push_back(loop_var);
builder->CreateStore(get_arg(1), loop_var);
stmt->body->accept(this);

Expand Down
54 changes: 25 additions & 29 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,15 +309,22 @@ class KernelCodegen : public IRVisitor {
}

void visit(LoopIndexStmt *stmt) override {
using TaskType = OffloadedStmt::TaskType;
const auto type = current_kernel_attribs_->task_type;
const auto stmt_name = stmt->raw_name();
if (type == TaskType::range_for) {
if (stmt->loop->is<OffloadedStmt>()) {
using TaskType = OffloadedStmt::TaskType;
const auto type = stmt->loop->as<OffloadedStmt>()->task_type;
if (type == TaskType::range_for) {
TI_ASSERT(stmt->index == 0);
emit("const int {} = {};", stmt_name, kLinearLoopIndexName);
} else if (type == TaskType::struct_for) {
emit("const int {} = {}.coords[{}];", stmt_name, kListgenElemVarName,
stmt->index);
} else {
TI_NOT_IMPLEMENTED;
}
} else if (stmt->loop->is<RangeForStmt>()) {
TI_ASSERT(stmt->index == 0);
emit("const int {} = {};", stmt_name, kLinearLoopIndexName);
} else if (type == TaskType::struct_for) {
emit("const int {} = {}.coords[{}];", stmt_name, kListgenElemVarName,
stmt->index);
emit("const int {} = {};", stmt_name, stmt->loop->raw_name());
} else {
TI_NOT_IMPLEMENTED;
}
Expand Down Expand Up @@ -445,29 +452,18 @@ class KernelCodegen : public IRVisitor {

void visit(RangeForStmt *for_stmt) override {
TI_ASSERT(for_stmt->width() == 1);
auto *loop_var = for_stmt->loop_var;
if (loop_var->ret_type.data_type == DataType::i32) {
if (!for_stmt->reversed) {
emit("for (int {}_ = {}; {}_ < {}; {}_ = {}_ + {}) {{",
loop_var->raw_name(), for_stmt->begin->raw_name(),
loop_var->raw_name(), for_stmt->end->raw_name(),
loop_var->raw_name(), loop_var->raw_name(), 1);
emit(" int {} = {}_;", loop_var->raw_name(), loop_var->raw_name());
} else {
// reversed for loop
emit("for (int {}_ = {} - 1; {}_ >= {}; {}_ = {}_ - {}) {{",
loop_var->raw_name(), for_stmt->end->raw_name(),
loop_var->raw_name(), for_stmt->begin->raw_name(),
loop_var->raw_name(), loop_var->raw_name(), 1);
emit(" int {} = {}_;", loop_var->raw_name(), loop_var->raw_name());
}
auto loop_var_name = for_stmt->raw_name();
if (!for_stmt->reversed) {
emit("for (int {}_ = {}; {}_ < {}; {}_ = {}_ + {}) {{", loop_var_name,
for_stmt->begin->raw_name(), loop_var_name,
for_stmt->end->raw_name(), loop_var_name, loop_var_name, 1);
emit(" int {} = {}_;", loop_var_name, loop_var_name);
} else {
TI_ASSERT(!for_stmt->reversed);
const auto type_name = metal_data_type_name(loop_var->element_type());
emit("for ({} {} = {}; {} < {}; {} = {} + ({})1) {{", type_name,
loop_var->raw_name(), for_stmt->begin->raw_name(),
loop_var->raw_name(), for_stmt->end->raw_name(),
loop_var->raw_name(), loop_var->raw_name(), type_name);
// reversed for loop
emit("for (int {}_ = {} - 1; {}_ >= {}; {}_ = {}_ - {}) {{",
loop_var_name, for_stmt->end->raw_name(), loop_var_name,
for_stmt->begin->raw_name(), loop_var_name, loop_var_name, 1);
emit(" int {} = {}_;", loop_var_name, loop_var_name);
}
for_stmt->body->accept(this);
emit("}}");
Expand Down
5 changes: 3 additions & 2 deletions taichi/backends/metal/shaders/helpers.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ STR(

inline int ifloordiv(int lhs, int rhs) {
const int intm = (lhs / rhs);
return (((lhs < 0) != (rhs < 0) && lhs &&
(rhs * intm != lhs)) ? (intm - 1) : intm);
return (((lhs < 0) != (rhs < 0) && lhs && (rhs * intm != lhs))
? (intm - 1)
: intm);
}

int32_t pow_i32(int32_t x, int32_t n) {
Expand Down
71 changes: 33 additions & 38 deletions taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,14 +414,13 @@ class KernelGen : public IRVisitor {
emit("{} {} = atan({}, {});", dt_name, bin_name, lhs_name, rhs_name);
}
return;
} else if (bin->op_type == BinaryOpType::pow
&& is_integral(bin->rhs->element_type())) {
// The GLSL `pow` is not so percise for `int`... e.g.: `pow(5, 3)` obtains 124
// So that we have to use some hack to make it percise.
// Discussion: https://github.com/taichi-dev/taichi/pull/943#issuecomment-626354902
} else if (bin->op_type == BinaryOpType::pow &&
is_integral(bin->rhs->element_type())) {
// The GLSL `pow` is not so percise for `int`... e.g.: `pow(5, 3)` obtains
// 124 So that we have to use some hack to make it percise. Discussion:
// https://github.com/taichi-dev/taichi/pull/943#issuecomment-626354902
emit("{} {} = {}(fast_pow_{}({}, {}));", dt_name, bin_name, dt_name,
data_type_short_name(bin->lhs->element_type()),
lhs_name, rhs_name);
data_type_short_name(bin->lhs->element_type()), lhs_name, rhs_name);
used.fast_pow = true;
return;
}
Expand Down Expand Up @@ -602,38 +601,32 @@ class KernelGen : public IRVisitor {
}

void visit(LoopIndexStmt *stmt) override {
TI_ASSERT(!stmt->is_struct_for);
TI_ASSERT(stmt->index == 0); // TODO: multiple indices
emit("int {} = _itv;", stmt->short_name());
if (stmt->loop->is<OffloadedStmt>()) {
TI_ASSERT(stmt->loop->as<OffloadedStmt>()->task_type ==
OffloadedStmt::TaskType::range_for);
emit("int {} = _itv;", stmt->short_name());
} else if (stmt->loop->is<RangeForStmt>()) {
emit("int {} = {};", stmt->short_name(), stmt->loop->short_name());
} else {
TI_NOT_IMPLEMENTED
}
}

void visit(RangeForStmt *for_stmt) override {
TI_ASSERT(for_stmt->width() == 1);
auto *loop_var = for_stmt->loop_var;
if (loop_var->ret_type.data_type == DataType::i32) {
if (!for_stmt->reversed) {
emit("for (int {}_ = {}; {}_ < {}; {}_ = {}_ + {}) {{",
loop_var->short_name(), for_stmt->begin->short_name(),
loop_var->short_name(), for_stmt->end->short_name(),
loop_var->short_name(), loop_var->short_name(), 1);
// variable named `loop_var->short_name()` is already allocated by
// alloca
emit(" {} = {}_;", loop_var->short_name(), loop_var->short_name());
} else {
// reversed for loop
emit("for (int {}_ = {} - 1; {}_ >= {}; {}_ = {}_ - {}) {{",
loop_var->short_name(), for_stmt->end->short_name(),
loop_var->short_name(), for_stmt->begin->short_name(),
loop_var->short_name(), loop_var->short_name(), 1);
emit(" {} = {}_;", loop_var->short_name(), loop_var->short_name());
}
auto loop_var_name = for_stmt->short_name();
if (!for_stmt->reversed) {
emit("for (int {}_ = {}; {}_ < {}; {}_ = {}_ + {}) {{", loop_var_name,
for_stmt->begin->short_name(), loop_var_name,
for_stmt->end->short_name(), loop_var_name, loop_var_name, 1);
emit(" int {} = {}_;", loop_var_name, loop_var_name);
} else {
TI_ASSERT(!for_stmt->reversed);
const auto type_name = opengl_data_type_name(loop_var->element_type());
emit("for ({} {} = {}; {} < {}; {} = {} + 1) {{", type_name,
loop_var->short_name(), for_stmt->begin->short_name(),
loop_var->short_name(), for_stmt->end->short_name(),
loop_var->short_name(), loop_var->short_name());
// reversed for loop
emit("for (int {}_ = {} - 1; {}_ >= {}; {}_ = {}_ - {}) {{",
loop_var_name, for_stmt->end->short_name(), loop_var_name,
for_stmt->begin->short_name(), loop_var_name, loop_var_name, 1);
emit(" int {} = {}_;", loop_var_name, loop_var_name);
}
for_stmt->body->accept(this);
emit("}}");
Expand Down Expand Up @@ -730,12 +723,14 @@ void OpenglCodeGen::lower() {
auto ir = kernel_->ir;
auto &config = kernel_->program.config;
config.demote_dense_struct_fors = true;
auto res = irpass::compile_to_offloads(ir, config,
/*vectorize=*/false, kernel_->grad,
/*ad_use_stack=*/false, config.print_ir,
/*lower_global_access*/true);
auto res =
irpass::compile_to_offloads(ir, config,
/*vectorize=*/false, kernel_->grad,
/*ad_use_stack=*/false, config.print_ir,
/*lower_global_access*/ true);
global_tmps_buffer_size_ = res.total_size;
TI_TRACE("[glsl] Global temporary buffer size {} B", global_tmps_buffer_size_);
TI_TRACE("[glsl] Global temporary buffer size {} B",
global_tmps_buffer_size_);
#ifdef _GLSL_DEBUG
irpass::print(ir);
#endif
Expand Down
35 changes: 19 additions & 16 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -790,13 +790,16 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
BasicBlock *after_loop = BasicBlock::Create(*llvm_context, "after_for", func);
BasicBlock *loop_test =
BasicBlock::Create(*llvm_context, "for_loop_test", func);

auto loop_var = create_entry_block_alloca(DataType::i32);
loop_vars_llvm[for_stmt].push_back(loop_var);

if (!for_stmt->reversed) {
builder->CreateStore(llvm_val[for_stmt->begin],
llvm_val[for_stmt->loop_var]);
builder->CreateStore(llvm_val[for_stmt->begin], loop_var);
} else {
builder->CreateStore(
builder->CreateSub(llvm_val[for_stmt->end], tlctx->get_constant(1)),
llvm_val[for_stmt->loop_var]);
loop_var);
}
builder->CreateBr(loop_test);

Expand All @@ -805,15 +808,13 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
builder->SetInsertPoint(loop_test);
llvm::Value *cond;
if (!for_stmt->reversed) {
cond =
builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SLT,
builder->CreateLoad(llvm_val[for_stmt->loop_var]),
llvm_val[for_stmt->end]);
cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SLT,
builder->CreateLoad(loop_var),
llvm_val[for_stmt->end]);
} else {
cond =
builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SGE,
builder->CreateLoad(llvm_val[for_stmt->loop_var]),
llvm_val[for_stmt->begin]);
cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SGE,
builder->CreateLoad(loop_var),
llvm_val[for_stmt->begin]);
}
builder->CreateCondBr(cond, body, after_loop);
}
Expand All @@ -833,9 +834,9 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
builder->SetInsertPoint(loop_inc);

if (!for_stmt->reversed) {
create_increment(llvm_val[for_stmt->loop_var], tlctx->get_constant(1));
create_increment(loop_var, tlctx->get_constant(1));
} else {
create_increment(llvm_val[for_stmt->loop_var], tlctx->get_constant(-1));
create_increment(loop_var, tlctx->get_constant(-1));
}
builder->CreateBr(loop_test);
}
Expand Down Expand Up @@ -1408,13 +1409,15 @@ void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) {

void CodeGenLLVM::visit(LoopIndexStmt *stmt) {
TI_ASSERT(&module->getContext() == tlctx->get_this_thread_context());
if (stmt->is_struct_for) {
if (stmt->loop->is<OffloadedStmt>() &&
stmt->loop->as<OffloadedStmt>()->task_type ==
OffloadedStmt::TaskType::struct_for) {
llvm_val[stmt] = builder->CreateLoad(builder->CreateGEP(
current_coordinates, {tlctx->get_constant(0), tlctx->get_constant(0),
tlctx->get_constant(stmt->index)}));
} else {
llvm_val[stmt] = builder->CreateLoad(
offloaded_loop_vars_llvm[current_offloaded_stmt][stmt->index]);
llvm_val[stmt] =
builder->CreateLoad(loop_vars_llvm[stmt->loop][stmt->index]);
}
}

Expand Down
3 changes: 1 addition & 2 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
std::vector<OffloadedTask> offloaded_tasks;
BasicBlock *func_body_bb;

std::unordered_map<OffloadedStmt *, std::vector<llvm::Value *>>
offloaded_loop_vars_llvm;
std::unordered_map<const Stmt *, std::vector<llvm::Value *>> loop_vars_llvm;

using IRVisitor::visit;
using LLVMModuleBuilder::call;
Expand Down
5 changes: 3 additions & 2 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -655,10 +655,11 @@ void Block::replace_statements_in_range(int start,
}

void Block::replace_with(Stmt *old_statement,
std::unique_ptr<Stmt> &&new_statement) {
std::unique_ptr<Stmt> &&new_statement,
bool replace_usages) {
VecStatement vec;
vec.push_back(std::move(new_statement));
replace_with(old_statement, std::move(vec));
replace_with(old_statement, std::move(vec), replace_usages);
}

Stmt *Block::lookup_var(const Identifier &ident) const {
Expand Down
4 changes: 3 additions & 1 deletion taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,9 @@ class Block : public IRNode {
void insert(VecStatement &&stmt, int location = -1);
void replace_statements_in_range(int start, int end, VecStatement &&stmts);
void set_statements(VecStatement &&stmts);
void replace_with(Stmt *old_statement, std::unique_ptr<Stmt> &&new_statement);
void replace_with(Stmt *old_statement,
std::unique_ptr<Stmt> &&new_statement,
bool replace_usages = true);
void insert_before(Stmt *old_statement, VecStatement &&new_statements);
void replace_with(Stmt *old_statement,
VecStatement &&new_statements,
Expand Down
Loading

0 comments on commit c6086e0

Please sign in to comment.