From 25227b022e22fa4f514cb8590f849f5eb65908ff Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Mon, 19 Jul 2021 16:55:20 +0800 Subject: [PATCH 01/34] init GlobalTensorElementExpression and GlobalTensorElementExpr --- python/taichi/lang/impl.py | 3 +++ python/taichi/lang/matrix.py | 2 +- taichi/codegen/codegen_llvm.cpp | 10 ++++++++++ taichi/codegen/codegen_llvm.h | 2 ++ taichi/inc/statements.inc.h | 1 + taichi/ir/frontend_ir.cpp | 9 +++++++++ taichi/ir/frontend_ir.h | 20 ++++++++++++++++++++ taichi/ir/statements.h | 24 ++++++++++++++++++++++++ taichi/python/export_lang.cpp | 4 ++++ taichi/runtime/llvm/runtime.cpp | 8 ++++++++ taichi/transforms/ir_printer.cpp | 7 +++++++ taichi/transforms/lower_access.cpp | 12 ++++++++++++ taichi/transforms/lower_ast.cpp | 4 ++++ taichi/transforms/type_check.cpp | 5 +++++ 14 files changed, 110 insertions(+), 1 deletion(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 97b975ee67a76..da3e69ec69180 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -145,6 +145,9 @@ def subscript(value, *indices): else: return value[indices] +@taichi_scope +def subscript_with_offset(origin, offset): + return Expr(_ti_core.subscript_with_offset(origin.ptr, make_constant_expr(offset).ptr)) @taichi_scope def chain_compare(comparators, ops): diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 42362f1b85abd..3ff629562c4a1 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -232,7 +232,7 @@ def linearize_entry_id(self, *args): def __call__(self, *args, **kwargs): _taichi_skip_traceback = 1 assert kwargs == {} - return self.entries[self.linearize_entry_id(*args)] + return ti.subscript_with_offset(self.entries[0], self.linearize_entry_id(*args)) def get_field_members(self): return self.entries diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 59d4b93354a1b..c63e770e0badb 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1378,6 +1378,16 @@ void CodeGenLLVM::visit(GetChStmt *stmt) { } } +void CodeGenLLVM::visit(GlobalTensorElementStmt *stmt) { + if (stmt->ret_type.ptr_removed()->is_primitive(PrimitiveTypeID::f32)) { + llvm_val[stmt] = + create_call("access_with_offset_f32", {llvm_val[stmt->origin], llvm_val[stmt->offset]}); + } else if (stmt->ret_type.ptr_removed()->is_primitive(PrimitiveTypeID::i32)) { + llvm_val[stmt] = + create_call("access_with_offset_i32", {llvm_val[stmt->origin], llvm_val[stmt->offset]}); + } +} + void CodeGenLLVM::visit(ExternalPtrStmt *stmt) { TI_ASSERT(stmt->width() == 1); diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index 255f677bcb5e2..f0bbb0f6fed72 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -212,6 +212,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(GlobalPtrStmt *stmt) override; + void visit(GlobalTensorElementStmt *stmt) override; + void store_custom_int(llvm::Value *bit_ptr, CustomIntType *cit, llvm::Value *value, diff --git a/taichi/inc/statements.inc.h b/taichi/inc/statements.inc.h index 6cd0a449bafb6..0870b3bb59919 100644 --- a/taichi/inc/statements.inc.h +++ b/taichi/inc/statements.inc.h @@ -29,6 +29,7 @@ PER_STATEMENT(ReturnStmt) PER_STATEMENT(ArgLoadStmt) PER_STATEMENT(ExternalPtrStmt) +PER_STATEMENT(GlobalTensorElementStmt) PER_STATEMENT(ConstStmt) PER_STATEMENT(AllocaStmt) PER_STATEMENT(UnaryOpStmt) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 833b699693f6f..8f7f1fdf90014 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -219,6 +219,15 @@ void GlobalPtrExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } +void GlobalTensorElementExpression::flatten(FlattenContext *ctx) { + origin_expr->flatten(ctx); + Stmt *origin_stmt = ctx->back_stmt(); + offset->flatten(ctx); + ctx->push_back(std::make_unique( + origin_stmt, ctx->back_stmt())); + stmt = ctx->back_stmt(); +} + void RangeAssumptionExpression::flatten(FlattenContext *ctx) { input->flatten(ctx); base->flatten(ctx); diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index a02f3bdc21262..43fdd99111eb7 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -430,6 +430,26 @@ class GlobalPtrExpression : public Expression { } }; +class GlobalTensorElementExpression : public Expression { + public: + Expr origin_expr; + Expr offset; + + GlobalTensorElementExpression(const Expr &origin_expr, const Expr &offset) + : origin_expr(origin_expr), offset(offset) { + } + + std::string serialize() override { + return "@@"; + } + + void flatten(FlattenContext *ctx) override; + + bool is_lvalue() const override { + return true; + } +}; + class EvalExpression : public Expression { public: Stmt *stmt_ptr; diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 408d28e0df400..e33ecfd8717e9 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -303,6 +303,30 @@ class GlobalPtrStmt : public Stmt { TI_DEFINE_ACCEPT_AND_CLONE }; +/** + * An accessing tensor element operation. + */ +class GlobalTensorElementStmt : public Stmt { + public: + Stmt *origin, *offset; + bool is_bit_vectorized; // TODO: remove this field + + GlobalTensorElementStmt(Stmt *origin, Stmt *offset) + : origin(origin), + offset(offset), + is_bit_vectorized(is_bit_vectorized) { + element_type() = origin->cast()->ret_type; + TI_STMT_REG_FIELDS; + } + + bool has_global_side_effect() const override { + return false; + } + + TI_STMT_DEF_FIELDS(ret_type, origin, offset, is_bit_vectorized); + TI_DEFINE_ACCEPT_AND_CLONE +}; + /** * An operation to a SNode (not necessarily a leaf SNode). */ diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index f67465cc5cc10..d3d562dea2d51 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -679,6 +679,10 @@ void export_lang(py::module &m) { return expr[expr_group]; }); + m.def("subscript_with_offset", [](const Expr &origin_expr, const Expr &offset) { + return Expr::make(origin_expr, offset); + }); + m.def("subscript", [](SNode *snode, const ExprGroup &indices) { return Expr::make(snode, indices.loaded()); }); diff --git a/taichi/runtime/llvm/runtime.cpp b/taichi/runtime/llvm/runtime.cpp index 2888e7f19c363..ff14c73880d0c 100644 --- a/taichi/runtime/llvm/runtime.cpp +++ b/taichi/runtime/llvm/runtime.cpp @@ -365,6 +365,14 @@ STRUCT_FIELD(StructMeta, refine_coordinates); STRUCT_FIELD(StructMeta, is_active); STRUCT_FIELD(StructMeta, context); +f32* access_with_offset_f32(f32* origin, i32 offset) { + return (f32*)((i64)origin + offset * 8); +} + +i32* access_with_offset_i32(i32* origin, i32 offset) { + return (i32*)((i64)origin + offset * 8); +} + struct LLVMRuntime; constexpr bool enable_assert = true; diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index 6c68cd3e8ab5f..21f53bf25e29d 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -383,6 +383,13 @@ class IRPrinter : public IRVisitor { print_raw(s); } + void visit(GlobalTensorElementStmt *stmt) override { + // TODO: do actual ir_printer + std::string s = + fmt::format("{}{} = global tensor element {} {}", stmt->type_hint(), stmt->name(), stmt->origin->name(), stmt->offset->name()); + print_raw(s); + } + void visit(ArgLoadStmt *stmt) override { print("{}{} = arg[{}]", stmt->type_hint(), stmt->name(), stmt->arg_id); } diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp index 299bfe53b77d7..77a95ba237f1c 100644 --- a/taichi/transforms/lower_access.cpp +++ b/taichi/transforms/lower_access.cpp @@ -153,6 +153,18 @@ class LowerAccess : public IRVisitor { } } + // TODO: this seems to be redundant + void visit(GlobalTensorElementStmt *stmt) override { + if (stmt->origin->is()) { + auto ptr = stmt->origin->as(); + // If ptr already has activate = false, no need to activate all the + // generated micro-access ops. Otherwise, activate the nodes. + auto lowered = lower_vector_ptr(ptr, ptr->activate); + stmt->origin = lowered.back().get(); + modifier.insert_before(stmt, std::move(lowered)); + } + } + void visit(GlobalStoreStmt *stmt) override { if (stmt->dest->is()) { auto ptr = stmt->dest->as(); diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 97edef1aa09c2..49b959c5b2803 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -361,6 +361,10 @@ class LowerAST : public IRVisitor { fctx.push_back( assign->parent->lookup_var(assign->lhs.cast()->id), expr->stmt); + } else if (assign->lhs.is()) { + auto global_ptr = assign->lhs.cast(); + global_ptr->flatten(&fctx); + fctx.push_back(fctx.back_stmt(), expr->stmt); } else { // global variable TI_ASSERT(assign->lhs.is()); auto global_ptr = assign->lhs.cast(); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 9672996659ef6..559d6d259dc4d 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -158,6 +158,11 @@ class TypeCheck : public IRVisitor { } } + void visit(GlobalTensorElementStmt *stmt) override { + // TODO: do actual type_check + stmt->ret_type.set_is_pointer(true); + } + void visit(GlobalStoreStmt *stmt) override { auto dst_value_type = stmt->dest->ret_type.ptr_removed(); if (dst_value_type->is() || From b0e3b112da1c978e002353a7da2f935218213d7f Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Mon, 19 Jul 2021 17:57:57 +0800 Subject: [PATCH 02/34] fix compile error for global load --- taichi/ir/expr.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index 4ff574a749c4b..f996c486df2b1 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -163,6 +163,8 @@ void Expr::operator/=(const Expr &o) { Expr load_if_ptr(const Expr &ptr) { if (ptr.is()) { return load(ptr); + } else if (ptr.is()) { + return load(ptr); } else if (ptr.is()) { TI_ASSERT(ptr.cast()->snode->num_active_indices == 0); @@ -172,7 +174,7 @@ Expr load_if_ptr(const Expr &ptr) { } Expr load(const Expr &ptr) { - TI_ASSERT(ptr.is()); + TI_ASSERT(ptr.is() || ptr.is()); return Expr::make(ptr); } From a9d3dc2362daa1d57b79adb028f081dfa4082474 Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Mon, 19 Jul 2021 18:40:01 +0800 Subject: [PATCH 03/34] fix cfg optimization --- taichi/ir/control_flow_graph.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 8e6f33e6943d7..e8c856a633218 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -626,7 +626,7 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { auto stmt = nodes[i]->block->statements[j].get(); if (stmt->is() || stmt->is() || stmt->is() || stmt->is() || - stmt->is()) { + stmt->is() || stmt->is()) { // TODO: unify them // A global pointer that may contain some data before this kernel. nodes[start_node]->reach_gen.insert(stmt); From 5aaf917bbdd858ba7db43039922a750927d99f1b Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Mon, 19 Jul 2021 18:55:57 +0800 Subject: [PATCH 04/34] fix offset=4 for i32 and f32 --- taichi/runtime/llvm/runtime.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/taichi/runtime/llvm/runtime.cpp b/taichi/runtime/llvm/runtime.cpp index ff14c73880d0c..c3f1a14f095f5 100644 --- a/taichi/runtime/llvm/runtime.cpp +++ b/taichi/runtime/llvm/runtime.cpp @@ -366,11 +366,11 @@ STRUCT_FIELD(StructMeta, is_active); STRUCT_FIELD(StructMeta, context); f32* access_with_offset_f32(f32* origin, i32 offset) { - return (f32*)((i64)origin + offset * 8); + return (f32*)((i64)origin + offset * 4); } i32* access_with_offset_i32(i32* origin, i32 offset) { - return (i32*)((i64)origin + offset * 8); + return (i32*)((i64)origin + offset * 4); } struct LLVMRuntime; From 1e2ab62132dd24e06105b7a2d75421563efede54 Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Wed, 21 Jul 2021 15:32:46 +0800 Subject: [PATCH 05/34] support matrix indices access, AOS SOA, and arbitrary date type --- python/taichi/lang/impl.py | 4 ++-- python/taichi/lang/matrix.py | 3 ++- taichi/codegen/codegen_llvm.cpp | 12 ++++------ taichi/ir/frontend_ir.cpp | 41 +++++++++++++++++++++++++++++---- taichi/ir/frontend_ir.h | 10 ++++---- taichi/python/export_lang.cpp | 4 ++-- taichi/runtime/llvm/runtime.cpp | 8 ------- 7 files changed, 54 insertions(+), 28 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index da3e69ec69180..bdaebef092957 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -146,8 +146,8 @@ def subscript(value, *indices): return value[indices] @taichi_scope -def subscript_with_offset(origin, offset): - return Expr(_ti_core.subscript_with_offset(origin.ptr, make_constant_expr(offset).ptr)) +def subscript_with_offset(var, indices, cols, is_AOS): + return Expr(_ti_core.subscript_with_offset(var.ptr, make_expr_group(*indices), cols, is_AOS)) @taichi_scope def chain_compare(comparators, ops): diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 3ff629562c4a1..a6dac4bc06212 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -232,7 +232,8 @@ def linearize_entry_id(self, *args): def __call__(self, *args, **kwargs): _taichi_skip_traceback = 1 assert kwargs == {} - return ti.subscript_with_offset(self.entries[0], self.linearize_entry_id(*args)) + # TODO: AOS hard-coded here + return ti.subscript_with_offset(self.entries[0], args, self.m, True) def get_field_members(self): return self.entries diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index c63e770e0badb..88ca46b8f8a7c 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1379,13 +1379,11 @@ void CodeGenLLVM::visit(GetChStmt *stmt) { } void CodeGenLLVM::visit(GlobalTensorElementStmt *stmt) { - if (stmt->ret_type.ptr_removed()->is_primitive(PrimitiveTypeID::f32)) { - llvm_val[stmt] = - create_call("access_with_offset_f32", {llvm_val[stmt->origin], llvm_val[stmt->offset]}); - } else if (stmt->ret_type.ptr_removed()->is_primitive(PrimitiveTypeID::i32)) { - llvm_val[stmt] = - create_call("access_with_offset_i32", {llvm_val[stmt->origin], llvm_val[stmt->offset]}); - } + auto origin_address = builder->CreatePtrToInt(llvm_val[stmt->origin], llvm::Type::getInt64Ty(*llvm_context)); + auto offset_address = builder->CreateSExt(llvm_val[stmt->offset], llvm::Type::getInt64Ty(*llvm_context)); + auto target_address = builder->CreateAdd(origin_address, offset_address); + auto dt = stmt->ret_type.ptr_removed(); + llvm_val[stmt] = builder->CreateIntToPtr(target_address, llvm::PointerType::get(tlctx->get_data_type(dt), 0)); } void CodeGenLLVM::visit(ExternalPtrStmt *stmt) { diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 8f7f1fdf90014..8a7c07e59598f 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -220,11 +220,44 @@ void GlobalPtrExpression::flatten(FlattenContext *ctx) { } void GlobalTensorElementExpression::flatten(FlattenContext *ctx) { - origin_expr->flatten(ctx); - Stmt *origin_stmt = ctx->back_stmt(); - offset->flatten(ctx); + TI_ASSERT(var.is()) + var->flatten(ctx); + Stmt *var_stmt = ctx->back_stmt(); + SNode *snode = var.cast()->var.cast()->snode; + // Compute exact offset + // Type A[i, j][x, y] + // ^^^^ + TI_ASSERT(1 <= indices.size() && indices.size() <= 2) + if (indices.size() == 1) { + indices[0]->flatten(ctx); + } else { + indices[0]->flatten(ctx); + Stmt* i_stmt = ctx->back_stmt(); + Stmt* cols_stmt = ctx->push_back(Stmt::make(LaneAttribute(cols))); + Stmt* i_mul_cols_stmt = ctx->push_back(Stmt::make(BinaryOpType::mul, i_stmt, cols_stmt)); + indices[1]->flatten(ctx); + Stmt* j_stmt = ctx->back_stmt(); + ctx->push_back(Stmt::make(BinaryOpType::add, i_mul_cols_stmt, j_stmt)); + } + // Type A[i, j][x, y] + // ^ ^ + if (!is_AOS) { + TI_ASSERT(snode->is_path_all_dense) + int size = 1; + for (int index = 0; index < taichi_max_num_indices; ++index) + size <<= snode->get_num_bits(index); + Stmt* offset_stmt = ctx->back_stmt(); + Stmt* field_size_stmt = ctx->push_back(Stmt::make(LaneAttribute(size))); + ctx->push_back(Stmt::make(BinaryOpType::mul, offset_stmt, field_size_stmt)); + } + // Type A[i, j][x, y] + // ^^^^ + Stmt* offset_stmt = ctx->back_stmt(); + Stmt* dt_size_stmt = ctx->push_back(Stmt::make(LaneAttribute(data_type_size(snode->dt)))); + ctx->push_back(Stmt::make(BinaryOpType::mul, offset_stmt, dt_size_stmt)); + ctx->push_back(std::make_unique( - origin_stmt, ctx->back_stmt())); + var_stmt, ctx->back_stmt())); stmt = ctx->back_stmt(); } diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 43fdd99111eb7..b3cc7301d592d 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -432,11 +432,13 @@ class GlobalPtrExpression : public Expression { class GlobalTensorElementExpression : public Expression { public: - Expr origin_expr; - Expr offset; + Expr var; + ExprGroup indices; + int cols; + bool is_AOS; - GlobalTensorElementExpression(const Expr &origin_expr, const Expr &offset) - : origin_expr(origin_expr), offset(offset) { + GlobalTensorElementExpression(const Expr &var, const ExprGroup &indices, int cols, bool is_AOS) + : var(var), indices(indices), cols(cols), is_AOS(is_AOS) { } std::string serialize() override { diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index d3d562dea2d51..b399e1ba6d7de 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -679,8 +679,8 @@ void export_lang(py::module &m) { return expr[expr_group]; }); - m.def("subscript_with_offset", [](const Expr &origin_expr, const Expr &offset) { - return Expr::make(origin_expr, offset); + m.def("subscript_with_offset", [](const Expr &var, const ExprGroup &indices, int cols, bool is_AOS) { + return Expr::make(var, indices, cols, is_AOS); }); m.def("subscript", [](SNode *snode, const ExprGroup &indices) { diff --git a/taichi/runtime/llvm/runtime.cpp b/taichi/runtime/llvm/runtime.cpp index c3f1a14f095f5..2888e7f19c363 100644 --- a/taichi/runtime/llvm/runtime.cpp +++ b/taichi/runtime/llvm/runtime.cpp @@ -365,14 +365,6 @@ STRUCT_FIELD(StructMeta, refine_coordinates); STRUCT_FIELD(StructMeta, is_active); STRUCT_FIELD(StructMeta, context); -f32* access_with_offset_f32(f32* origin, i32 offset) { - return (f32*)((i64)origin + offset * 4); -} - -i32* access_with_offset_i32(i32* origin, i32 offset) { - return (i32*)((i64)origin + offset * 4); -} - struct LLVMRuntime; constexpr bool enable_assert = true; From cbc795025292a58c50e39940f48f53954026a058 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Wed, 21 Jul 2021 07:35:06 +0000 Subject: [PATCH 06/34] Auto Format --- python/taichi/lang/impl.py | 6 ++- taichi/codegen/codegen_llvm.cpp | 9 ++-- taichi/ir/control_flow_graph.cpp | 3 +- taichi/ir/expr.cpp | 3 +- taichi/ir/frontend_ir.cpp | 87 ++++++++++++++++++-------------- taichi/ir/frontend_ir.h | 5 +- taichi/ir/statements.h | 4 +- taichi/python/export_lang.cpp | 8 +-- taichi/transforms/ir_printer.cpp | 3 +- taichi/transforms/type_check.cpp | 2 +- 10 files changed, 76 insertions(+), 54 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index bdaebef092957..90770a0b955cd 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -145,9 +145,13 @@ def subscript(value, *indices): else: return value[indices] + @taichi_scope def subscript_with_offset(var, indices, cols, is_AOS): - return Expr(_ti_core.subscript_with_offset(var.ptr, make_expr_group(*indices), cols, is_AOS)) + return Expr( + _ti_core.subscript_with_offset(var.ptr, make_expr_group(*indices), + cols, is_AOS)) + @taichi_scope def chain_compare(comparators, ops): diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 88ca46b8f8a7c..16102be19f9c7 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1379,11 +1379,14 @@ void CodeGenLLVM::visit(GetChStmt *stmt) { } void CodeGenLLVM::visit(GlobalTensorElementStmt *stmt) { - auto origin_address = builder->CreatePtrToInt(llvm_val[stmt->origin], llvm::Type::getInt64Ty(*llvm_context)); - auto offset_address = builder->CreateSExt(llvm_val[stmt->offset], llvm::Type::getInt64Ty(*llvm_context)); + auto origin_address = builder->CreatePtrToInt( + llvm_val[stmt->origin], llvm::Type::getInt64Ty(*llvm_context)); + auto offset_address = builder->CreateSExt( + llvm_val[stmt->offset], llvm::Type::getInt64Ty(*llvm_context)); auto target_address = builder->CreateAdd(origin_address, offset_address); auto dt = stmt->ret_type.ptr_removed(); - llvm_val[stmt] = builder->CreateIntToPtr(target_address, llvm::PointerType::get(tlctx->get_data_type(dt), 0)); + llvm_val[stmt] = builder->CreateIntToPtr( + target_address, llvm::PointerType::get(tlctx->get_data_type(dt), 0)); } void CodeGenLLVM::visit(ExternalPtrStmt *stmt) { diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index e8c856a633218..27e09a45cb58f 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -626,7 +626,8 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { auto stmt = nodes[i]->block->statements[j].get(); if (stmt->is() || stmt->is() || stmt->is() || stmt->is() || - stmt->is() || stmt->is()) { + stmt->is() || + stmt->is()) { // TODO: unify them // A global pointer that may contain some data before this kernel. nodes[start_node]->reach_gen.insert(stmt); diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index f996c486df2b1..b74372fb0073b 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -174,7 +174,8 @@ Expr load_if_ptr(const Expr &ptr) { } Expr load(const Expr &ptr) { - TI_ASSERT(ptr.is() || ptr.is()); + TI_ASSERT(ptr.is() || + ptr.is()); return Expr::make(ptr); } diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 8a7c07e59598f..0ba6e0c251a32 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -220,45 +220,54 @@ void GlobalPtrExpression::flatten(FlattenContext *ctx) { } void GlobalTensorElementExpression::flatten(FlattenContext *ctx) { - TI_ASSERT(var.is()) - var->flatten(ctx); - Stmt *var_stmt = ctx->back_stmt(); - SNode *snode = var.cast()->var.cast()->snode; - // Compute exact offset - // Type A[i, j][x, y] - // ^^^^ - TI_ASSERT(1 <= indices.size() && indices.size() <= 2) - if (indices.size() == 1) { - indices[0]->flatten(ctx); - } else { - indices[0]->flatten(ctx); - Stmt* i_stmt = ctx->back_stmt(); - Stmt* cols_stmt = ctx->push_back(Stmt::make(LaneAttribute(cols))); - Stmt* i_mul_cols_stmt = ctx->push_back(Stmt::make(BinaryOpType::mul, i_stmt, cols_stmt)); - indices[1]->flatten(ctx); - Stmt* j_stmt = ctx->back_stmt(); - ctx->push_back(Stmt::make(BinaryOpType::add, i_mul_cols_stmt, j_stmt)); - } - // Type A[i, j][x, y] - // ^ ^ - if (!is_AOS) { - TI_ASSERT(snode->is_path_all_dense) - int size = 1; - for (int index = 0; index < taichi_max_num_indices; ++index) - size <<= snode->get_num_bits(index); - Stmt* offset_stmt = ctx->back_stmt(); - Stmt* field_size_stmt = ctx->push_back(Stmt::make(LaneAttribute(size))); - ctx->push_back(Stmt::make(BinaryOpType::mul, offset_stmt, field_size_stmt)); - } - // Type A[i, j][x, y] - // ^^^^ - Stmt* offset_stmt = ctx->back_stmt(); - Stmt* dt_size_stmt = ctx->push_back(Stmt::make(LaneAttribute(data_type_size(snode->dt)))); - ctx->push_back(Stmt::make(BinaryOpType::mul, offset_stmt, dt_size_stmt)); - - ctx->push_back(std::make_unique( - var_stmt, ctx->back_stmt())); - stmt = ctx->back_stmt(); + TI_ASSERT(var.is()) + var->flatten(ctx); + Stmt *var_stmt = ctx->back_stmt(); + SNode *snode = var.cast() + ->var.cast() + ->snode; + // Compute exact offset + // Type A[i, j][x, y] + // ^^^^ + TI_ASSERT(1 <= indices.size() && indices.size() <= 2) + if (indices.size() == 1) { + indices[0]->flatten(ctx); + } else { + indices[0]->flatten(ctx); + Stmt *i_stmt = ctx->back_stmt(); + Stmt *cols_stmt = ctx->push_back( + Stmt::make(LaneAttribute(cols))); + Stmt *i_mul_cols_stmt = ctx->push_back( + Stmt::make(BinaryOpType::mul, i_stmt, cols_stmt)); + indices[1]->flatten(ctx); + Stmt *j_stmt = ctx->back_stmt(); + ctx->push_back( + Stmt::make(BinaryOpType::add, i_mul_cols_stmt, j_stmt)); + } + // Type A[i, j][x, y] + // ^ ^ + if (!is_AOS) { + TI_ASSERT(snode->is_path_all_dense) + int size = 1; + for (int index = 0; index < taichi_max_num_indices; ++index) + size <<= snode->get_num_bits(index); + Stmt *offset_stmt = ctx->back_stmt(); + Stmt *field_size_stmt = ctx->push_back( + Stmt::make(LaneAttribute(size))); + ctx->push_back(Stmt::make(BinaryOpType::mul, offset_stmt, + field_size_stmt)); + } + // Type A[i, j][x, y] + // ^^^^ + Stmt *offset_stmt = ctx->back_stmt(); + Stmt *dt_size_stmt = ctx->push_back(Stmt::make( + LaneAttribute(data_type_size(snode->dt)))); + ctx->push_back( + Stmt::make(BinaryOpType::mul, offset_stmt, dt_size_stmt)); + + ctx->push_back( + std::make_unique(var_stmt, ctx->back_stmt())); + stmt = ctx->back_stmt(); } void RangeAssumptionExpression::flatten(FlattenContext *ctx) { diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index b3cc7301d592d..b4d8de85f1979 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -437,7 +437,10 @@ class GlobalTensorElementExpression : public Expression { int cols; bool is_AOS; - GlobalTensorElementExpression(const Expr &var, const ExprGroup &indices, int cols, bool is_AOS) + GlobalTensorElementExpression(const Expr &var, + const ExprGroup &indices, + int cols, + bool is_AOS) : var(var), indices(indices), cols(cols), is_AOS(is_AOS) { } diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index e33ecfd8717e9..9d29b61d4e9a1 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -312,9 +312,7 @@ class GlobalTensorElementStmt : public Stmt { bool is_bit_vectorized; // TODO: remove this field GlobalTensorElementStmt(Stmt *origin, Stmt *offset) - : origin(origin), - offset(offset), - is_bit_vectorized(is_bit_vectorized) { + : origin(origin), offset(offset), is_bit_vectorized(is_bit_vectorized) { element_type() = origin->cast()->ret_type; TI_STMT_REG_FIELDS; } diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index b399e1ba6d7de..263fef85e9ba0 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -679,9 +679,11 @@ void export_lang(py::module &m) { return expr[expr_group]; }); - m.def("subscript_with_offset", [](const Expr &var, const ExprGroup &indices, int cols, bool is_AOS) { - return Expr::make(var, indices, cols, is_AOS); - }); + m.def("subscript_with_offset", + [](const Expr &var, const ExprGroup &indices, int cols, bool is_AOS) { + return Expr::make(var, indices, cols, + is_AOS); + }); m.def("subscript", [](SNode *snode, const ExprGroup &indices) { return Expr::make(snode, indices.loaded()); diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index 21f53bf25e29d..a1b09ffa90a3f 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -386,7 +386,8 @@ class IRPrinter : public IRVisitor { void visit(GlobalTensorElementStmt *stmt) override { // TODO: do actual ir_printer std::string s = - fmt::format("{}{} = global tensor element {} {}", stmt->type_hint(), stmt->name(), stmt->origin->name(), stmt->offset->name()); + fmt::format("{}{} = global tensor element {} {}", stmt->type_hint(), + stmt->name(), stmt->origin->name(), stmt->offset->name()); print_raw(s); } diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 559d6d259dc4d..6650577c08a8a 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -160,7 +160,7 @@ class TypeCheck : public IRVisitor { void visit(GlobalTensorElementStmt *stmt) override { // TODO: do actual type_check - stmt->ret_type.set_is_pointer(true); + stmt->ret_type.set_is_pointer(true); } void visit(GlobalStoreStmt *stmt) override { From 7684a3ff90cb087809bb8eb814041e3ce07f96a6 Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Wed, 21 Jul 2021 16:03:25 +0800 Subject: [PATCH 07/34] do better serialize() for GlobalTensorElementExpression and GlobalTensorElementStmt --- taichi/ir/frontend_ir.h | 10 +++++++++- taichi/transforms/ir_printer.cpp | 3 +-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index b4d8de85f1979..0c82fa4afddbe 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -445,7 +445,15 @@ class GlobalTensorElementExpression : public Expression { } std::string serialize() override { - return "@@"; + std::string s = fmt::format("{}[", var.serialize()); + for (int i = 0; i < (int)indices.size(); i++) { + s += indices.exprs[i]->serialize(); + if (i + 1 < (int)indices.size()) + s += ", "; + } + s += "]"; + s += " (col=" + std::to_string(cols) + (is_AOS ? ", AOS)" : ", SOA)"); + return s; } void flatten(FlattenContext *ctx) override; diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index a1b09ffa90a3f..a2edcc0c49823 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -384,9 +384,8 @@ class IRPrinter : public IRVisitor { } void visit(GlobalTensorElementStmt *stmt) override { - // TODO: do actual ir_printer std::string s = - fmt::format("{}{} = global tensor element {} {}", stmt->type_hint(), + fmt::format("{}{} = shift ptr [{} + {}]", stmt->type_hint(), stmt->name(), stmt->origin->name(), stmt->offset->name()); print_raw(s); } From fbd0f1a454370de494a337fc3803b07d132d71df Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Wed, 21 Jul 2021 16:34:53 +0800 Subject: [PATCH 08/34] simplify make --- taichi/ir/frontend_ir.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 0ba6e0c251a32..9a6593c0b2b8a 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -236,7 +236,7 @@ void GlobalTensorElementExpression::flatten(FlattenContext *ctx) { indices[0]->flatten(ctx); Stmt *i_stmt = ctx->back_stmt(); Stmt *cols_stmt = ctx->push_back( - Stmt::make(LaneAttribute(cols))); + Stmt::make(TypedConstant(cols))); Stmt *i_mul_cols_stmt = ctx->push_back( Stmt::make(BinaryOpType::mul, i_stmt, cols_stmt)); indices[1]->flatten(ctx); @@ -253,7 +253,7 @@ void GlobalTensorElementExpression::flatten(FlattenContext *ctx) { size <<= snode->get_num_bits(index); Stmt *offset_stmt = ctx->back_stmt(); Stmt *field_size_stmt = ctx->push_back( - Stmt::make(LaneAttribute(size))); + Stmt::make(TypedConstant(size))); ctx->push_back(Stmt::make(BinaryOpType::mul, offset_stmt, field_size_stmt)); } @@ -261,7 +261,7 @@ void GlobalTensorElementExpression::flatten(FlattenContext *ctx) { // ^^^^ Stmt *offset_stmt = ctx->back_stmt(); Stmt *dt_size_stmt = ctx->push_back(Stmt::make( - LaneAttribute(data_type_size(snode->dt)))); + TypedConstant(data_type_size(snode->dt)))); ctx->push_back( Stmt::make(BinaryOpType::mul, offset_stmt, dt_size_stmt)); From 1120afbbe1b712dbdca88f5c4d2969656af8e026 Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Thu, 22 Jul 2021 17:55:03 +0800 Subject: [PATCH 09/34] fix ti test --- python/taichi/lang/matrix.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index a6dac4bc06212..0ba5e00429b65 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -232,8 +232,7 @@ def linearize_entry_id(self, *args): def __call__(self, *args, **kwargs): _taichi_skip_traceback = 1 assert kwargs == {} - # TODO: AOS hard-coded here - return ti.subscript_with_offset(self.entries[0], args, self.m, True) + return self.entries[self.linearize_entry_id(*args)] def get_field_members(self): return self.entries @@ -269,7 +268,10 @@ def subscript(self, *indices): assert len(indices) in [1, 2] i = indices[0] j = 0 if len(indices) == 1 else indices[1] - return self(i, j) + if is_taichi_class(self.entries[0]): + return ti.subscript_with_offset(self.entries[0], (i, j), self.m, True) + else: + return self(i, j) @property def x(self): From 11d89e0b94b9e23e2a26ddca54c1a364ecb7969e Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Thu, 22 Jul 2021 18:45:07 +0800 Subject: [PATCH 10/34] fix test and add load_if_ptr for dynamic indices --- python/taichi/lang/matrix.py | 2 +- taichi/ir/frontend_ir.cpp | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 0ba5e00429b65..cd1638c0e9308 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -268,7 +268,7 @@ def subscript(self, *indices): assert len(indices) in [1, 2] i = indices[0] j = 0 if len(indices) == 1 else indices[1] - if is_taichi_class(self.entries[0]): + if is_taichi_class(self): return ti.subscript_with_offset(self.entries[0], (i, j), self.m, True) else: return self(i, j) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 9a6593c0b2b8a..340f2952d70c3 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -231,14 +231,17 @@ void GlobalTensorElementExpression::flatten(FlattenContext *ctx) { // ^^^^ TI_ASSERT(1 <= indices.size() && indices.size() <= 2) if (indices.size() == 1) { + indices[0].set(load_if_ptr(indices[0])); indices[0]->flatten(ctx); } else { + indices[0].set(load_if_ptr(indices[0])); indices[0]->flatten(ctx); Stmt *i_stmt = ctx->back_stmt(); Stmt *cols_stmt = ctx->push_back( Stmt::make(TypedConstant(cols))); Stmt *i_mul_cols_stmt = ctx->push_back( Stmt::make(BinaryOpType::mul, i_stmt, cols_stmt)); + indices[1].set(load_if_ptr(indices[1])); indices[1]->flatten(ctx); Stmt *j_stmt = ctx->back_stmt(); ctx->push_back( From 2705249353eedf3a44f16aed29e2c73c93d2d586 Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Thu, 22 Jul 2021 18:54:47 +0800 Subject: [PATCH 11/34] fix test --- python/taichi/lang/matrix.py | 2 +- tests/python/test_matrix.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index cd1638c0e9308..4dee900d282f5 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -268,7 +268,7 @@ def subscript(self, *indices): assert len(indices) in [1, 2] i = indices[0] j = 0 if len(indices) == 1 else indices[1] - if is_taichi_class(self): + if isinstance(self.entries[0], ti.Expr): return ti.subscript_with_offset(self.entries[0], (i, j), self.m, True) else: return self(i, j) diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 5b26574e19249..462029a6ac2b6 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -171,8 +171,7 @@ def func(): for j, k in ti.ndrange(2, 2): m[i][j, k] = 12 - with pytest.raises(ti.TaichiSyntaxError): - func() + func() @ti.test(arch=ti.cpu) From 587569b61e882767670c5ab0591c248e6a21e328 Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Fri, 23 Jul 2021 16:52:09 +0800 Subject: [PATCH 12/34] fix more tests --- python/taichi/lang/matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 4dee900d282f5..0006ff915e248 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -268,7 +268,7 @@ def subscript(self, *indices): assert len(indices) in [1, 2] i = indices[0] j = 0 if len(indices) == 1 else indices[1] - if isinstance(self.entries[0], ti.Expr): + if self.entries[0].ptr.is_global_var() and (ti.cfg.arch == ti.cpu or ti.cfg.arch == ti.gpu): return ti.subscript_with_offset(self.entries[0], (i, j), self.m, True) else: return self(i, j) From b009a9431152546a3d4b28bb6e8aed431c27c105 Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Fri, 23 Jul 2021 18:12:49 +0800 Subject: [PATCH 13/34] fix more tests --- python/taichi/lang/matrix.py | 2 +- taichi/ir/frontend_ir.cpp | 4 ++++ taichi/python/export_lang.cpp | 2 ++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 0006ff915e248..fc27ab2e00cd9 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -268,7 +268,7 @@ def subscript(self, *indices): assert len(indices) in [1, 2] i = indices[0] j = 0 if len(indices) == 1 else indices[1] - if self.entries[0].ptr.is_global_var() and (ti.cfg.arch == ti.cpu or ti.cfg.arch == ti.gpu): + if isinstance(self.entries[0], ti.Expr) and self.entries[0].ptr.is_global_ptr() and (ti.cfg.arch == ti.cpu or ti.cfg.arch == ti.gpu): return ti.subscript_with_offset(self.entries[0], (i, j), self.m, True) else: return self(i, j) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 340f2952d70c3..087bda68a8069 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -351,6 +351,10 @@ void AtomicOpExpression::flatten(FlattenContext *ctx) { // emit local store stmt auto alloca = ctx->current_block->lookup_var(dest.cast()->id); ctx->push_back(op_type, alloca, expr->stmt); + } else if (dest.is()) { + auto global_ptr = dest.cast(); + global_ptr->flatten(ctx); + ctx->push_back(op_type, ctx->back_stmt(), expr->stmt); } else { // global variable TI_ASSERT(dest.is()); auto global_ptr = dest.cast(); diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 263fef85e9ba0..113297d372906 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -360,6 +360,8 @@ void export_lang(py::module &m) { [](Expr *expr) { return expr->is(); }) .def("is_external_var", [](Expr *expr) { return expr->is(); }) + .def("is_global_ptr", + [](Expr *expr) { return expr->is(); }) .def("set_tb", &Expr::set_tb) .def("set_name", [&](Expr *expr, std::string na) { From 4d68187436c07ebc7c372a0fdcc4e664db6c6f80 Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Mon, 26 Jul 2021 11:57:21 +0800 Subject: [PATCH 14/34] fix test --- taichi/transforms/flag_access.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/taichi/transforms/flag_access.cpp b/taichi/transforms/flag_access.cpp index 02831b376b4dd..a26497e6c5698 100644 --- a/taichi/transforms/flag_access.cpp +++ b/taichi/transforms/flag_access.cpp @@ -54,6 +54,11 @@ class FlagAccess : public IRVisitor { if (stmt->dest->is()) { stmt->dest->as()->activate = true; } + if (stmt->dest->is()) { + if (stmt->dest->as()->origin->is()) { + stmt->dest->as()->origin->as()->activate = true; + } + } } void visit(AtomicOpStmt *stmt) { From e445596f9170c426128023ee72ccc4c457efb953 Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Mon, 26 Jul 2021 12:43:08 +0800 Subject: [PATCH 15/34] fix test --- taichi/program/async_utils.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/taichi/program/async_utils.cpp b/taichi/program/async_utils.cpp index 0ec4aefe33b5b..c4eff3ed1ae09 100644 --- a/taichi/program/async_utils.cpp +++ b/taichi/program/async_utils.cpp @@ -165,6 +165,14 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { ir_bank->get_async_state(snode, AsyncState::Type::value)); } } + if (auto global_tensor_element = global_store->dest->cast()) { + if (auto dest = global_tensor_element->origin->cast()) { + for (auto &snode : dest->snodes.data) { + meta.output_states.insert( + ir_bank->get_async_state(snode, AsyncState::Type::value)); + } + } + } } if (auto global_atomic = stmt->cast()) { if (auto dest = global_atomic->dest->cast()) { From cc6317d76eaaa1986ae2b145eaaf46a6bba2985e Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Mon, 26 Jul 2021 13:35:34 +0800 Subject: [PATCH 16/34] update test --- tests/python/test_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 462029a6ac2b6..74572077e78b0 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -161,7 +161,7 @@ def run(): assert np.allclose(r2[None].value.to_numpy(), ops(a, c)) -@ti.test(arch=ti.cpu) +@ti.test(arch=[ti.cpu, ti.gpu]) def test_matrix_non_constant_index(): m = ti.Matrix.field(2, 2, ti.i32, 5) From 9d3ecb1bd92c5014bca4c22b0d65bea6c8b88ea4 Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Mon, 26 Jul 2021 14:23:24 +0800 Subject: [PATCH 17/34] merge condition in load_if_ptr --- taichi/ir/expr.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index b74372fb0073b..7331e4305d6f4 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -161,9 +161,7 @@ void Expr::operator/=(const Expr &o) { } Expr load_if_ptr(const Expr &ptr) { - if (ptr.is()) { - return load(ptr); - } else if (ptr.is()) { + if (ptr.is() or ptr.is()) { return load(ptr); } else if (ptr.is()) { TI_ASSERT(ptr.cast()->snode->num_active_indices == From 46dd8a20ddde35b90cda21b34fcf06523ac99edb Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Mon, 26 Jul 2021 14:27:56 +0800 Subject: [PATCH 18/34] replace is_AOS with is_aos --- python/taichi/lang/impl.py | 4 ++-- taichi/ir/frontend_ir.cpp | 2 +- taichi/ir/frontend_ir.h | 8 ++++---- taichi/python/export_lang.cpp | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 90770a0b955cd..25011084cbed0 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -147,10 +147,10 @@ def subscript(value, *indices): @taichi_scope -def subscript_with_offset(var, indices, cols, is_AOS): +def subscript_with_offset(var, indices, cols, is_aos): return Expr( _ti_core.subscript_with_offset(var.ptr, make_expr_group(*indices), - cols, is_AOS)) + cols, is_aos)) @taichi_scope diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 087bda68a8069..bf37461067a89 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -249,7 +249,7 @@ void GlobalTensorElementExpression::flatten(FlattenContext *ctx) { } // Type A[i, j][x, y] // ^ ^ - if (!is_AOS) { + if (!is_aos) { TI_ASSERT(snode->is_path_all_dense) int size = 1; for (int index = 0; index < taichi_max_num_indices; ++index) diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 0c82fa4afddbe..6a988989c71b5 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -435,13 +435,13 @@ class GlobalTensorElementExpression : public Expression { Expr var; ExprGroup indices; int cols; - bool is_AOS; + bool is_aos; GlobalTensorElementExpression(const Expr &var, const ExprGroup &indices, int cols, - bool is_AOS) - : var(var), indices(indices), cols(cols), is_AOS(is_AOS) { + bool is_aos) + : var(var), indices(indices), cols(cols), is_aos(is_aos) { } std::string serialize() override { @@ -452,7 +452,7 @@ class GlobalTensorElementExpression : public Expression { s += ", "; } s += "]"; - s += " (col=" + std::to_string(cols) + (is_AOS ? ", AOS)" : ", SOA)"); + s += " (col=" + std::to_string(cols) + (is_aos ? ", AOS)" : ", SOA)"); return s; } diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 113297d372906..ac16c3f2d1b00 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -682,9 +682,9 @@ void export_lang(py::module &m) { }); m.def("subscript_with_offset", - [](const Expr &var, const ExprGroup &indices, int cols, bool is_AOS) { + [](const Expr &var, const ExprGroup &indices, int cols, bool is_aos) { return Expr::make(var, indices, cols, - is_AOS); + is_aos); }); m.def("subscript", [](SNode *snode, const ExprGroup &indices) { From 6922a516e0dcccd02be02ec07f3a8fa6f5ad371c Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Mon, 26 Jul 2021 15:07:45 +0800 Subject: [PATCH 19/34] remove is_bit_vectorized --- taichi/ir/frontend_ir.h | 2 +- taichi/ir/statements.h | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 6a988989c71b5..fcab21f9aa0eb 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -435,7 +435,7 @@ class GlobalTensorElementExpression : public Expression { Expr var; ExprGroup indices; int cols; - bool is_aos; + bool is_aos{false}; GlobalTensorElementExpression(const Expr &var, const ExprGroup &indices, diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 9d29b61d4e9a1..018881c1f1b2a 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -309,10 +309,9 @@ class GlobalPtrStmt : public Stmt { class GlobalTensorElementStmt : public Stmt { public: Stmt *origin, *offset; - bool is_bit_vectorized; // TODO: remove this field GlobalTensorElementStmt(Stmt *origin, Stmt *offset) - : origin(origin), offset(offset), is_bit_vectorized(is_bit_vectorized) { + : origin(origin), offset(offset) { element_type() = origin->cast()->ret_type; TI_STMT_REG_FIELDS; } @@ -321,7 +320,7 @@ class GlobalTensorElementStmt : public Stmt { return false; } - TI_STMT_DEF_FIELDS(ret_type, origin, offset, is_bit_vectorized); + TI_STMT_DEF_FIELDS(ret_type, origin, offset); TI_DEFINE_ACCEPT_AND_CLONE }; From d05919526184e7de285390bd1170127e859d7397 Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Mon, 26 Jul 2021 15:08:53 +0800 Subject: [PATCH 20/34] add default value for GlobalTensorElementStmt --- taichi/ir/statements.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 018881c1f1b2a..fd0d1c48cb31a 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -308,7 +308,8 @@ class GlobalPtrStmt : public Stmt { */ class GlobalTensorElementStmt : public Stmt { public: - Stmt *origin, *offset; + Stmt *origin{nullptr}; + Stmt *offset{nullptr}; GlobalTensorElementStmt(Stmt *origin, Stmt *offset) : origin(origin), offset(offset) { From e3cc8935789b669a016b9adfb5ac1e529b47212e Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Mon, 26 Jul 2021 15:09:40 +0800 Subject: [PATCH 21/34] add default value for GlobalTensorElementExpression --- taichi/ir/frontend_ir.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index fcab21f9aa0eb..2b219ddfdeb05 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -434,7 +434,7 @@ class GlobalTensorElementExpression : public Expression { public: Expr var; ExprGroup indices; - int cols; + int cols{0}; bool is_aos{false}; GlobalTensorElementExpression(const Expr &var, From ff80ccdef11003b85b8c77efeb57d47a9b8cf161 Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Mon, 26 Jul 2021 15:21:32 +0800 Subject: [PATCH 22/34] update computation for field size --- taichi/ir/frontend_ir.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index bf37461067a89..560a90f40d848 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -252,8 +252,8 @@ void GlobalTensorElementExpression::flatten(FlattenContext *ctx) { if (!is_aos) { TI_ASSERT(snode->is_path_all_dense) int size = 1; - for (int index = 0; index < taichi_max_num_indices; ++index) - size <<= snode->get_num_bits(index); + for (auto *s = snode; s != nullptr; s = s->parent) + size *= (int)s->max_num_elements(); Stmt *offset_stmt = ctx->back_stmt(); Stmt *field_size_stmt = ctx->push_back( Stmt::make(TypedConstant(size))); From 99e4a187badf18ad9aeb8d60daf71498594a990b Mon Sep 17 00:00:00 2001 From: squarefk Date: Tue, 27 Jul 2021 10:04:51 +0800 Subject: [PATCH 23/34] Update taichi/ir/expr.cpp Co-authored-by: xumingkuan --- taichi/ir/expr.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index 7331e4305d6f4..47307a001be4a 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -161,7 +161,7 @@ void Expr::operator/=(const Expr &o) { } Expr load_if_ptr(const Expr &ptr) { - if (ptr.is() or ptr.is()) { + if (ptr.is() || ptr.is()) { return load(ptr); } else if (ptr.is()) { TI_ASSERT(ptr.cast()->snode->num_active_indices == From a68187fb4d0d4fedd36f713a78561143f1ef3e85 Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Tue, 27 Jul 2021 10:12:17 +0800 Subject: [PATCH 24/34] comment is_global_ptr() --- python/taichi/lang/matrix.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index b7c2b9b2aab1c..2a1b91c14973e 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -268,6 +268,7 @@ def subscript(self, *indices): assert len(indices) in [1, 2] i = indices[0] j = 0 if len(indices) == 1 else indices[1] + # ptr.is_global_ptr() will check whether it's an element in the field (which is different from ptr.is_global_var()). if isinstance(self.entries[0], ti.Expr) and self.entries[0].ptr.is_global_ptr() and (ti.cfg.arch == ti.cpu or ti.cfg.arch == ti.gpu): return ti.subscript_with_offset(self.entries[0], (i, j), self.m, True) else: From 1123d63e7ce115f058222f0af7e21d39add13d0e Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Tue, 27 Jul 2021 10:14:04 +0800 Subject: [PATCH 25/34] nit --- taichi/transforms/lower_access.cpp | 44 +++++++++++++++--------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp index 7f460702e1f3a..806e8bbf9ab46 100644 --- a/taichi/transforms/lower_access.cpp +++ b/taichi/transforms/lower_access.cpp @@ -148,35 +148,35 @@ class LowerAccess : public IRVisitor { } void visit(GlobalLoadStmt *stmt) override { - if (stmt->src->is()) { - // No need to activate for all read accesses - auto lowered = lower_vector_ptr(stmt->src->as(), false); - stmt->src = lowered.back().get(); - modifier.insert_before(stmt, std::move(lowered)); - } + if (!stmt->src->is()) + return; + // No need to activate for all read accesses + auto lowered = lower_vector_ptr(stmt->src->as(), false); + stmt->src = lowered.back().get(); + modifier.insert_before(stmt, std::move(lowered)); } // TODO: this seems to be redundant void visit(GlobalTensorElementStmt *stmt) override { - if (stmt->origin->is()) { - auto ptr = stmt->origin->as(); - // If ptr already has activate = false, no need to activate all the - // generated micro-access ops. Otherwise, activate the nodes. - auto lowered = lower_vector_ptr(ptr, ptr->activate); - stmt->origin = lowered.back().get(); - modifier.insert_before(stmt, std::move(lowered)); - } + if (!stmt->origin->is()) + return; + auto ptr = stmt->origin->as(); + // If ptr already has activate = false, no need to activate all the + // generated micro-access ops. Otherwise, activate the nodes. + auto lowered = lower_vector_ptr(ptr, ptr->activate); + stmt->origin = lowered.back().get(); + modifier.insert_before(stmt, std::move(lowered)); } void visit(GlobalStoreStmt *stmt) override { - if (stmt->dest->is()) { - auto ptr = stmt->dest->as(); - // If ptr already has activate = false, no need to activate all the - // generated micro-access ops. Otherwise, activate the nodes. - auto lowered = lower_vector_ptr(ptr, ptr->activate); - stmt->dest = lowered.back().get(); - modifier.insert_before(stmt, std::move(lowered)); - } + if (!stmt->dest->is()) + return; + auto ptr = stmt->dest->as(); + // If ptr already has activate = false, no need to activate all the + // generated micro-access ops. Otherwise, activate the nodes. + auto lowered = lower_vector_ptr(ptr, ptr->activate); + stmt->dest = lowered.back().get(); + modifier.insert_before(stmt, std::move(lowered)); } void visit(SNodeOpStmt *stmt) override { From da3c394370e1e9157ac98e39e7d2cf00484342cf Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Tue, 27 Jul 2021 02:15:34 +0000 Subject: [PATCH 26/34] Auto Format --- python/taichi/lang/matrix.py | 8 ++++++-- taichi/ir/expr.cpp | 3 ++- taichi/ir/frontend_ir.cpp | 12 ++++++------ taichi/program/async_utils.cpp | 3 ++- taichi/transforms/flag_access.cpp | 7 +++++-- 5 files changed, 21 insertions(+), 12 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 2a1b91c14973e..c0afb4498d095 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -269,8 +269,12 @@ def subscript(self, *indices): i = indices[0] j = 0 if len(indices) == 1 else indices[1] # ptr.is_global_ptr() will check whether it's an element in the field (which is different from ptr.is_global_var()). - if isinstance(self.entries[0], ti.Expr) and self.entries[0].ptr.is_global_ptr() and (ti.cfg.arch == ti.cpu or ti.cfg.arch == ti.gpu): - return ti.subscript_with_offset(self.entries[0], (i, j), self.m, True) + if isinstance( + self.entries[0], + ti.Expr) and self.entries[0].ptr.is_global_ptr() and ( + ti.cfg.arch == ti.cpu or ti.cfg.arch == ti.gpu): + return ti.subscript_with_offset(self.entries[0], (i, j), + self.m, True) else: return self(i, j) diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index 47307a001be4a..5eeefcd6eca41 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -161,7 +161,8 @@ void Expr::operator/=(const Expr &o) { } Expr load_if_ptr(const Expr &ptr) { - if (ptr.is() || ptr.is()) { + if (ptr.is() || + ptr.is()) { return load(ptr); } else if (ptr.is()) { TI_ASSERT(ptr.cast()->snode->num_active_indices == diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 560a90f40d848..efdba4910d0c8 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -237,8 +237,8 @@ void GlobalTensorElementExpression::flatten(FlattenContext *ctx) { indices[0].set(load_if_ptr(indices[0])); indices[0]->flatten(ctx); Stmt *i_stmt = ctx->back_stmt(); - Stmt *cols_stmt = ctx->push_back( - Stmt::make(TypedConstant(cols))); + Stmt *cols_stmt = + ctx->push_back(Stmt::make(TypedConstant(cols))); Stmt *i_mul_cols_stmt = ctx->push_back( Stmt::make(BinaryOpType::mul, i_stmt, cols_stmt)); indices[1].set(load_if_ptr(indices[1])); @@ -255,16 +255,16 @@ void GlobalTensorElementExpression::flatten(FlattenContext *ctx) { for (auto *s = snode; s != nullptr; s = s->parent) size *= (int)s->max_num_elements(); Stmt *offset_stmt = ctx->back_stmt(); - Stmt *field_size_stmt = ctx->push_back( - Stmt::make(TypedConstant(size))); + Stmt *field_size_stmt = + ctx->push_back(Stmt::make(TypedConstant(size))); ctx->push_back(Stmt::make(BinaryOpType::mul, offset_stmt, field_size_stmt)); } // Type A[i, j][x, y] // ^^^^ Stmt *offset_stmt = ctx->back_stmt(); - Stmt *dt_size_stmt = ctx->push_back(Stmt::make( - TypedConstant(data_type_size(snode->dt)))); + Stmt *dt_size_stmt = ctx->push_back( + Stmt::make(TypedConstant(data_type_size(snode->dt)))); ctx->push_back( Stmt::make(BinaryOpType::mul, offset_stmt, dt_size_stmt)); diff --git a/taichi/program/async_utils.cpp b/taichi/program/async_utils.cpp index c4eff3ed1ae09..061c90d4613b6 100644 --- a/taichi/program/async_utils.cpp +++ b/taichi/program/async_utils.cpp @@ -165,7 +165,8 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { ir_bank->get_async_state(snode, AsyncState::Type::value)); } } - if (auto global_tensor_element = global_store->dest->cast()) { + if (auto global_tensor_element = + global_store->dest->cast()) { if (auto dest = global_tensor_element->origin->cast()) { for (auto &snode : dest->snodes.data) { meta.output_states.insert( diff --git a/taichi/transforms/flag_access.cpp b/taichi/transforms/flag_access.cpp index a26497e6c5698..4de5f9ac9cb63 100644 --- a/taichi/transforms/flag_access.cpp +++ b/taichi/transforms/flag_access.cpp @@ -55,8 +55,11 @@ class FlagAccess : public IRVisitor { stmt->dest->as()->activate = true; } if (stmt->dest->is()) { - if (stmt->dest->as()->origin->is()) { - stmt->dest->as()->origin->as()->activate = true; + if (stmt->dest->as() + ->origin->is()) { + stmt->dest->as() + ->origin->as() + ->activate = true; } } } From 848db37c89534a34f6dc86ab93f1590ad9ee2fcc Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Thu, 29 Jul 2021 16:40:37 +0800 Subject: [PATCH 27/34] add extension and fix tests --- python/taichi/lang/matrix.py | 3 +-- taichi/inc/extensions.inc.h | 1 + taichi/ir/statements.h | 3 +++ taichi/program/extension.cpp | 8 +++++--- tests/python/test_matrix.py | 24 +++++++++++++++++++----- 5 files changed, 29 insertions(+), 10 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 70288029dde42..f52bfff75a6f3 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -300,8 +300,7 @@ def subscript(self, *indices): # ptr.is_global_ptr() will check whether it's an element in the field (which is different from ptr.is_global_var()). if isinstance( self.entries[0], - ti.Expr) and self.entries[0].ptr.is_global_ptr() and ( - ti.cfg.arch == ti.cpu or ti.cfg.arch == ti.gpu): + ti.Expr) and self.entries[0].ptr.is_global_ptr() and ti.is_extension_supported(ti.cfg.arch, ti.extension.dynamic_index): return ti.subscript_with_offset(self.entries[0], (i, j), self.m, True) else: diff --git a/taichi/inc/extensions.inc.h b/taichi/inc/extensions.inc.h index 5c0ba10343346..0dc106864c8a0 100644 --- a/taichi/inc/extensions.inc.h +++ b/taichi/inc/extensions.inc.h @@ -9,3 +9,4 @@ PER_EXTENSION(bls) // Block-local storage PER_EXTENSION(assertion) // Run-time asserts in Taichi kernels PER_EXTENSION(extfunc) // Invoke external functions or backend source PER_EXTENSION(packed) // Shape will not be padded to a power of two +PER_EXTENSION(dynamic_index) // Dynamic index support for both global and local tensors diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index fd0d1c48cb31a..dfa04f51b24d0 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -318,6 +318,9 @@ class GlobalTensorElementStmt : public Stmt { } bool has_global_side_effect() const override { + // After access lowered, activate info will be recorded in SNodeLookupStmt's + // activate for AOS sparse data structure. We don't support SOA sparse data + // structure for now. return false; } diff --git a/taichi/program/extension.cpp b/taichi/program/extension.cpp index 7dbae325f3067..a33b098d3a701 100644 --- a/taichi/program/extension.cpp +++ b/taichi/program/extension.cpp @@ -11,15 +11,17 @@ bool is_extension_supported(Arch arch, Extension ext) { {Arch::x64, {Extension::sparse, Extension::async_mode, Extension::quant, Extension::quant_basic, Extension::data64, Extension::adstack, - Extension::assertion, Extension::extfunc, Extension::packed}}, + Extension::assertion, Extension::extfunc, Extension::packed, + Extension::dynamic_index}}, {Arch::arm64, {Extension::sparse, Extension::async_mode, Extension::quant, Extension::quant_basic, Extension::data64, Extension::adstack, - Extension::assertion, Extension::packed}}, + Extension::assertion, Extension::packed, Extension::dynamic_index}}, {Arch::cuda, {Extension::sparse, Extension::async_mode, Extension::quant, Extension::quant_basic, Extension::data64, Extension::adstack, - Extension::bls, Extension::assertion, Extension::packed}}, + Extension::bls, Extension::assertion, Extension::packed, + Extension::dynamic_index}}, {Arch::metal, {Extension::adstack, Extension::assertion, Extension::quant_basic, Extension::async_mode, Extension::sparse}}, diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 74572077e78b0..c8257cd2a5076 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -161,18 +161,32 @@ def run(): assert np.allclose(r2[None].value.to_numpy(), ops(a, c)) -@ti.test(arch=[ti.cpu, ti.gpu]) +@ti.test(require=ti.extension.dynamic_index) def test_matrix_non_constant_index(): m = ti.Matrix.field(2, 2, ti.i32, 5) + v = ti.Vector.field(10, ti.i32, 5) @ti.kernel - def func(): + def func1(): for i in range(5): for j, k in ti.ndrange(2, 2): - m[i][j, k] = 12 - - func() + m[i][j, k] = j * j + k * k + assert m[1][0, 1] == 1 + assert m[2][1, 0] == 1 + assert m[3][1, 1] == 2 + func1() + assert m[4][0, 1] == 1 + @ti.kernel + def func2(): + for i in range(5): + for j in range(4): + v[i][j * j] = j * j + assert v[1][0] == 0 + assert v[1][1] == 1 + assert v[1][4] == 4 + func2() + assert v[1][9] == 9 @ti.test(arch=ti.cpu) def test_matrix_constant_index(): From 6a43da792bb72b7b78fe5950a8ce948c88126f64 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Thu, 29 Jul 2021 08:51:30 +0000 Subject: [PATCH 28/34] Auto Format --- python/taichi/lang/matrix.py | 7 ++++--- taichi/inc/extensions.inc.h | 3 ++- tests/python/test_matrix.py | 3 +++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index f52bfff75a6f3..1548a549a1d01 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -298,9 +298,10 @@ def subscript(self, *indices): i = indices[0] j = 0 if len(indices) == 1 else indices[1] # ptr.is_global_ptr() will check whether it's an element in the field (which is different from ptr.is_global_var()). - if isinstance( - self.entries[0], - ti.Expr) and self.entries[0].ptr.is_global_ptr() and ti.is_extension_supported(ti.cfg.arch, ti.extension.dynamic_index): + if isinstance(self.entries[0], + ti.Expr) and self.entries[0].ptr.is_global_ptr( + ) and ti.is_extension_supported( + ti.cfg.arch, ti.extension.dynamic_index): return ti.subscript_with_offset(self.entries[0], (i, j), self.m, True) else: diff --git a/taichi/inc/extensions.inc.h b/taichi/inc/extensions.inc.h index 0dc106864c8a0..dae6b432b30a2 100644 --- a/taichi/inc/extensions.inc.h +++ b/taichi/inc/extensions.inc.h @@ -9,4 +9,5 @@ PER_EXTENSION(bls) // Block-local storage PER_EXTENSION(assertion) // Run-time asserts in Taichi kernels PER_EXTENSION(extfunc) // Invoke external functions or backend source PER_EXTENSION(packed) // Shape will not be padded to a power of two -PER_EXTENSION(dynamic_index) // Dynamic index support for both global and local tensors +PER_EXTENSION( + dynamic_index) // Dynamic index support for both global and local tensors diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index c8257cd2a5076..d12fc69ffb1ea 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -174,6 +174,7 @@ def func1(): assert m[1][0, 1] == 1 assert m[2][1, 0] == 1 assert m[3][1, 1] == 2 + func1() assert m[4][0, 1] == 1 @@ -185,9 +186,11 @@ def func2(): assert v[1][0] == 0 assert v[1][1] == 1 assert v[1][4] == 4 + func2() assert v[1][9] == 9 + @ti.test(arch=ti.cpu) def test_matrix_constant_index(): m = ti.Matrix.field(2, 2, ti.i32, 5) From e37ed4700c0114cdfc4596d7e5522eb420808fee Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Thu, 29 Jul 2021 16:53:46 +0800 Subject: [PATCH 29/34] do actual type check --- taichi/transforms/type_check.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 6650577c08a8a..c29b32deabed6 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -159,7 +159,7 @@ class TypeCheck : public IRVisitor { } void visit(GlobalTensorElementStmt *stmt) override { - // TODO: do actual type_check + TI_ASSERT(stmt->offset->ret_type->is_primitive(PrimitiveTypeID::i32)); stmt->ret_type.set_is_pointer(true); } From 6c2ea50462f4e93c1d97268fcd8f4945e3e8542c Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Thu, 29 Jul 2021 16:58:58 +0800 Subject: [PATCH 30/34] rename GlobalTensorElementStmt into ShiftGlobalPtrStmt --- taichi/codegen/codegen_llvm.cpp | 2 +- taichi/codegen/codegen_llvm.h | 2 +- taichi/inc/statements.inc.h | 2 +- taichi/ir/control_flow_graph.cpp | 2 +- taichi/ir/frontend_ir.cpp | 2 +- taichi/ir/statements.h | 4 ++-- taichi/program/async_utils.cpp | 2 +- taichi/transforms/flag_access.cpp | 6 +++--- taichi/transforms/ir_printer.cpp | 2 +- taichi/transforms/lower_access.cpp | 2 +- taichi/transforms/type_check.cpp | 2 +- 11 files changed, 14 insertions(+), 14 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index e117043a8ec0f..6c10c344dbdbd 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1378,7 +1378,7 @@ void CodeGenLLVM::visit(GetChStmt *stmt) { } } -void CodeGenLLVM::visit(GlobalTensorElementStmt *stmt) { +void CodeGenLLVM::visit(ShiftGlobalPtrStmt *stmt) { auto origin_address = builder->CreatePtrToInt( llvm_val[stmt->origin], llvm::Type::getInt64Ty(*llvm_context)); auto offset_address = builder->CreateSExt( diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index f0bbb0f6fed72..b987af5f0ac33 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -212,7 +212,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(GlobalPtrStmt *stmt) override; - void visit(GlobalTensorElementStmt *stmt) override; + void visit(ShiftGlobalPtrStmt *stmt) override; void store_custom_int(llvm::Value *bit_ptr, CustomIntType *cit, diff --git a/taichi/inc/statements.inc.h b/taichi/inc/statements.inc.h index 0870b3bb59919..45021bb5cb5de 100644 --- a/taichi/inc/statements.inc.h +++ b/taichi/inc/statements.inc.h @@ -29,7 +29,7 @@ PER_STATEMENT(ReturnStmt) PER_STATEMENT(ArgLoadStmt) PER_STATEMENT(ExternalPtrStmt) -PER_STATEMENT(GlobalTensorElementStmt) +PER_STATEMENT(ShiftGlobalPtrStmt) PER_STATEMENT(ConstStmt) PER_STATEMENT(AllocaStmt) PER_STATEMENT(UnaryOpStmt) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 27e09a45cb58f..cbd5f7f047182 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -627,7 +627,7 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { if (stmt->is() || stmt->is() || stmt->is() || stmt->is() || stmt->is() || - stmt->is()) { + stmt->is()) { // TODO: unify them // A global pointer that may contain some data before this kernel. nodes[start_node]->reach_gen.insert(stmt); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index efdba4910d0c8..9de1cb7d0b79a 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -269,7 +269,7 @@ void GlobalTensorElementExpression::flatten(FlattenContext *ctx) { Stmt::make(BinaryOpType::mul, offset_stmt, dt_size_stmt)); ctx->push_back( - std::make_unique(var_stmt, ctx->back_stmt())); + std::make_unique(var_stmt, ctx->back_stmt())); stmt = ctx->back_stmt(); } diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index dfa04f51b24d0..f2a9e595873b8 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -306,12 +306,12 @@ class GlobalPtrStmt : public Stmt { /** * An accessing tensor element operation. */ -class GlobalTensorElementStmt : public Stmt { +class ShiftGlobalPtrStmt : public Stmt { public: Stmt *origin{nullptr}; Stmt *offset{nullptr}; - GlobalTensorElementStmt(Stmt *origin, Stmt *offset) + ShiftGlobalPtrStmt(Stmt *origin, Stmt *offset) : origin(origin), offset(offset) { element_type() = origin->cast()->ret_type; TI_STMT_REG_FIELDS; diff --git a/taichi/program/async_utils.cpp b/taichi/program/async_utils.cpp index 061c90d4613b6..0a63fb9888c77 100644 --- a/taichi/program/async_utils.cpp +++ b/taichi/program/async_utils.cpp @@ -166,7 +166,7 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { } } if (auto global_tensor_element = - global_store->dest->cast()) { + global_store->dest->cast()) { if (auto dest = global_tensor_element->origin->cast()) { for (auto &snode : dest->snodes.data) { meta.output_states.insert( diff --git a/taichi/transforms/flag_access.cpp b/taichi/transforms/flag_access.cpp index 4de5f9ac9cb63..e3e7c74e86fc3 100644 --- a/taichi/transforms/flag_access.cpp +++ b/taichi/transforms/flag_access.cpp @@ -54,10 +54,10 @@ class FlagAccess : public IRVisitor { if (stmt->dest->is()) { stmt->dest->as()->activate = true; } - if (stmt->dest->is()) { - if (stmt->dest->as() + if (stmt->dest->is()) { + if (stmt->dest->as() ->origin->is()) { - stmt->dest->as() + stmt->dest->as() ->origin->as() ->activate = true; } diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index a2edcc0c49823..4bb72f6b59ac4 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -383,7 +383,7 @@ class IRPrinter : public IRVisitor { print_raw(s); } - void visit(GlobalTensorElementStmt *stmt) override { + void visit(ShiftGlobalPtrStmt *stmt) override { std::string s = fmt::format("{}{} = shift ptr [{} + {}]", stmt->type_hint(), stmt->name(), stmt->origin->name(), stmt->offset->name()); diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp index 806e8bbf9ab46..4e679d20264c3 100644 --- a/taichi/transforms/lower_access.cpp +++ b/taichi/transforms/lower_access.cpp @@ -157,7 +157,7 @@ class LowerAccess : public IRVisitor { } // TODO: this seems to be redundant - void visit(GlobalTensorElementStmt *stmt) override { + void visit(ShiftGlobalPtrStmt *stmt) override { if (!stmt->origin->is()) return; auto ptr = stmt->origin->as(); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index c29b32deabed6..0a21b18c0785f 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -158,7 +158,7 @@ class TypeCheck : public IRVisitor { } } - void visit(GlobalTensorElementStmt *stmt) override { + void visit(ShiftGlobalPtrStmt *stmt) override { TI_ASSERT(stmt->offset->ret_type->is_primitive(PrimitiveTypeID::i32)); stmt->ret_type.set_is_pointer(true); } From cccd8821045183e55447c7728823fd039dd7cff5 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Thu, 29 Jul 2021 09:24:27 +0000 Subject: [PATCH 31/34] Auto Format --- taichi/ir/control_flow_graph.cpp | 3 +-- taichi/transforms/flag_access.cpp | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index cbd5f7f047182..1e8d0cdc3dd52 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -626,8 +626,7 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { auto stmt = nodes[i]->block->statements[j].get(); if (stmt->is() || stmt->is() || stmt->is() || stmt->is() || - stmt->is() || - stmt->is()) { + stmt->is() || stmt->is()) { // TODO: unify them // A global pointer that may contain some data before this kernel. nodes[start_node]->reach_gen.insert(stmt); diff --git a/taichi/transforms/flag_access.cpp b/taichi/transforms/flag_access.cpp index e3e7c74e86fc3..29db75ffb9a86 100644 --- a/taichi/transforms/flag_access.cpp +++ b/taichi/transforms/flag_access.cpp @@ -55,8 +55,7 @@ class FlagAccess : public IRVisitor { stmt->dest->as()->activate = true; } if (stmt->dest->is()) { - if (stmt->dest->as() - ->origin->is()) { + if (stmt->dest->as()->origin->is()) { stmt->dest->as() ->origin->as() ->activate = true; From d9bdf429f4eef8994cba10e1fc85442b01a7e781 Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Fri, 30 Jul 2021 15:20:36 +0800 Subject: [PATCH 32/34] nit address_offset --- taichi/codegen/codegen_llvm.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 6c10c344dbdbd..1a4b99d1aee2b 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1381,9 +1381,9 @@ void CodeGenLLVM::visit(GetChStmt *stmt) { void CodeGenLLVM::visit(ShiftGlobalPtrStmt *stmt) { auto origin_address = builder->CreatePtrToInt( llvm_val[stmt->origin], llvm::Type::getInt64Ty(*llvm_context)); - auto offset_address = builder->CreateSExt( + auto address_offset = builder->CreateSExt( llvm_val[stmt->offset], llvm::Type::getInt64Ty(*llvm_context)); - auto target_address = builder->CreateAdd(origin_address, offset_address); + auto target_address = builder->CreateAdd(origin_address, address_offset); auto dt = stmt->ret_type.ptr_removed(); llvm_val[stmt] = builder->CreateIntToPtr( target_address, llvm::PointerType::get(tlctx->get_data_type(dt), 0)); From 3c20d4828a45cdb6e1396564fb49835e69eb86ab Mon Sep 17 00:00:00 2001 From: Yu Fang Date: Fri, 30 Jul 2021 15:25:10 +0800 Subject: [PATCH 33/34] rename ShiftGlobalPtrStmt into PtrOffsetStmt --- taichi/codegen/codegen_llvm.cpp | 2 +- taichi/codegen/codegen_llvm.h | 2 +- taichi/inc/statements.inc.h | 2 +- taichi/ir/control_flow_graph.cpp | 2 +- taichi/ir/frontend_ir.cpp | 2 +- taichi/ir/statements.h | 4 ++-- taichi/program/async_utils.cpp | 2 +- taichi/transforms/flag_access.cpp | 6 +++--- taichi/transforms/ir_printer.cpp | 2 +- taichi/transforms/lower_access.cpp | 2 +- taichi/transforms/type_check.cpp | 2 +- 11 files changed, 14 insertions(+), 14 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 1a4b99d1aee2b..cd31437b1d392 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1378,7 +1378,7 @@ void CodeGenLLVM::visit(GetChStmt *stmt) { } } -void CodeGenLLVM::visit(ShiftGlobalPtrStmt *stmt) { +void CodeGenLLVM::visit(PtrOffsetStmt *stmt) { auto origin_address = builder->CreatePtrToInt( llvm_val[stmt->origin], llvm::Type::getInt64Ty(*llvm_context)); auto address_offset = builder->CreateSExt( diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index b987af5f0ac33..d2b99f6bde66e 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -212,7 +212,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(GlobalPtrStmt *stmt) override; - void visit(ShiftGlobalPtrStmt *stmt) override; + void visit(PtrOffsetStmt *stmt) override; void store_custom_int(llvm::Value *bit_ptr, CustomIntType *cit, diff --git a/taichi/inc/statements.inc.h b/taichi/inc/statements.inc.h index 45021bb5cb5de..a631fe23884c7 100644 --- a/taichi/inc/statements.inc.h +++ b/taichi/inc/statements.inc.h @@ -29,7 +29,7 @@ PER_STATEMENT(ReturnStmt) PER_STATEMENT(ArgLoadStmt) PER_STATEMENT(ExternalPtrStmt) -PER_STATEMENT(ShiftGlobalPtrStmt) +PER_STATEMENT(PtrOffsetStmt) PER_STATEMENT(ConstStmt) PER_STATEMENT(AllocaStmt) PER_STATEMENT(UnaryOpStmt) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 1e8d0cdc3dd52..609daebacb50d 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -626,7 +626,7 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { auto stmt = nodes[i]->block->statements[j].get(); if (stmt->is() || stmt->is() || stmt->is() || stmt->is() || - stmt->is() || stmt->is()) { + stmt->is() || stmt->is()) { // TODO: unify them // A global pointer that may contain some data before this kernel. nodes[start_node]->reach_gen.insert(stmt); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 9de1cb7d0b79a..67246ca1d8093 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -269,7 +269,7 @@ void GlobalTensorElementExpression::flatten(FlattenContext *ctx) { Stmt::make(BinaryOpType::mul, offset_stmt, dt_size_stmt)); ctx->push_back( - std::make_unique(var_stmt, ctx->back_stmt())); + std::make_unique(var_stmt, ctx->back_stmt())); stmt = ctx->back_stmt(); } diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index f2a9e595873b8..d45efe27e2337 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -306,12 +306,12 @@ class GlobalPtrStmt : public Stmt { /** * An accessing tensor element operation. */ -class ShiftGlobalPtrStmt : public Stmt { +class PtrOffsetStmt : public Stmt { public: Stmt *origin{nullptr}; Stmt *offset{nullptr}; - ShiftGlobalPtrStmt(Stmt *origin, Stmt *offset) + PtrOffsetStmt(Stmt *origin, Stmt *offset) : origin(origin), offset(offset) { element_type() = origin->cast()->ret_type; TI_STMT_REG_FIELDS; diff --git a/taichi/program/async_utils.cpp b/taichi/program/async_utils.cpp index 0a63fb9888c77..4d1b6205396b5 100644 --- a/taichi/program/async_utils.cpp +++ b/taichi/program/async_utils.cpp @@ -166,7 +166,7 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { } } if (auto global_tensor_element = - global_store->dest->cast()) { + global_store->dest->cast()) { if (auto dest = global_tensor_element->origin->cast()) { for (auto &snode : dest->snodes.data) { meta.output_states.insert( diff --git a/taichi/transforms/flag_access.cpp b/taichi/transforms/flag_access.cpp index 29db75ffb9a86..1f8c7bd24ec54 100644 --- a/taichi/transforms/flag_access.cpp +++ b/taichi/transforms/flag_access.cpp @@ -54,9 +54,9 @@ class FlagAccess : public IRVisitor { if (stmt->dest->is()) { stmt->dest->as()->activate = true; } - if (stmt->dest->is()) { - if (stmt->dest->as()->origin->is()) { - stmt->dest->as() + if (stmt->dest->is()) { + if (stmt->dest->as()->origin->is()) { + stmt->dest->as() ->origin->as() ->activate = true; } diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index 4bb72f6b59ac4..4e44a9bedf5aa 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -383,7 +383,7 @@ class IRPrinter : public IRVisitor { print_raw(s); } - void visit(ShiftGlobalPtrStmt *stmt) override { + void visit(PtrOffsetStmt *stmt) override { std::string s = fmt::format("{}{} = shift ptr [{} + {}]", stmt->type_hint(), stmt->name(), stmt->origin->name(), stmt->offset->name()); diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp index 4e679d20264c3..b83b5bd458de5 100644 --- a/taichi/transforms/lower_access.cpp +++ b/taichi/transforms/lower_access.cpp @@ -157,7 +157,7 @@ class LowerAccess : public IRVisitor { } // TODO: this seems to be redundant - void visit(ShiftGlobalPtrStmt *stmt) override { + void visit(PtrOffsetStmt *stmt) override { if (!stmt->origin->is()) return; auto ptr = stmt->origin->as(); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 0a21b18c0785f..943787b584464 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -158,7 +158,7 @@ class TypeCheck : public IRVisitor { } } - void visit(ShiftGlobalPtrStmt *stmt) override { + void visit(PtrOffsetStmt *stmt) override { TI_ASSERT(stmt->offset->ret_type->is_primitive(PrimitiveTypeID::i32)); stmt->ret_type.set_is_pointer(true); } From 3da4b3edffcc428270486d9c8a2d99cc85e48f4b Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Fri, 30 Jul 2021 07:27:32 +0000 Subject: [PATCH 34/34] Auto Format --- taichi/ir/frontend_ir.cpp | 3 +-- taichi/ir/statements.h | 3 +-- taichi/transforms/flag_access.cpp | 5 ++--- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 67246ca1d8093..5cc1f1e6ae8f8 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -268,8 +268,7 @@ void GlobalTensorElementExpression::flatten(FlattenContext *ctx) { ctx->push_back( Stmt::make(BinaryOpType::mul, offset_stmt, dt_size_stmt)); - ctx->push_back( - std::make_unique(var_stmt, ctx->back_stmt())); + ctx->push_back(std::make_unique(var_stmt, ctx->back_stmt())); stmt = ctx->back_stmt(); } diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index d45efe27e2337..31cbaf14fd937 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -311,8 +311,7 @@ class PtrOffsetStmt : public Stmt { Stmt *origin{nullptr}; Stmt *offset{nullptr}; - PtrOffsetStmt(Stmt *origin, Stmt *offset) - : origin(origin), offset(offset) { + PtrOffsetStmt(Stmt *origin, Stmt *offset) : origin(origin), offset(offset) { element_type() = origin->cast()->ret_type; TI_STMT_REG_FIELDS; } diff --git a/taichi/transforms/flag_access.cpp b/taichi/transforms/flag_access.cpp index 1f8c7bd24ec54..d4cea6eac92f5 100644 --- a/taichi/transforms/flag_access.cpp +++ b/taichi/transforms/flag_access.cpp @@ -56,9 +56,8 @@ class FlagAccess : public IRVisitor { } if (stmt->dest->is()) { if (stmt->dest->as()->origin->is()) { - stmt->dest->as() - ->origin->as() - ->activate = true; + stmt->dest->as()->origin->as()->activate = + true; } } }