Skip to content

Commit

Permalink
[IR] Init GlobalTensorElementExpression and PtrOffsetStmt (#2543)
Browse files Browse the repository at this point in the history
* init GlobalTensorElementExpression and GlobalTensorElementExpr

* fix compile error for global load

* fix cfg optimization

* fix offset=4 for i32 and f32

* support matrix indices access, AOS SOA, and arbitrary date type

* Auto Format

* do better serialize() for GlobalTensorElementExpression and GlobalTensorElementStmt

* simplify make<ConstStmt>

* fix ti test

* fix test and add load_if_ptr for dynamic indices

* fix test

* fix more tests

* fix more tests

* fix test

* fix test

* update test

* merge condition in load_if_ptr

* replace is_AOS with is_aos

* remove is_bit_vectorized

* add default value for GlobalTensorElementStmt

* add default value for GlobalTensorElementExpression

* update computation for field size

* Update taichi/ir/expr.cpp

Co-authored-by: xumingkuan <xumingkuan0721@126.com>

* comment is_global_ptr()

* nit

* Auto Format

* add extension and fix tests

* Auto Format

* do actual type check

* rename GlobalTensorElementStmt into ShiftGlobalPtrStmt

* Auto Format

* nit address_offset

* rename ShiftGlobalPtrStmt into PtrOffsetStmt

* Auto Format

Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
Co-authored-by: xumingkuan <xumingkuan0721@126.com>
  • Loading branch information
3 people authored Jul 31, 2021
1 parent 7a2d5d1 commit e71f2eb
Show file tree
Hide file tree
Showing 20 changed files with 242 additions and 26 deletions.
7 changes: 7 additions & 0 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ def subscript(value, *indices):
return value[indices]


@taichi_scope
def subscript_with_offset(var, indices, cols, is_aos):
return Expr(
_ti_core.subscript_with_offset(var.ptr, make_expr_group(*indices),
cols, is_aos))


@taichi_scope
def chain_compare(comparators, ops):
_taichi_skip_traceback = 1
Expand Down
10 changes: 9 additions & 1 deletion python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,15 @@ def subscript(self, *indices):
assert len(indices) in [1, 2]
i = indices[0]
j = 0 if len(indices) == 1 else indices[1]
return self(i, j)
# ptr.is_global_ptr() will check whether it's an element in the field (which is different from ptr.is_global_var()).
if isinstance(self.entries[0],
ti.Expr) and self.entries[0].ptr.is_global_ptr(
) and ti.is_extension_supported(
ti.cfg.arch, ti.extension.dynamic_index):
return ti.subscript_with_offset(self.entries[0], (i, j),
self.m, True)
else:
return self(i, j)

@property
def x(self):
Expand Down
11 changes: 11 additions & 0 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,17 @@ void CodeGenLLVM::visit(GetChStmt *stmt) {
}
}

void CodeGenLLVM::visit(PtrOffsetStmt *stmt) {
auto origin_address = builder->CreatePtrToInt(
llvm_val[stmt->origin], llvm::Type::getInt64Ty(*llvm_context));
auto address_offset = builder->CreateSExt(
llvm_val[stmt->offset], llvm::Type::getInt64Ty(*llvm_context));
auto target_address = builder->CreateAdd(origin_address, address_offset);
auto dt = stmt->ret_type.ptr_removed();
llvm_val[stmt] = builder->CreateIntToPtr(
target_address, llvm::PointerType::get(tlctx->get_data_type(dt), 0));
}

