diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index aa2197c854360..bc2649fa271bc 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -499,12 +499,8 @@ class Stmt : public IRNode { Stmt(); Stmt(const Stmt &stmt); - int &width() { - return ret_type.width; - } - - const int &width() const { - return ret_type.width; + int width() const { + return ret_type->vector_width(); } virtual bool is_container_statement() const { @@ -558,7 +554,8 @@ class Stmt : public IRNode { IRNode *get_parent() const override; virtual void repeat(int factor) { - ret_type.width *= factor; + TI_ASSERT(factor == 1); + // ret_type.width *= factor; } // returns the inserted stmt diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 29b859359456b..e921481ae4bc5 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -45,7 +45,7 @@ ExternalPtrStmt::ExternalPtrStmt(const LaneAttribute &base_ptrs, TI_ASSERT(base_ptrs[i] != nullptr); TI_ASSERT(base_ptrs[i]->is()); } - width() = base_ptrs.size(); + TI_ASSERT(base_ptrs.size() == 1); element_type() = dt; TI_STMT_REG_FIELDS; } @@ -58,7 +58,7 @@ GlobalPtrStmt::GlobalPtrStmt(const LaneAttribute &snodes, TI_ASSERT(snodes[i] != nullptr); TI_ASSERT(snodes[0]->dt == snodes[i]->dt); } - width() = snodes.size(); + TI_ASSERT(snodes.size() == 1); element_type() = snodes[0]->dt; TI_STMT_REG_FIELDS; } @@ -90,7 +90,6 @@ SNodeOpStmt::SNodeOpStmt(SNodeOpType op_type, Stmt *ptr, Stmt *val) : op_type(op_type), snode(snode), ptr(ptr), val(val) { - width() = 1; element_type() = PrimitiveType::i32; TI_STMT_REG_FIELDS; } @@ -104,7 +103,6 @@ SNodeOpStmt::SNodeOpStmt(SNodeOpType op_type, TI_ASSERT(op_type == SNodeOpType::is_active || op_type == SNodeOpType::deactivate || op_type == SNodeOpType::activate); - width() = 1; element_type() = PrimitiveType::i32; TI_STMT_REG_FIELDS; } diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 0ecf3c666073f..685d94890dd2c 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -485,9 +485,9 @@ class ConstStmt : public Stmt { LaneAttribute val; ConstStmt(const LaneAttribute &val) : val(val) { - width() = val.size(); - element_type() = val[0].dt; - for (int i = 0; i < ret_type.width; i++) { + TI_ASSERT(val.size() == 1); // TODO: support vectorized case + ret_type = val[0].dt; + for (int i = 0; i < val.size(); i++) { TI_ASSERT(val[0].dt == val[i].dt); } TI_STMT_REG_FIELDS; @@ -660,8 +660,8 @@ class ElementShuffleStmt : public Stmt { ElementShuffleStmt(const LaneAttribute &elements, bool pointer = false) : elements(elements), pointer(pointer) { - width() = elements.size(); - element_type() = elements[0].stmt->element_type(); + TI_ASSERT(elements.size() == 1); // TODO: support vectorized cases + ret_type = elements[0].stmt->element_type(); TI_STMT_REG_FIELDS; } diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index fd65cbf1cbe40..c06341586573a 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -1,4 +1,5 @@ #include "taichi/ir/type.h" + #include "taichi/program/program.h" TLANG_NAMESPACE_BEGIN @@ -67,6 +68,14 @@ std::string PrimitiveType::to_string() const { return data_type_name(DataType(const_cast(this))); } +int Type::vector_width() const { + if (auto vec = cast()) { + return vec->get_num_elements(); + } else { + return 1; + } +} + DataType LegacyVectorType(int width, DataType data_type, bool is_pointer) { TI_ASSERT(width == 1); if (is_pointer) { diff --git a/taichi/ir/type.h b/taichi/ir/type.h index efcb14323c015..ba558a30b39f4 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -30,6 +30,8 @@ class Type { return p; } + int vector_width() const; + virtual ~Type() { } }; @@ -66,7 +68,6 @@ class DataType { // Temporary API and members // for LegacyVectorType-compatibility - int width{1}; Type *operator->() const { return ptr_; diff --git a/taichi/lang_util.cpp b/taichi/lang_util.cpp index 1f7300857c3b3..2c77eac504cb8 100644 --- a/taichi/lang_util.cpp +++ b/taichi/lang_util.cpp @@ -361,11 +361,14 @@ class TypePromotionMapping { TI_WARN("promoted_type got a pointer input."); } + if (d->is()) { + d = d->as()->get_element_type(); + TI_WARN("promoted_type got a vector input."); + } + auto primitive = d->cast(); - TI_ASSERT_INFO( - primitive, - "Failed to get primitive type! " - "Consider adding `ti.init()` to the first line of your program."); + TI_ASSERT_INFO(primitive, "Failed to get primitive type from {}", + d->to_string()); return primitive->type; }; }; diff --git a/taichi/transforms/loop_vectorize.cpp b/taichi/transforms/loop_vectorize.cpp index e2cf8f6f9f227..85ce94e4f0136 100644 --- a/taichi/transforms/loop_vectorize.cpp +++ b/taichi/transforms/loop_vectorize.cpp @@ -1,6 +1,8 @@ // The loop vectorizer +#include "taichi/program/program.h" #include "taichi/ir/ir.h" +#include "taichi/ir/type_factory.h" #include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/ir/visitors.h" @@ -21,13 +23,19 @@ class LoopVectorize : public IRVisitor { vectorize = 1; } + static void widen_type(DataType &type, int width) { + if (width != 1) { + type = Program::get_type_factory().get_vector_type(width, type.get_ptr()); + } + } + void visit(Stmt *stmt) override { - stmt->ret_type.width *= vectorize; + widen_type(stmt->ret_type, vectorize); } void visit(ConstStmt *stmt) override { stmt->val.repeat(vectorize); - stmt->ret_type.width *= vectorize; + widen_type(stmt->ret_type, vectorize); } void visit(Block *stmt_list) override { @@ -42,11 +50,11 @@ class LoopVectorize : public IRVisitor { void visit(GlobalPtrStmt *ptr) override { ptr->snodes.repeat(vectorize); - ptr->width() *= vectorize; + widen_type(ptr->ret_type, vectorize); } void visit(AllocaStmt *alloca) override { - alloca->ret_type.width *= vectorize; + widen_type(alloca->ret_type, vectorize); } void visit(SNodeOpStmt *stmt) override { @@ -63,7 +71,7 @@ class LoopVectorize : public IRVisitor { if (vectorize == 1) return; int original_width = stmt->width(); - stmt->ret_type.width *= vectorize; + widen_type(stmt->ret_type, vectorize); stmt->elements.repeat(vectorize); // TODO: this can be buggy int stride = stmt->elements[original_width - 1].index + 1; @@ -80,7 +88,7 @@ class LoopVectorize : public IRVisitor { if (vectorize == 1) return; int original_width = stmt->width(); - stmt->ret_type.width *= vectorize; + widen_type(stmt->ret_type, vectorize); stmt->ptr.repeat(vectorize); // TODO: this can be buggy int stride = stmt->ptr[original_width - 1].offset + 1; diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index 461ed943bbae4..5478a8ef55c3d 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -259,7 +259,7 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor { } std::size_t allocate_global(DataType type) { - TI_ASSERT(type.width == 1); + TI_ASSERT(type->vector_width() == 1); auto ret = global_offset; global_offset += data_type_size(type); TI_ASSERT(global_offset < taichi_global_tmp_buffer_size); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index dabf894133c52..df167361d41ae 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -142,7 +142,7 @@ class TypeCheck : public IRVisitor { stmt->indices[i] = insert_type_cast_before(stmt, stmt->indices[i], PrimitiveType::i32); } - TI_ASSERT(stmt->indices[i]->ret_type.width == stmt->snodes.size()); + TI_ASSERT(stmt->indices[i]->width() == stmt->snodes.size()); } } @@ -272,8 +272,7 @@ class TypeCheck : public IRVisitor { } } bool matching = true; - matching = - matching && (stmt->lhs->ret_type.width == stmt->rhs->ret_type.width); + matching = matching && (stmt->lhs->width() == stmt->rhs->width()); matching = matching && (stmt->lhs->ret_type != PrimitiveType::unknown); matching = matching && (stmt->rhs->ret_type != PrimitiveType::unknown); matching = matching && (stmt->lhs->ret_type == stmt->rhs->ret_type); @@ -286,8 +285,7 @@ class TypeCheck : public IRVisitor { } } if (is_comparison(stmt->op_type)) { - stmt->ret_type = - LegacyVectorType(stmt->lhs->ret_type.width, PrimitiveType::i32); + stmt->ret_type = LegacyVectorType(stmt->lhs->width(), PrimitiveType::i32); } else { stmt->ret_type = stmt->lhs->ret_type; } @@ -297,8 +295,10 @@ class TypeCheck : public IRVisitor { if (stmt->op_type == TernaryOpType::select) { auto ret_type = promoted_type(stmt->op2->ret_type, stmt->op3->ret_type); TI_ASSERT(stmt->op1->ret_type == PrimitiveType::i32) - TI_ASSERT(stmt->op1->ret_type.width == stmt->op2->ret_type.width); - TI_ASSERT(stmt->op2->ret_type.width == stmt->op3->ret_type.width); + TI_ASSERT(stmt->op1->ret_type->vector_width() == + stmt->op2->ret_type->vector_width()); + TI_ASSERT(stmt->op2->ret_type->vector_width() == + stmt->op3->ret_type->vector_width()); if (ret_type != stmt->op2->ret_type) { auto cast_stmt = insert_type_cast_before(stmt, stmt->op2, ret_type); stmt->op2 = cast_stmt; @@ -329,7 +329,7 @@ class TypeCheck : public IRVisitor { // defined by the kernel. After that, type_check() pass will purely do // verification, without modifying any types. TI_ASSERT(rt != PrimitiveType::unknown); - TI_ASSERT(rt.width == 1); + TI_ASSERT(rt->vector_width() == 1); stmt->ret_type.set_is_pointer(stmt->is_ptr); } @@ -337,7 +337,7 @@ class TypeCheck : public IRVisitor { // TODO: Support stmt->ret_id? const auto &rt = stmt->ret_type; TI_ASSERT(stmt->value->element_type() == rt); - TI_ASSERT(rt.width == 1); + TI_ASSERT(rt->vector_width() == 1); } void visit(ExternalPtrStmt *stmt) { diff --git a/taichi/transforms/vector_split.cpp b/taichi/transforms/vector_split.cpp index 07da12cbb73e1..4a93f1938b6df 100644 --- a/taichi/transforms/vector_split.cpp +++ b/taichi/transforms/vector_split.cpp @@ -4,6 +4,8 @@ #include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/ir/visitors.h" +#include "taichi/program/program.h" + #include TLANG_NAMESPACE_BEGIN @@ -59,8 +61,9 @@ class BasicBlockVectorSplit : public IRVisitor { stmt->accept(this); origin2split[stmt] = std::vector(current_split_factor, nullptr); for (int j = 0; j < current_split_factor; j++) { - current_split[j]->element_type() = stmt->element_type(); - current_split[j]->width() = max_width; + current_split[j]->ret_type = + Program::get_type_factory().get_vector_type( + max_width, stmt->element_type().get_ptr()); origin2split[stmt][j] = current_split[j].get(); } splits.push_back(std::move(current_split)); @@ -71,8 +74,10 @@ class BasicBlockVectorSplit : public IRVisitor { need_split = false; stmt->accept(this); origin2split[stmt] = std::vector(1, nullptr); - current_split[0]->width() = stmt->width(); current_split[0]->element_type() = stmt->element_type(); + current_split[0]->ret_type = + Program::get_type_factory().get_vector_type( + stmt->width(), stmt->element_type().get_ptr()); origin2split[stmt][0] = current_split[0].get(); std::vector split; split.push_back(std::move(current_split[0]));