From e71f2eb268c303c030d09cafbb09d96296eaae5d Mon Sep 17 00:00:00 2001 From: squarefk Date: Sat, 31 Jul 2021 10:29:35 +0800 Subject: [PATCH] [IR] Init GlobalTensorElementExpression and PtrOffsetStmt (#2543) * init GlobalTensorElementExpression and GlobalTensorElementExpr * fix compile error for global load * fix cfg optimization * fix offset=4 for i32 and f32 * support matrix indices access, AOS SOA, and arbitrary date type * Auto Format * do better serialize() for GlobalTensorElementExpression and GlobalTensorElementStmt * simplify make * fix ti test * fix test and add load_if_ptr for dynamic indices * fix test * fix more tests * fix more tests * fix test * fix test * update test * merge condition in load_if_ptr * replace is_AOS with is_aos * remove is_bit_vectorized * add default value for GlobalTensorElementStmt * add default value for GlobalTensorElementExpression * update computation for field size * Update taichi/ir/expr.cpp Co-authored-by: xumingkuan * comment is_global_ptr() * nit * Auto Format * add extension and fix tests * Auto Format * do actual type check * rename GlobalTensorElementStmt into ShiftGlobalPtrStmt * Auto Format * nit address_offset * rename ShiftGlobalPtrStmt into PtrOffsetStmt * Auto Format Co-authored-by: Taichi Gardener Co-authored-by: xumingkuan --- python/taichi/lang/impl.py | 7 ++++ python/taichi/lang/matrix.py | 10 +++++- taichi/codegen/codegen_llvm.cpp | 11 ++++++ taichi/codegen/codegen_llvm.h | 2 ++ taichi/inc/extensions.inc.h | 2 ++ taichi/inc/statements.inc.h | 1 + taichi/ir/control_flow_graph.cpp | 2 +- taichi/ir/expr.cpp | 6 ++-- taichi/ir/frontend_ir.cpp | 57 ++++++++++++++++++++++++++++++ taichi/ir/frontend_ir.h | 33 +++++++++++++++++ taichi/ir/statements.h | 24 +++++++++++++ taichi/program/async_utils.cpp | 9 +++++ taichi/program/extension.cpp | 8 +++-- taichi/python/export_lang.cpp | 8 +++++ taichi/transforms/flag_access.cpp | 6 ++++ taichi/transforms/ir_printer.cpp | 7 ++++ taichi/transforms/lower_access.cpp | 40 +++++++++++++-------- taichi/transforms/lower_ast.cpp | 4 +++ taichi/transforms/type_check.cpp | 5 +++ tests/python/test_matrix.py | 26 +++++++++++--- 20 files changed, 242 insertions(+), 26 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 58a3c6db83397..dcb9b6566fb25 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -154,6 +154,13 @@ def subscript(value, *indices): return value[indices] +@taichi_scope +def subscript_with_offset(var, indices, cols, is_aos): + return Expr( + _ti_core.subscript_with_offset(var.ptr, make_expr_group(*indices), + cols, is_aos)) + + @taichi_scope def chain_compare(comparators, ops): _taichi_skip_traceback = 1 diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 21c88d75f57ac..6f3339b9215cc 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -297,7 +297,15 @@ def subscript(self, *indices): assert len(indices) in [1, 2] i = indices[0] j = 0 if len(indices) == 1 else indices[1] - return self(i, j) + # ptr.is_global_ptr() will check whether it's an element in the field (which is different from ptr.is_global_var()). + if isinstance(self.entries[0], + ti.Expr) and self.entries[0].ptr.is_global_ptr( + ) and ti.is_extension_supported( + ti.cfg.arch, ti.extension.dynamic_index): + return ti.subscript_with_offset(self.entries[0], (i, j), + self.m, True) + else: + return self(i, j) @property def x(self): diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 3e1bd8c059685..cd31437b1d392 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1378,6 +1378,17 @@ void CodeGenLLVM::visit(GetChStmt *stmt) { } } +void CodeGenLLVM::visit(PtrOffsetStmt *stmt) { + auto origin_address = builder->CreatePtrToInt( + llvm_val[stmt->origin], llvm::Type::getInt64Ty(*llvm_context)); + auto address_offset = builder->CreateSExt( + llvm_val[stmt->offset], llvm::Type::getInt64Ty(*llvm_context)); + auto target_address = builder->CreateAdd(origin_address, address_offset); + auto dt = stmt->ret_type.ptr_removed(); + llvm_val[stmt] = builder->CreateIntToPtr( + target_address, llvm::PointerType::get(tlctx->get_data_type(dt), 0)); +} + void CodeGenLLVM::visit(ExternalPtrStmt *stmt) { TI_ASSERT(stmt->width() == 1); diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index 255f677bcb5e2..d2b99f6bde66e 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -212,6 +212,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(GlobalPtrStmt *stmt) override; + void visit(PtrOffsetStmt *stmt) override; + void store_custom_int(llvm::Value *bit_ptr, CustomIntType *cit, llvm::Value *value, diff --git a/taichi/inc/extensions.inc.h b/taichi/inc/extensions.inc.h index 5c0ba10343346..dae6b432b30a2 100644 --- a/taichi/inc/extensions.inc.h +++ b/taichi/inc/extensions.inc.h @@ -9,3 +9,5 @@ PER_EXTENSION(bls) // Block-local storage PER_EXTENSION(assertion) // Run-time asserts in Taichi kernels PER_EXTENSION(extfunc) // Invoke external functions or backend source PER_EXTENSION(packed) // Shape will not be padded to a power of two +PER_EXTENSION( + dynamic_index) // Dynamic index support for both global and local tensors diff --git a/taichi/inc/statements.inc.h b/taichi/inc/statements.inc.h index 6cd0a449bafb6..a631fe23884c7 100644 --- a/taichi/inc/statements.inc.h +++ b/taichi/inc/statements.inc.h @@ -29,6 +29,7 @@ PER_STATEMENT(ReturnStmt) PER_STATEMENT(ArgLoadStmt) PER_STATEMENT(ExternalPtrStmt) +PER_STATEMENT(PtrOffsetStmt) PER_STATEMENT(ConstStmt) PER_STATEMENT(AllocaStmt) PER_STATEMENT(UnaryOpStmt) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 8e6f33e6943d7..609daebacb50d 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -626,7 +626,7 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { auto stmt = nodes[i]->block->statements[j].get(); if (stmt->is() || stmt->is() || stmt->is() || stmt->is() || - stmt->is()) { + stmt->is() || stmt->is()) { // TODO: unify them // A global pointer that may contain some data before this kernel. nodes[start_node]->reach_gen.insert(stmt); diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index 4ff574a749c4b..5eeefcd6eca41 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -161,7 +161,8 @@ void Expr::operator/=(const Expr &o) { } Expr load_if_ptr(const Expr &ptr) { - if (ptr.is()) { + if (ptr.is() || + ptr.is()) { return load(ptr); } else if (ptr.is()) { TI_ASSERT(ptr.cast()->snode->num_active_indices == @@ -172,7 +173,8 @@ Expr load_if_ptr(const Expr &ptr) { } Expr load(const Expr &ptr) { - TI_ASSERT(ptr.is()); + TI_ASSERT(ptr.is() || + ptr.is()); return Expr::make(ptr); } diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 833b699693f6f..5cc1f1e6ae8f8 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -219,6 +219,59 @@ void GlobalPtrExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } +void GlobalTensorElementExpression::flatten(FlattenContext *ctx) { + TI_ASSERT(var.is()) + var->flatten(ctx); + Stmt *var_stmt = ctx->back_stmt(); + SNode *snode = var.cast() + ->var.cast() + ->snode; + // Compute exact offset + // Type A[i, j][x, y] + // ^^^^ + TI_ASSERT(1 <= indices.size() && indices.size() <= 2) + if (indices.size() == 1) { + indices[0].set(load_if_ptr(indices[0])); + indices[0]->flatten(ctx); + } else { + indices[0].set(load_if_ptr(indices[0])); + indices[0]->flatten(ctx); + Stmt *i_stmt = ctx->back_stmt(); + Stmt *cols_stmt = + ctx->push_back(Stmt::make(TypedConstant(cols))); + Stmt *i_mul_cols_stmt = ctx->push_back( + Stmt::make(BinaryOpType::mul, i_stmt, cols_stmt)); + indices[1].set(load_if_ptr(indices[1])); + indices[1]->flatten(ctx); + Stmt *j_stmt = ctx->back_stmt(); + ctx->push_back( + Stmt::make(BinaryOpType::add, i_mul_cols_stmt, j_stmt)); + } + // Type A[i, j][x, y] + // ^ ^ + if (!is_aos) { + TI_ASSERT(snode->is_path_all_dense) + int size = 1; + for (auto *s = snode; s != nullptr; s = s->parent) + size *= (int)s->max_num_elements(); + Stmt *offset_stmt = ctx->back_stmt(); + Stmt *field_size_stmt = + ctx->push_back(Stmt::make(TypedConstant(size))); + ctx->push_back(Stmt::make(BinaryOpType::mul, offset_stmt, + field_size_stmt)); + } + // Type A[i, j][x, y] + // ^^^^ + Stmt *offset_stmt = ctx->back_stmt(); + Stmt *dt_size_stmt = ctx->push_back( + Stmt::make(TypedConstant(data_type_size(snode->dt)))); + ctx->push_back( + Stmt::make(BinaryOpType::mul, offset_stmt, dt_size_stmt)); + + ctx->push_back(std::make_unique(var_stmt, ctx->back_stmt())); + stmt = ctx->back_stmt(); +} + void RangeAssumptionExpression::flatten(FlattenContext *ctx) { input->flatten(ctx); base->flatten(ctx); @@ -297,6 +350,10 @@ void AtomicOpExpression::flatten(FlattenContext *ctx) { // emit local store stmt auto alloca = ctx->current_block->lookup_var(dest.cast()->id); ctx->push_back(op_type, alloca, expr->stmt); + } else if (dest.is()) { + auto global_ptr = dest.cast(); + global_ptr->flatten(ctx); + ctx->push_back(op_type, ctx->back_stmt(), expr->stmt); } else { // global variable TI_ASSERT(dest.is()); auto global_ptr = dest.cast(); diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index a02f3bdc21262..2b219ddfdeb05 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -430,6 +430,39 @@ class GlobalPtrExpression : public Expression { } }; +class GlobalTensorElementExpression : public Expression { + public: + Expr var; + ExprGroup indices; + int cols{0}; + bool is_aos{false}; + + GlobalTensorElementExpression(const Expr &var, + const ExprGroup &indices, + int cols, + bool is_aos) + : var(var), indices(indices), cols(cols), is_aos(is_aos) { + } + + std::string serialize() override { + std::string s = fmt::format("{}[", var.serialize()); + for (int i = 0; i < (int)indices.size(); i++) { + s += indices.exprs[i]->serialize(); + if (i + 1 < (int)indices.size()) + s += ", "; + } + s += "]"; + s += " (col=" + std::to_string(cols) + (is_aos ? ", AOS)" : ", SOA)"); + return s; + } + + void flatten(FlattenContext *ctx) override; + + bool is_lvalue() const override { + return true; + } +}; + class EvalExpression : public Expression { public: Stmt *stmt_ptr; diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 408d28e0df400..31cbaf14fd937 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -303,6 +303,30 @@ class GlobalPtrStmt : public Stmt { TI_DEFINE_ACCEPT_AND_CLONE }; +/** + * An accessing tensor element operation. + */ +class PtrOffsetStmt : public Stmt { + public: + Stmt *origin{nullptr}; + Stmt *offset{nullptr}; + + PtrOffsetStmt(Stmt *origin, Stmt *offset) : origin(origin), offset(offset) { + element_type() = origin->cast()->ret_type; + TI_STMT_REG_FIELDS; + } + + bool has_global_side_effect() const override { + // After access lowered, activate info will be recorded in SNodeLookupStmt's + // activate for AOS sparse data structure. We don't support SOA sparse data + // structure for now. + return false; + } + + TI_STMT_DEF_FIELDS(ret_type, origin, offset); + TI_DEFINE_ACCEPT_AND_CLONE +}; + /** * An operation to a SNode (not necessarily a leaf SNode). */ diff --git a/taichi/program/async_utils.cpp b/taichi/program/async_utils.cpp index 0ec4aefe33b5b..4d1b6205396b5 100644 --- a/taichi/program/async_utils.cpp +++ b/taichi/program/async_utils.cpp @@ -165,6 +165,15 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { ir_bank->get_async_state(snode, AsyncState::Type::value)); } } + if (auto global_tensor_element = + global_store->dest->cast()) { + if (auto dest = global_tensor_element->origin->cast()) { + for (auto &snode : dest->snodes.data) { + meta.output_states.insert( + ir_bank->get_async_state(snode, AsyncState::Type::value)); + } + } + } } if (auto global_atomic = stmt->cast()) { if (auto dest = global_atomic->dest->cast()) { diff --git a/taichi/program/extension.cpp b/taichi/program/extension.cpp index 7dbae325f3067..a33b098d3a701 100644 --- a/taichi/program/extension.cpp +++ b/taichi/program/extension.cpp @@ -11,15 +11,17 @@ bool is_extension_supported(Arch arch, Extension ext) { {Arch::x64, {Extension::sparse, Extension::async_mode, Extension::quant, Extension::quant_basic, Extension::data64, Extension::adstack, - Extension::assertion, Extension::extfunc, Extension::packed}}, + Extension::assertion, Extension::extfunc, Extension::packed, + Extension::dynamic_index}}, {Arch::arm64, {Extension::sparse, Extension::async_mode, Extension::quant, Extension::quant_basic, Extension::data64, Extension::adstack, - Extension::assertion, Extension::packed}}, + Extension::assertion, Extension::packed, Extension::dynamic_index}}, {Arch::cuda, {Extension::sparse, Extension::async_mode, Extension::quant, Extension::quant_basic, Extension::data64, Extension::adstack, - Extension::bls, Extension::assertion, Extension::packed}}, + Extension::bls, Extension::assertion, Extension::packed, + Extension::dynamic_index}}, {Arch::metal, {Extension::adstack, Extension::assertion, Extension::quant_basic, Extension::async_mode, Extension::sparse}}, diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 918aab9c638ec..360142acd66a1 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -371,6 +371,8 @@ void export_lang(py::module &m) { [](Expr *expr) { return expr->is(); }) .def("is_external_var", [](Expr *expr) { return expr->is(); }) + .def("is_global_ptr", + [](Expr *expr) { return expr->is(); }) .def("is_primal", [](Expr *expr) { return expr->cast()->is_primal; @@ -698,6 +700,12 @@ void export_lang(py::module &m) { return expr[expr_group]; }); + m.def("subscript_with_offset", + [](const Expr &var, const ExprGroup &indices, int cols, bool is_aos) { + return Expr::make(var, indices, cols, + is_aos); + }); + m.def("subscript", [](SNode *snode, const ExprGroup &indices) { return Expr::make(snode, indices.loaded()); }); diff --git a/taichi/transforms/flag_access.cpp b/taichi/transforms/flag_access.cpp index 02831b376b4dd..d4cea6eac92f5 100644 --- a/taichi/transforms/flag_access.cpp +++ b/taichi/transforms/flag_access.cpp @@ -54,6 +54,12 @@ class FlagAccess : public IRVisitor { if (stmt->dest->is()) { stmt->dest->as()->activate = true; } + if (stmt->dest->is()) { + if (stmt->dest->as()->origin->is()) { + stmt->dest->as()->origin->as()->activate = + true; + } + } } void visit(AtomicOpStmt *stmt) { diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index 6c68cd3e8ab5f..4e44a9bedf5aa 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -383,6 +383,13 @@ class IRPrinter : public IRVisitor { print_raw(s); } + void visit(PtrOffsetStmt *stmt) override { + std::string s = + fmt::format("{}{} = shift ptr [{} + {}]", stmt->type_hint(), + stmt->name(), stmt->origin->name(), stmt->offset->name()); + print_raw(s); + } + void visit(ArgLoadStmt *stmt) override { print("{}{} = arg[{}]", stmt->type_hint(), stmt->name(), stmt->arg_id); } diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp index c2c44562a9628..b83b5bd458de5 100644 --- a/taichi/transforms/lower_access.cpp +++ b/taichi/transforms/lower_access.cpp @@ -148,23 +148,35 @@ class LowerAccess : public IRVisitor { } void visit(GlobalLoadStmt *stmt) override { - if (stmt->src->is()) { - // No need to activate for all read accesses - auto lowered = lower_vector_ptr(stmt->src->as(), false); - stmt->src = lowered.back().get(); - modifier.insert_before(stmt, std::move(lowered)); - } + if (!stmt->src->is()) + return; + // No need to activate for all read accesses + auto lowered = lower_vector_ptr(stmt->src->as(), false); + stmt->src = lowered.back().get(); + modifier.insert_before(stmt, std::move(lowered)); + } + + // TODO: this seems to be redundant + void visit(PtrOffsetStmt *stmt) override { + if (!stmt->origin->is()) + return; + auto ptr = stmt->origin->as(); + // If ptr already has activate = false, no need to activate all the + // generated micro-access ops. Otherwise, activate the nodes. + auto lowered = lower_vector_ptr(ptr, ptr->activate); + stmt->origin = lowered.back().get(); + modifier.insert_before(stmt, std::move(lowered)); } void visit(GlobalStoreStmt *stmt) override { - if (stmt->dest->is()) { - auto ptr = stmt->dest->as(); - // If ptr already has activate = false, no need to activate all the - // generated micro-access ops. Otherwise, activate the nodes. - auto lowered = lower_vector_ptr(ptr, ptr->activate); - stmt->dest = lowered.back().get(); - modifier.insert_before(stmt, std::move(lowered)); - } + if (!stmt->dest->is()) + return; + auto ptr = stmt->dest->as(); + // If ptr already has activate = false, no need to activate all the + // generated micro-access ops. Otherwise, activate the nodes. + auto lowered = lower_vector_ptr(ptr, ptr->activate); + stmt->dest = lowered.back().get(); + modifier.insert_before(stmt, std::move(lowered)); } void visit(SNodeOpStmt *stmt) override { diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 97edef1aa09c2..49b959c5b2803 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -361,6 +361,10 @@ class LowerAST : public IRVisitor { fctx.push_back( assign->parent->lookup_var(assign->lhs.cast()->id), expr->stmt); + } else if (assign->lhs.is()) { + auto global_ptr = assign->lhs.cast(); + global_ptr->flatten(&fctx); + fctx.push_back(fctx.back_stmt(), expr->stmt); } else { // global variable TI_ASSERT(assign->lhs.is()); auto global_ptr = assign->lhs.cast(); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 9672996659ef6..943787b584464 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -158,6 +158,11 @@ class TypeCheck : public IRVisitor { } } + void visit(PtrOffsetStmt *stmt) override { + TI_ASSERT(stmt->offset->ret_type->is_primitive(PrimitiveTypeID::i32)); + stmt->ret_type.set_is_pointer(true); + } + void visit(GlobalStoreStmt *stmt) override { auto dst_value_type = stmt->dest->ret_type.ptr_removed(); if (dst_value_type->is() || diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 5b26574e19249..d12fc69ffb1ea 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -161,18 +161,34 @@ def run(): assert np.allclose(r2[None].value.to_numpy(), ops(a, c)) -@ti.test(arch=ti.cpu) +@ti.test(require=ti.extension.dynamic_index) def test_matrix_non_constant_index(): m = ti.Matrix.field(2, 2, ti.i32, 5) + v = ti.Vector.field(10, ti.i32, 5) @ti.kernel - def func(): + def func1(): for i in range(5): for j, k in ti.ndrange(2, 2): - m[i][j, k] = 12 + m[i][j, k] = j * j + k * k + assert m[1][0, 1] == 1 + assert m[2][1, 0] == 1 + assert m[3][1, 1] == 2 - with pytest.raises(ti.TaichiSyntaxError): - func() + func1() + assert m[4][0, 1] == 1 + + @ti.kernel + def func2(): + for i in range(5): + for j in range(4): + v[i][j * j] = j * j + assert v[1][0] == 0 + assert v[1][1] == 1 + assert v[1][4] == 4 + + func2() + assert v[1][9] == 9 @ti.test(arch=ti.cpu)