void CodeGenLLVM::visit(ExternalPtrStmt *stmt) {
TI_ASSERT(stmt->width() == 1);

Expand Down
2 changes: 2 additions & 0 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(GlobalPtrStmt *stmt) override;

void visit(PtrOffsetStmt *stmt) override;

void store_custom_int(llvm::Value *bit_ptr,
CustomIntType *cit,
llvm::Value *value,
Expand Down
2 changes: 2 additions & 0 deletions taichi/inc/extensions.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ PER_EXTENSION(bls) // Block-local storage
PER_EXTENSION(assertion) // Run-time asserts in Taichi kernels
PER_EXTENSION(extfunc) // Invoke external functions or backend source
PER_EXTENSION(packed) // Shape will not be padded to a power of two
PER_EXTENSION(
dynamic_index) // Dynamic index support for both global and local tensors
1 change: 1 addition & 0 deletions taichi/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ PER_STATEMENT(ReturnStmt)

PER_STATEMENT(ArgLoadStmt)
PER_STATEMENT(ExternalPtrStmt)
PER_STATEMENT(PtrOffsetStmt)
PER_STATEMENT(ConstStmt)
PER_STATEMENT(AllocaStmt)
PER_STATEMENT(UnaryOpStmt)
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) {
auto stmt = nodes[i]->block->statements[j].get();
if (stmt->is<GlobalPtrStmt>() || stmt->is<ExternalPtrStmt>() ||
stmt->is<BlockLocalPtrStmt>() || stmt->is<ThreadLocalPtrStmt>() ||
stmt->is<GlobalTemporaryStmt>()) {
stmt->is<GlobalTemporaryStmt>() || stmt->is<PtrOffsetStmt>()) {
// TODO: unify them
// A global pointer that may contain some data before this kernel.
nodes[start_node]->reach_gen.insert(stmt);
Expand Down
6 changes: 4 additions & 2 deletions taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ void Expr::operator/=(const Expr &o) {
}

Expr load_if_ptr(const Expr &ptr) {
if (ptr.is<GlobalPtrExpression>()) {
if (ptr.is<GlobalPtrExpression>() ||
ptr.is<GlobalTensorElementExpression>()) {
return load(ptr);
} else if (ptr.is<GlobalVariableExpression>()) {
TI_ASSERT(ptr.cast<GlobalVariableExpression>()->snode->num_active_indices ==
Expand All @@ -172,7 +173,8 @@ Expr load_if_ptr(const Expr &ptr) {
}

Expr load(const Expr &ptr) {
TI_ASSERT(ptr.is<GlobalPtrExpression>());
TI_ASSERT(ptr.is<GlobalPtrExpression>() ||
ptr.is<GlobalTensorElementExpression>());
return Expr::make<GlobalLoadExpression>(ptr);
}

Expand Down
57 changes: 57 additions & 0 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,59 @@ void GlobalPtrExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void GlobalTensorElementExpression::flatten(FlattenContext *ctx) {
TI_ASSERT(var.is<GlobalPtrExpression>())
var->flatten(ctx);
Stmt *var_stmt = ctx->back_stmt();
SNode *snode = var.cast<GlobalPtrExpression>()
->var.cast<GlobalVariableExpression>()
->snode;
// Compute exact offset
// Type A[i, j][x, y]
// ^^^^
TI_ASSERT(1 <= indices.size() && indices.size() <= 2)
if (indices.size() == 1) {
indices[0].set(load_if_ptr(indices[0]));
indices[0]->flatten(ctx);
} else {
indices[0].set(load_if_ptr(indices[0]));
indices[0]->flatten(ctx);
Stmt *i_stmt = ctx->back_stmt();
Stmt *cols_stmt =
ctx->push_back(Stmt::make<ConstStmt>(TypedConstant(cols)));
Stmt *i_mul_cols_stmt = ctx->push_back(
Stmt::make<BinaryOpStmt>(BinaryOpType::mul, i_stmt, cols_stmt));
indices[1].set(load_if_ptr(indices[1]));
indices[1]->flatten(ctx);
Stmt *j_stmt = ctx->back_stmt();
ctx->push_back(
Stmt::make<BinaryOpStmt>(BinaryOpType::add, i_mul_cols_stmt, j_stmt));
}
// Type A[i, j][x, y]
// ^ ^
if (!is_aos) {
TI_ASSERT(snode->is_path_all_dense)
int size = 1;
for (auto *s = snode; s != nullptr; s = s->parent)
size *= (int)s->max_num_elements();
Stmt *offset_stmt = ctx->back_stmt();
Stmt *field_size_stmt =
ctx->push_back(Stmt::make<ConstStmt>(TypedConstant(size)));
ctx->push_back(Stmt::make<BinaryOpStmt>(BinaryOpType::mul, offset_stmt,
field_size_stmt));
}
// Type A[i, j][x, y]
// ^^^^
Stmt *offset_stmt = ctx->back_stmt();
Stmt *dt_size_stmt = ctx->push_back(
Stmt::make<ConstStmt>(TypedConstant(data_type_size(snode->dt))));
ctx->push_back(
Stmt::make<BinaryOpStmt>(BinaryOpType::mul, offset_stmt, dt_size_stmt));

ctx->push_back(std::make_unique<PtrOffsetStmt>(var_stmt, ctx->back_stmt()));
stmt = ctx->back_stmt();
}

void RangeAssumptionExpression::flatten(FlattenContext *ctx) {
input->flatten(ctx);
base->flatten(ctx);
Expand Down Expand Up @@ -297,6 +350,10 @@ void AtomicOpExpression::flatten(FlattenContext *ctx) {
// emit local store stmt
auto alloca = ctx->current_block->lookup_var(dest.cast<IdExpression>()->id);
ctx->push_back<AtomicOpStmt>(op_type, alloca, expr->stmt);
} else if (dest.is<GlobalTensorElementExpression>()) {
auto global_ptr = dest.cast<GlobalTensorElementExpression>();
global_ptr->flatten(ctx);
ctx->push_back<AtomicOpStmt>(op_type, ctx->back_stmt(), expr->stmt);
} else { // global variable
TI_ASSERT(dest.is<GlobalPtrExpression>());
auto global_ptr = dest.cast<GlobalPtrExpression>();
Expand Down
33 changes: 33 additions & 0 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,39 @@ class GlobalPtrExpression : public Expression {
}
};

class GlobalTensorElementExpression : public Expression {
public:
Expr var;
ExprGroup indices;
int cols{0};
bool is_aos{false};

GlobalTensorElementExpression(const Expr &var,
const ExprGroup &indices,
int cols,
bool is_aos)
: var(var), indices(indices), cols(cols), is_aos(is_aos) {
}

std::string serialize() override {
std::string s = fmt::format("{}[", var.serialize());
for (int i = 0; i < (int)indices.size(); i++) {
s += indices.exprs[i]->serialize();
if (i + 1 < (int)indices.size())
s += ", ";
}
s += "]";
s += " (col=" + std::to_string(cols) + (is_aos ? ", AOS)" : ", SOA)");
return s;
}

void flatten(FlattenContext *ctx) override;

bool is_lvalue() const override {
return true;
}
};

class EvalExpression : public Expression {
public:
Stmt *stmt_ptr;
Expand Down
24 changes: 24 additions & 0 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,30 @@ class GlobalPtrStmt : public Stmt {
TI_DEFINE_ACCEPT_AND_CLONE
};

/**
* An accessing tensor element operation.
*/
class PtrOffsetStmt : public Stmt {
public:
Stmt *origin{nullptr};
Stmt *offset{nullptr};

PtrOffsetStmt(Stmt *origin, Stmt *offset) : origin(origin), offset(offset) {
element_type() = origin->cast<GlobalPtrStmt>()->ret_type;
TI_STMT_REG_FIELDS;
}

bool has_global_side_effect() const override {
// After access lowered, activate info will be recorded in SNodeLookupStmt's
// activate for AOS sparse data structure. We don't support SOA sparse data
// structure for now.
return false;
}

TI_STMT_DEF_FIELDS(ret_type, origin, offset);
TI_DEFINE_ACCEPT_AND_CLONE
};

/**
* An operation to a SNode (not necessarily a leaf SNode).
*/
Expand Down
9 changes: 9 additions & 0 deletions taichi/program/async_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,15 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {
ir_bank->get_async_state(snode, AsyncState::Type::value));
}
}
if (auto global_tensor_element =
global_store->dest->cast<PtrOffsetStmt>()) {
if (auto dest = global_tensor_element->origin->cast<GlobalPtrStmt>()) {
for (auto &snode : dest->snodes.data) {
meta.output_states.insert(
ir_bank->get_async_state(snode, AsyncState::Type::value));
}
}
}
}
if (auto global_atomic = stmt->cast<AtomicOpStmt>()) {
if (auto dest = global_atomic->dest->cast<GlobalPtrStmt>()) {
Expand Down
8 changes: 5 additions & 3 deletions taichi/program/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@ bool is_extension_supported(Arch arch, Extension ext) {
{Arch::x64,
{Extension::sparse, Extension::async_mode, Extension::quant,
Extension::quant_basic, Extension::data64, Extension::adstack,
Extension::assertion, Extension::extfunc, Extension::packed}},
Extension::assertion, Extension::extfunc, Extension::packed,
Extension::dynamic_index}},
{Arch::arm64,
{Extension::sparse, Extension::async_mode, Extension::quant,
Extension::quant_basic, Extension::data64, Extension::adstack,
Extension::assertion, Extension::packed}},
Extension::assertion, Extension::packed, Extension::dynamic_index}},
{Arch::cuda,
{Extension::sparse, Extension::async_mode, Extension::quant,
Extension::quant_basic, Extension::data64, Extension::adstack,
Extension::bls, Extension::assertion, Extension::packed}},
Extension::bls, Extension::assertion, Extension::packed,
Extension::dynamic_index}},
{Arch::metal,
{Extension::adstack, Extension::assertion, Extension::quant_basic,
Extension::async_mode, Extension::sparse}},
Expand Down
8 changes: 8 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,8 @@ void export_lang(py::module &m) {
[](Expr *expr) { return expr->is<GlobalVariableExpression>(); })
.def("is_external_var",
[](Expr *expr) { return expr->is<ExternalTensorExpression>(); })
.def("is_global_ptr",
[](Expr *expr) { return expr->is<GlobalPtrExpression>(); })
.def("is_primal",
[](Expr *expr) {
return expr->cast<GlobalVariableExpression>()->is_primal;
Expand Down Expand Up @@ -698,6 +700,12 @@ void export_lang(py::module &m) {
return expr[expr_group];
});

m.def("subscript_with_offset",
[](const Expr &var, const ExprGroup &indices, int cols, bool is_aos) {
return Expr::make<GlobalTensorElementExpression>(var, indices, cols,
is_aos);
});

m.def("subscript", [](SNode *snode, const ExprGroup &indices) {
return Expr::make<GlobalPtrExpression>(snode, indices.loaded());
});
Expand Down
6 changes: 6 additions & 0 deletions taichi/transforms/flag_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ class FlagAccess : public IRVisitor {
if (stmt->dest->is<GlobalPtrStmt>()) {
stmt->dest->as<GlobalPtrStmt>()->activate = true;
}
if (stmt->dest->is<PtrOffsetStmt>()) {
if (stmt->dest->as<PtrOffsetStmt>()->origin->is<GlobalPtrStmt>()) {
stmt->dest->as<PtrOffsetStmt>()->origin->as<GlobalPtrStmt>()->activate =
true;
}
}
}

void visit(AtomicOpStmt *stmt) {
Expand Down
7 changes: 7 additions & 0 deletions taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,13 @@ class IRPrinter : public IRVisitor {
print_raw(s);
}

void visit(PtrOffsetStmt *stmt) override {
std::string s =
fmt::format("{}{} = shift ptr [{} + {}]", stmt->type_hint(),
stmt->name(), stmt->origin->name(), stmt->offset->name());
print_raw(s);
}

void visit(ArgLoadStmt *stmt) override {
print("{}{} = arg[{}]", stmt->type_hint(), stmt->name(), stmt->arg_id);
}
Expand Down
40 changes: 26 additions & 14 deletions taichi/transforms/lower_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,23 +148,35 @@ class LowerAccess : public IRVisitor {
}

void visit(GlobalLoadStmt *stmt) override {
if (stmt->src->is<GlobalPtrStmt>()) {
// No need to activate for all read accesses
auto lowered = lower_vector_ptr(stmt->src->as<GlobalPtrStmt>(), false);
stmt->src = lowered.back().get();
modifier.insert_before(stmt, std::move(lowered));
}
if (!stmt->src->is<GlobalPtrStmt>())
return;
// No need to activate for all read accesses
auto lowered = lower_vector_ptr(stmt->src->as<GlobalPtrStmt>(), false);
stmt->src = lowered.back().get();
modifier.insert_before(stmt, std::move(lowered));
}

// TODO: this seems to be redundant
void visit(PtrOffsetStmt *stmt) override {
if (!stmt->origin->is<GlobalPtrStmt>())
return;
auto ptr = stmt->origin->as<GlobalPtrStmt>();
// If ptr already has activate = false, no need to activate all the
// generated micro-access ops. Otherwise, activate the nodes.
auto lowered = lower_vector_ptr(ptr, ptr->activate);
stmt->origin = lowered.back().get();
modifier.insert_before(stmt, std::move(lowered));
}

void visit(GlobalStoreStmt *stmt) override {
if (stmt->dest->is<GlobalPtrStmt>()) {
auto ptr = stmt->dest->as<GlobalPtrStmt>();
// If ptr already has activate = false, no need to activate all the
// generated micro-access ops. Otherwise, activate the nodes.
auto lowered = lower_vector_ptr(ptr, ptr->activate);
stmt->dest = lowered.back().get();
modifier.insert_before(stmt, std::move(lowered));
}
if (!stmt->dest->is<GlobalPtrStmt>())
return;
auto ptr = stmt->dest->as<GlobalPtrStmt>();
// If ptr already has activate = false, no need to activate all the
// generated micro-access ops. Otherwise, activate the nodes.
auto lowered = lower_vector_ptr(ptr, ptr->activate);
stmt->dest = lowered.back().get();
modifier.insert_before(stmt, std::move(lowered));
}

void visit(SNodeOpStmt *stmt) override {
Expand Down
4 changes: 4 additions & 0 deletions taichi/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,10 @@ class LowerAST : public IRVisitor {
fctx.push_back<LocalStoreStmt>(
assign->parent->lookup_var(assign->lhs.cast<IdExpression>()->id),
expr->stmt);
} else if (assign->lhs.is<GlobalTensorElementExpression>()) {
auto global_ptr = assign->lhs.cast<GlobalTensorElementExpression>();
global_ptr->flatten(&fctx);
fctx.push_back<GlobalStoreStmt>(fctx.back_stmt(), expr->stmt);
} else { // global variable
TI_ASSERT(assign->lhs.is<GlobalPtrExpression>());
auto global_ptr = assign->lhs.cast<GlobalPtrExpression>();
Expand Down
Loading

0 comments on commit e71f2eb

Please sign in to comment.