Skip to content

Commit

Permalink
[type] Remove DataType::width (#1962)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu authored Oct 16, 2020
1 parent 6fea4c2 commit 6cbe537
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 40 deletions.
11 changes: 4 additions & 7 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -499,12 +499,8 @@ class Stmt : public IRNode {
Stmt();
Stmt(const Stmt &stmt);

int &width() {
return ret_type.width;
}

const int &width() const {
return ret_type.width;
int width() const {
return ret_type->vector_width();
}

virtual bool is_container_statement() const {
Expand Down Expand Up @@ -558,7 +554,8 @@ class Stmt : public IRNode {
IRNode *get_parent() const override;

virtual void repeat(int factor) {
ret_type.width *= factor;
TI_ASSERT(factor == 1);
// ret_type.width *= factor;
}

// returns the inserted stmt
Expand Down
6 changes: 2 additions & 4 deletions taichi/ir/statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ ExternalPtrStmt::ExternalPtrStmt(const LaneAttribute<Stmt *> &base_ptrs,
TI_ASSERT(base_ptrs[i] != nullptr);
TI_ASSERT(base_ptrs[i]->is<ArgLoadStmt>());
}
width() = base_ptrs.size();
TI_ASSERT(base_ptrs.size() == 1);
element_type() = dt;
TI_STMT_REG_FIELDS;
}
Expand All @@ -58,7 +58,7 @@ GlobalPtrStmt::GlobalPtrStmt(const LaneAttribute<SNode *> &snodes,
TI_ASSERT(snodes[i] != nullptr);
TI_ASSERT(snodes[0]->dt == snodes[i]->dt);
}
width() = snodes.size();
TI_ASSERT(snodes.size() == 1);
element_type() = snodes[0]->dt;
TI_STMT_REG_FIELDS;
}
Expand Down Expand Up @@ -90,7 +90,6 @@ SNodeOpStmt::SNodeOpStmt(SNodeOpType op_type,
Stmt *ptr,
Stmt *val)
: op_type(op_type), snode(snode), ptr(ptr), val(val) {
width() = 1;
element_type() = PrimitiveType::i32;
TI_STMT_REG_FIELDS;
}
Expand All @@ -104,7 +103,6 @@ SNodeOpStmt::SNodeOpStmt(SNodeOpType op_type,
TI_ASSERT(op_type == SNodeOpType::is_active ||
op_type == SNodeOpType::deactivate ||
op_type == SNodeOpType::activate);
width() = 1;
element_type() = PrimitiveType::i32;
TI_STMT_REG_FIELDS;
}
Expand Down
10 changes: 5 additions & 5 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,9 +485,9 @@ class ConstStmt : public Stmt {
LaneAttribute<TypedConstant> val;

ConstStmt(const LaneAttribute<TypedConstant> &val) : val(val) {
width() = val.size();
element_type() = val[0].dt;
for (int i = 0; i < ret_type.width; i++) {
TI_ASSERT(val.size() == 1); // TODO: support vectorized case
ret_type = val[0].dt;
for (int i = 0; i < val.size(); i++) {
TI_ASSERT(val[0].dt == val[i].dt);
}
TI_STMT_REG_FIELDS;
Expand Down Expand Up @@ -660,8 +660,8 @@ class ElementShuffleStmt : public Stmt {
ElementShuffleStmt(const LaneAttribute<VectorElement> &elements,
bool pointer = false)
: elements(elements), pointer(pointer) {
width() = elements.size();
element_type() = elements[0].stmt->element_type();
TI_ASSERT(elements.size() == 1); // TODO: support vectorized cases
ret_type = elements[0].stmt->element_type();
TI_STMT_REG_FIELDS;
}

Expand Down
9 changes: 9 additions & 0 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "taichi/ir/type.h"

#include "taichi/program/program.h"

TLANG_NAMESPACE_BEGIN
Expand Down Expand Up @@ -67,6 +68,14 @@ std::string PrimitiveType::to_string() const {
return data_type_name(DataType(const_cast<PrimitiveType *>(this)));
}

int Type::vector_width() const {
if (auto vec = cast<VectorType>()) {
return vec->get_num_elements();
} else {
return 1;
}
}

DataType LegacyVectorType(int width, DataType data_type, bool is_pointer) {
TI_ASSERT(width == 1);
if (is_pointer) {
Expand Down
3 changes: 2 additions & 1 deletion taichi/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class Type {
return p;
}

int vector_width() const;

virtual ~Type() {
}
};
Expand Down Expand Up @@ -66,7 +68,6 @@ class DataType {

// Temporary API and members
// for LegacyVectorType-compatibility
int width{1};

Type *operator->() const {
return ptr_;
Expand Down
11 changes: 7 additions & 4 deletions taichi/lang_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,11 +361,14 @@ class TypePromotionMapping {
TI_WARN("promoted_type got a pointer input.");
}

if (d->is<VectorType>()) {
d = d->as<VectorType>()->get_element_type();
TI_WARN("promoted_type got a vector input.");
}

auto primitive = d->cast<PrimitiveType>();
TI_ASSERT_INFO(
primitive,
"Failed to get primitive type! "
"Consider adding `ti.init()` to the first line of your program.");
TI_ASSERT_INFO(primitive, "Failed to get primitive type from {}",
d->to_string());
return primitive->type;
};
};
Expand Down
20 changes: 14 additions & 6 deletions taichi/transforms/loop_vectorize.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// The loop vectorizer

#include "taichi/program/program.h"
#include "taichi/ir/ir.h"
#include "taichi/ir/type_factory.h"
#include "taichi/ir/statements.h"
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"
Expand All @@ -21,13 +23,19 @@ class LoopVectorize : public IRVisitor {
vectorize = 1;
}

static void widen_type(DataType &type, int width) {
if (width != 1) {
type = Program::get_type_factory().get_vector_type(width, type.get_ptr());
}
}

void visit(Stmt *stmt) override {
stmt->ret_type.width *= vectorize;
widen_type(stmt->ret_type, vectorize);
}

void visit(ConstStmt *stmt) override {
stmt->val.repeat(vectorize);
stmt->ret_type.width *= vectorize;
widen_type(stmt->ret_type, vectorize);
}

void visit(Block *stmt_list) override {
Expand All @@ -42,11 +50,11 @@ class LoopVectorize : public IRVisitor {

void visit(GlobalPtrStmt *ptr) override {
ptr->snodes.repeat(vectorize);
ptr->width() *= vectorize;
widen_type(ptr->ret_type, vectorize);
}

void visit(AllocaStmt *alloca) override {
alloca->ret_type.width *= vectorize;
widen_type(alloca->ret_type, vectorize);
}

void visit(SNodeOpStmt *stmt) override {
Expand All @@ -63,7 +71,7 @@ class LoopVectorize : public IRVisitor {
if (vectorize == 1)
return;
int original_width = stmt->width();
stmt->ret_type.width *= vectorize;
widen_type(stmt->ret_type, vectorize);
stmt->elements.repeat(vectorize);
// TODO: this can be buggy
int stride = stmt->elements[original_width - 1].index + 1;
Expand All @@ -80,7 +88,7 @@ class LoopVectorize : public IRVisitor {
if (vectorize == 1)
return;
int original_width = stmt->width();
stmt->ret_type.width *= vectorize;
widen_type(stmt->ret_type, vectorize);
stmt->ptr.repeat(vectorize);
// TODO: this can be buggy
int stride = stmt->ptr[original_width - 1].offset + 1;
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor {
}

std::size_t allocate_global(DataType type) {
TI_ASSERT(type.width == 1);
TI_ASSERT(type->vector_width() == 1);
auto ret = global_offset;
global_offset += data_type_size(type);
TI_ASSERT(global_offset < taichi_global_tmp_buffer_size);
Expand Down
18 changes: 9 additions & 9 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class TypeCheck : public IRVisitor {
stmt->indices[i] =
insert_type_cast_before(stmt, stmt->indices[i], PrimitiveType::i32);
}
TI_ASSERT(stmt->indices[i]->ret_type.width == stmt->snodes.size());
TI_ASSERT(stmt->indices[i]->width() == stmt->snodes.size());
}
}

Expand Down Expand Up @@ -272,8 +272,7 @@ class TypeCheck : public IRVisitor {
}
}
bool matching = true;
matching =
matching && (stmt->lhs->ret_type.width == stmt->rhs->ret_type.width);
matching = matching && (stmt->lhs->width() == stmt->rhs->width());
matching = matching && (stmt->lhs->ret_type != PrimitiveType::unknown);
matching = matching && (stmt->rhs->ret_type != PrimitiveType::unknown);
matching = matching && (stmt->lhs->ret_type == stmt->rhs->ret_type);
Expand All @@ -286,8 +285,7 @@ class TypeCheck : public IRVisitor {
}
}
if (is_comparison(stmt->op_type)) {
stmt->ret_type =
LegacyVectorType(stmt->lhs->ret_type.width, PrimitiveType::i32);
stmt->ret_type = LegacyVectorType(stmt->lhs->width(), PrimitiveType::i32);
} else {
stmt->ret_type = stmt->lhs->ret_type;
}
Expand All @@ -297,8 +295,10 @@ class TypeCheck : public IRVisitor {
if (stmt->op_type == TernaryOpType::select) {
auto ret_type = promoted_type(stmt->op2->ret_type, stmt->op3->ret_type);
TI_ASSERT(stmt->op1->ret_type == PrimitiveType::i32)
TI_ASSERT(stmt->op1->ret_type.width == stmt->op2->ret_type.width);
TI_ASSERT(stmt->op2->ret_type.width == stmt->op3->ret_type.width);
TI_ASSERT(stmt->op1->ret_type->vector_width() ==
stmt->op2->ret_type->vector_width());
TI_ASSERT(stmt->op2->ret_type->vector_width() ==
stmt->op3->ret_type->vector_width());
if (ret_type != stmt->op2->ret_type) {
auto cast_stmt = insert_type_cast_before(stmt, stmt->op2, ret_type);
stmt->op2 = cast_stmt;
Expand Down Expand Up @@ -329,15 +329,15 @@ class TypeCheck : public IRVisitor {
// defined by the kernel. After that, type_check() pass will purely do
// verification, without modifying any types.
TI_ASSERT(rt != PrimitiveType::unknown);
TI_ASSERT(rt.width == 1);
TI_ASSERT(rt->vector_width() == 1);
stmt->ret_type.set_is_pointer(stmt->is_ptr);
}

void visit(KernelReturnStmt *stmt) {
// TODO: Support stmt->ret_id?
const auto &rt = stmt->ret_type;
TI_ASSERT(stmt->value->element_type() == rt);
TI_ASSERT(rt.width == 1);
TI_ASSERT(rt->vector_width() == 1);
}

void visit(ExternalPtrStmt *stmt) {
Expand Down
11 changes: 8 additions & 3 deletions taichi/transforms/vector_split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include "taichi/ir/statements.h"
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"
#include "taichi/program/program.h"

#include <algorithm>

TLANG_NAMESPACE_BEGIN
Expand Down Expand Up @@ -59,8 +61,9 @@ class BasicBlockVectorSplit : public IRVisitor {
stmt->accept(this);
origin2split[stmt] = std::vector<Stmt *>(current_split_factor, nullptr);
for (int j = 0; j < current_split_factor; j++) {
current_split[j]->element_type() = stmt->element_type();
current_split[j]->width() = max_width;
current_split[j]->ret_type =
Program::get_type_factory().get_vector_type(
max_width, stmt->element_type().get_ptr());
origin2split[stmt][j] = current_split[j].get();
}
splits.push_back(std::move(current_split));
Expand All @@ -71,8 +74,10 @@ class BasicBlockVectorSplit : public IRVisitor {
need_split = false;
stmt->accept(this);
origin2split[stmt] = std::vector<Stmt *>(1, nullptr);
current_split[0]->width() = stmt->width();
current_split[0]->element_type() = stmt->element_type();
current_split[0]->ret_type =
Program::get_type_factory().get_vector_type(
stmt->width(), stmt->element_type().get_ptr());
origin2split[stmt][0] = current_split[0].get();
std::vector<pStmt> split;
split.push_back(std::move(current_split[0]));
Expand Down

0 comments on commit 6cbe537

Please sign in to comment.