Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Init GlobalTensorElementExpression and PtrOffsetStmt #2543

Merged
merged 36 commits into from
Jul 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
25227b0
init GlobalTensorElementExpression and GlobalTensorElementExpr
squarefk Jul 19, 2021
b0e3b11
fix compile error for global load
squarefk Jul 19, 2021
a9d3dc2
fix cfg optimization
squarefk Jul 19, 2021
5aaf917
fix offset=4 for i32 and f32
squarefk Jul 19, 2021
1e2ab62
support matrix indices access, AOS SOA, and arbitrary date type
squarefk Jul 21, 2021
cbc7950
Auto Format
taichi-gardener Jul 21, 2021
7684a3f
do better serialize() for GlobalTensorElementExpression and GlobalTen…
squarefk Jul 21, 2021
fbd0f1a
simplify make<ConstStmt>
squarefk Jul 21, 2021
1120afb
fix ti test
squarefk Jul 22, 2021
11d89e0
fix test and add load_if_ptr for dynamic indices
squarefk Jul 22, 2021
2705249
fix test
squarefk Jul 22, 2021
587569b
fix more tests
squarefk Jul 23, 2021
b009a94
fix more tests
squarefk Jul 23, 2021
4d68187
fix test
squarefk Jul 26, 2021
e445596
fix test
squarefk Jul 26, 2021
cc6317d
update test
squarefk Jul 26, 2021
9d3ecb1
merge condition in load_if_ptr
squarefk Jul 26, 2021
46dd8a2
replace is_AOS with is_aos
squarefk Jul 26, 2021
6922a51
remove is_bit_vectorized
squarefk Jul 26, 2021
d059195
add default value for GlobalTensorElementStmt
squarefk Jul 26, 2021
e3cc893
add default value for GlobalTensorElementExpression
squarefk Jul 26, 2021
915941e
Merge branch 'master' of https://github.com/taichi-dev/taichi
squarefk Jul 26, 2021
ff80ccd
update computation for field size
squarefk Jul 26, 2021
99e4a18
Update taichi/ir/expr.cpp
squarefk Jul 27, 2021
a68187f
comment is_global_ptr()
squarefk Jul 27, 2021
1123d63
nit
squarefk Jul 27, 2021
da3c394
Auto Format
taichi-gardener Jul 27, 2021
9985bac
Merge branch 'master' of https://github.com/taichi-dev/taichi
squarefk Jul 28, 2021
848db37
add extension and fix tests
squarefk Jul 29, 2021
6a43da7
Auto Format
taichi-gardener Jul 29, 2021
e37ed47
do actual type check
squarefk Jul 29, 2021
6c2ea50
rename GlobalTensorElementStmt into ShiftGlobalPtrStmt
squarefk Jul 29, 2021
cccd882
Auto Format
taichi-gardener Jul 29, 2021
d9bdf42
nit address_offset
squarefk Jul 30, 2021
3c20d48
rename ShiftGlobalPtrStmt into PtrOffsetStmt
squarefk Jul 30, 2021
3da4b3e
Auto Format
taichi-gardener Jul 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ def subscript(value, *indices):
return value[indices]


@taichi_scope
def subscript_with_offset(var, indices, cols, is_aos):
return Expr(
_ti_core.subscript_with_offset(var.ptr, make_expr_group(*indices),
cols, is_aos))


@taichi_scope
def chain_compare(comparators, ops):
_taichi_skip_traceback = 1
Expand Down
10 changes: 9 additions & 1 deletion python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,15 @@ def subscript(self, *indices):
assert len(indices) in [1, 2]
i = indices[0]
j = 0 if len(indices) == 1 else indices[1]
return self(i, j)
# ptr.is_global_ptr() will check whether it's an element in the field (which is different from ptr.is_global_var()).
if isinstance(self.entries[0],
ti.Expr) and self.entries[0].ptr.is_global_ptr(
) and ti.is_extension_supported(
ti.cfg.arch, ti.extension.dynamic_index):
return ti.subscript_with_offset(self.entries[0], (i, j),
self.m, True)
else:
return self(i, j)

@property
def x(self):
Expand Down
11 changes: 11 additions & 0 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,17 @@ void CodeGenLLVM::visit(GetChStmt *stmt) {
}
}

void CodeGenLLVM::visit(PtrOffsetStmt *stmt) {
auto origin_address = builder->CreatePtrToInt(
llvm_val[stmt->origin], llvm::Type::getInt64Ty(*llvm_context));
auto address_offset = builder->CreateSExt(
llvm_val[stmt->offset], llvm::Type::getInt64Ty(*llvm_context));
auto target_address = builder->CreateAdd(origin_address, address_offset);
auto dt = stmt->ret_type.ptr_removed();
llvm_val[stmt] = builder->CreateIntToPtr(
target_address, llvm::PointerType::get(tlctx->get_data_type(dt), 0));
}

void CodeGenLLVM::visit(ExternalPtrStmt *stmt) {
TI_ASSERT(stmt->width() == 1);

Expand Down
2 changes: 2 additions & 0 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(GlobalPtrStmt *stmt) override;

void visit(PtrOffsetStmt *stmt) override;

void store_custom_int(llvm::Value *bit_ptr,
CustomIntType *cit,
llvm::Value *value,
Expand Down
2 changes: 2 additions & 0 deletions taichi/inc/extensions.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ PER_EXTENSION(bls) // Block-local storage
PER_EXTENSION(assertion) // Run-time asserts in Taichi kernels
PER_EXTENSION(extfunc) // Invoke external functions or backend source
PER_EXTENSION(packed) // Shape will not be padded to a power of two
PER_EXTENSION(
dynamic_index) // Dynamic index support for both global and local tensors
1 change: 1 addition & 0 deletions taichi/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ PER_STATEMENT(ReturnStmt)

PER_STATEMENT(ArgLoadStmt)
PER_STATEMENT(ExternalPtrStmt)
PER_STATEMENT(PtrOffsetStmt)
PER_STATEMENT(ConstStmt)
PER_STATEMENT(AllocaStmt)
PER_STATEMENT(UnaryOpStmt)
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) {
auto stmt = nodes[i]->block->statements[j].get();
if (stmt->is<GlobalPtrStmt>() || stmt->is<ExternalPtrStmt>() ||
stmt->is<BlockLocalPtrStmt>() || stmt->is<ThreadLocalPtrStmt>() ||
stmt->is<GlobalTemporaryStmt>()) {
stmt->is<GlobalTemporaryStmt>() || stmt->is<PtrOffsetStmt>()) {
// TODO: unify them
// A global pointer that may contain some data before this kernel.
nodes[start_node]->reach_gen.insert(stmt);
Expand Down
6 changes: 4 additions & 2 deletions taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ void Expr::operator/=(const Expr &o) {
}

Expr load_if_ptr(const Expr &ptr) {
if (ptr.is<GlobalPtrExpression>()) {
if (ptr.is<GlobalPtrExpression>() ||
ptr.is<GlobalTensorElementExpression>()) {
return load(ptr);
} else if (ptr.is<GlobalVariableExpression>()) {
TI_ASSERT(ptr.cast<GlobalVariableExpression>()->snode->num_active_indices ==
Expand All @@ -172,7 +173,8 @@ Expr load_if_ptr(const Expr &ptr) {
}

Expr load(const Expr &ptr) {
TI_ASSERT(ptr.is<GlobalPtrExpression>());
TI_ASSERT(ptr.is<GlobalPtrExpression>() ||
ptr.is<GlobalTensorElementExpression>());
return Expr::make<GlobalLoadExpression>(ptr);
}

Expand Down
57 changes: 57 additions & 0 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,59 @@ void GlobalPtrExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void GlobalTensorElementExpression::flatten(FlattenContext *ctx) {
TI_ASSERT(var.is<GlobalPtrExpression>())
var->flatten(ctx);
Stmt *var_stmt = ctx->back_stmt();
SNode *snode = var.cast<GlobalPtrExpression>()
->var.cast<GlobalVariableExpression>()
->snode;
// Compute exact offset
// Type A[i, j][x, y]
// ^^^^
TI_ASSERT(1 <= indices.size() && indices.size() <= 2)
if (indices.size() == 1) {
indices[0].set(load_if_ptr(indices[0]));
indices[0]->flatten(ctx);
} else {
indices[0].set(load_if_ptr(indices[0]));
indices[0]->flatten(ctx);
Stmt *i_stmt = ctx->back_stmt();
Stmt *cols_stmt =
ctx->push_back(Stmt::make<ConstStmt>(TypedConstant(cols)));
Stmt *i_mul_cols_stmt = ctx->push_back(
Stmt::make<BinaryOpStmt>(BinaryOpType::mul, i_stmt, cols_stmt));
indices[1].set(load_if_ptr(indices[1]));
indices[1]->flatten(ctx);
Stmt *j_stmt = ctx->back_stmt();
ctx->push_back(
Stmt::make<BinaryOpStmt>(BinaryOpType::add, i_mul_cols_stmt, j_stmt));
}
// Type A[i, j][x, y]
// ^ ^
if (!is_aos) {
TI_ASSERT(snode->is_path_all_dense)
int size = 1;
for (auto *s = snode; s != nullptr; s = s->parent)
size *= (int)s->max_num_elements();
Stmt *offset_stmt = ctx->back_stmt();
Stmt *field_size_stmt =
ctx->push_back(Stmt::make<ConstStmt>(TypedConstant(size)));
ctx->push_back(Stmt::make<BinaryOpStmt>(BinaryOpType::mul, offset_stmt,
field_size_stmt));
}
// Type A[i, j][x, y]
// ^^^^
Stmt *offset_stmt = ctx->back_stmt();
Stmt *dt_size_stmt = ctx->push_back(
Stmt::make<ConstStmt>(TypedConstant(data_type_size(snode->dt))));
ctx->push_back(
Stmt::make<BinaryOpStmt>(BinaryOpType::mul, offset_stmt, dt_size_stmt));

ctx->push_back(std::make_unique<PtrOffsetStmt>(var_stmt, ctx->back_stmt()));
stmt = ctx->back_stmt();
}

void RangeAssumptionExpression::flatten(FlattenContext *ctx) {
input->flatten(ctx);
base->flatten(ctx);
Expand Down Expand Up @@ -297,6 +350,10 @@ void AtomicOpExpression::flatten(FlattenContext *ctx) {
// emit local store stmt
auto alloca = ctx->current_block->lookup_var(dest.cast<IdExpression>()->id);
ctx->push_back<AtomicOpStmt>(op_type, alloca, expr->stmt);
} else if (dest.is<GlobalTensorElementExpression>()) {
auto global_ptr = dest.cast<GlobalTensorElementExpression>();
global_ptr->flatten(ctx);
ctx->push_back<AtomicOpStmt>(op_type, ctx->back_stmt(), expr->stmt);
} else { // global variable
TI_ASSERT(dest.is<GlobalPtrExpression>());
auto global_ptr = dest.cast<GlobalPtrExpression>();
Expand Down
33 changes: 33 additions & 0 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,39 @@ class GlobalPtrExpression : public Expression {
}
};

class GlobalTensorElementExpression : public Expression {
public:
Expr var;
ExprGroup indices;
int cols{0};
bool is_aos{false};

GlobalTensorElementExpression(const Expr &var,
const ExprGroup &indices,
int cols,
bool is_aos)
: var(var), indices(indices), cols(cols), is_aos(is_aos) {
}

std::string serialize() override {
std::string s = fmt::format("{}[", var.serialize());
for (int i = 0; i < (int)indices.size(); i++) {
s += indices.exprs[i]->serialize();
if (i + 1 < (int)indices.size())
s += ", ";
}
s += "]";
s += " (col=" + std::to_string(cols) + (is_aos ? ", AOS)" : ", SOA)");
return s;
}

void flatten(FlattenContext *ctx) override;

bool is_lvalue() const override {
return true;
}
};

class EvalExpression : public Expression {
public:
Stmt *stmt_ptr;
Expand Down
24 changes: 24 additions & 0 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,30 @@ class GlobalPtrStmt : public Stmt {
TI_DEFINE_ACCEPT_AND_CLONE
};

/**
* An accessing tensor element operation.
squarefk marked this conversation as resolved.
Show resolved Hide resolved
*/
class PtrOffsetStmt : public Stmt {
public:
Stmt *origin{nullptr};
Stmt *offset{nullptr};

PtrOffsetStmt(Stmt *origin, Stmt *offset) : origin(origin), offset(offset) {
element_type() = origin->cast<GlobalPtrStmt>()->ret_type;
TI_STMT_REG_FIELDS;
}

bool has_global_side_effect() const override {
// After access lowered, activate info will be recorded in SNodeLookupStmt's
// activate for AOS sparse data structure. We don't support SOA sparse data
// structure for now.
return false;
squarefk marked this conversation as resolved.
Show resolved Hide resolved
}

TI_STMT_DEF_FIELDS(ret_type, origin, offset);
TI_DEFINE_ACCEPT_AND_CLONE
};

/**
* An operation to a SNode (not necessarily a leaf SNode).
*/
Expand Down
9 changes: 9 additions & 0 deletions taichi/program/async_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,15 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {
ir_bank->get_async_state(snode, AsyncState::Type::value));
}
}
if (auto global_tensor_element =
global_store->dest->cast<PtrOffsetStmt>()) {
if (auto dest = global_tensor_element->origin->cast<GlobalPtrStmt>()) {
for (auto &snode : dest->snodes.data) {
meta.output_states.insert(
ir_bank->get_async_state(snode, AsyncState::Type::value));
}
}
}
}
if (auto global_atomic = stmt->cast<AtomicOpStmt>()) {
if (auto dest = global_atomic->dest->cast<GlobalPtrStmt>()) {
Expand Down
8 changes: 5 additions & 3 deletions taichi/program/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@ bool is_extension_supported(Arch arch, Extension ext) {
{Arch::x64,
{Extension::sparse, Extension::async_mode, Extension::quant,
Extension::quant_basic, Extension::data64, Extension::adstack,
Extension::assertion, Extension::extfunc, Extension::packed}},
Extension::assertion, Extension::extfunc, Extension::packed,
Extension::dynamic_index}},
{Arch::arm64,
{Extension::sparse, Extension::async_mode, Extension::quant,
Extension::quant_basic, Extension::data64, Extension::adstack,
Extension::assertion, Extension::packed}},
Extension::assertion, Extension::packed, Extension::dynamic_index}},
{Arch::cuda,
{Extension::sparse, Extension::async_mode, Extension::quant,
Extension::quant_basic, Extension::data64, Extension::adstack,
Extension::bls, Extension::assertion, Extension::packed}},
Extension::bls, Extension::assertion, Extension::packed,
Extension::dynamic_index}},
{Arch::metal,
{Extension::adstack, Extension::assertion, Extension::quant_basic,
Extension::async_mode, Extension::sparse}},
Expand Down
8 changes: 8 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ void export_lang(py::module &m) {
[](Expr *expr) { return expr->is<GlobalVariableExpression>(); })
.def("is_external_var",
[](Expr *expr) { return expr->is<ExternalTensorExpression>(); })
.def("is_global_ptr",
[](Expr *expr) { return expr->is<GlobalPtrExpression>(); })
.def("is_primal",
[](Expr *expr) {
return expr->cast<GlobalVariableExpression>()->is_primal;
Expand Down Expand Up @@ -690,6 +692,12 @@ void export_lang(py::module &m) {
return expr[expr_group];
});

m.def("subscript_with_offset",
[](const Expr &var, const ExprGroup &indices, int cols, bool is_aos) {
return Expr::make<GlobalTensorElementExpression>(var, indices, cols,
is_aos);
});

m.def("subscript", [](SNode *snode, const ExprGroup &indices) {
return Expr::make<GlobalPtrExpression>(snode, indices.loaded());
});
Expand Down
6 changes: 6 additions & 0 deletions taichi/transforms/flag_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ class FlagAccess : public IRVisitor {
if (stmt->dest->is<GlobalPtrStmt>()) {
stmt->dest->as<GlobalPtrStmt>()->activate = true;
}
if (stmt->dest->is<PtrOffsetStmt>()) {
if (stmt->dest->as<PtrOffsetStmt>()->origin->is<GlobalPtrStmt>()) {
stmt->dest->as<PtrOffsetStmt>()->origin->as<GlobalPtrStmt>()->activate =
true;
}
}
}

void visit(AtomicOpStmt *stmt) {
Expand Down
7 changes: 7 additions & 0 deletions taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,13 @@ class IRPrinter : public IRVisitor {
print_raw(s);
}

void visit(PtrOffsetStmt *stmt) override {
std::string s =
fmt::format("{}{} = shift ptr [{} + {}]", stmt->type_hint(),
stmt->name(), stmt->origin->name(), stmt->offset->name());
print_raw(s);
}

void visit(ArgLoadStmt *stmt) override {
print("{}{} = arg[{}]", stmt->type_hint(), stmt->name(), stmt->arg_id);
}
Expand Down
40 changes: 26 additions & 14 deletions taichi/transforms/lower_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,23 +148,35 @@ class LowerAccess : public IRVisitor {
}

void visit(GlobalLoadStmt *stmt) override {
if (stmt->src->is<GlobalPtrStmt>()) {
// No need to activate for all read accesses
auto lowered = lower_vector_ptr(stmt->src->as<GlobalPtrStmt>(), false);
stmt->src = lowered.back().get();
modifier.insert_before(stmt, std::move(lowered));
}
if (!stmt->src->is<GlobalPtrStmt>())
return;
// No need to activate for all read accesses
auto lowered = lower_vector_ptr(stmt->src->as<GlobalPtrStmt>(), false);
stmt->src = lowered.back().get();
modifier.insert_before(stmt, std::move(lowered));
}

// TODO: this seems to be redundant
void visit(PtrOffsetStmt *stmt) override {
if (!stmt->origin->is<GlobalPtrStmt>())
return;
auto ptr = stmt->origin->as<GlobalPtrStmt>();
// If ptr already has activate = false, no need to activate all the
// generated micro-access ops. Otherwise, activate the nodes.
auto lowered = lower_vector_ptr(ptr, ptr->activate);
stmt->origin = lowered.back().get();
modifier.insert_before(stmt, std::move(lowered));
}

void visit(GlobalStoreStmt *stmt) override {
if (stmt->dest->is<GlobalPtrStmt>()) {
auto ptr = stmt->dest->as<GlobalPtrStmt>();
// If ptr already has activate = false, no need to activate all the
// generated micro-access ops. Otherwise, activate the nodes.
auto lowered = lower_vector_ptr(ptr, ptr->activate);
stmt->dest = lowered.back().get();
modifier.insert_before(stmt, std::move(lowered));
}
if (!stmt->dest->is<GlobalPtrStmt>())
return;
auto ptr = stmt->dest->as<GlobalPtrStmt>();
// If ptr already has activate = false, no need to activate all the
// generated micro-access ops. Otherwise, activate the nodes.
auto lowered = lower_vector_ptr(ptr, ptr->activate);
stmt->dest = lowered.back().get();
modifier.insert_before(stmt, std::move(lowered));
}

void visit(SNodeOpStmt *stmt) override {
Expand Down
4 changes: 4 additions & 0 deletions taichi/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,10 @@ class LowerAST : public IRVisitor {
fctx.push_back<LocalStoreStmt>(
assign->parent->lookup_var(assign->lhs.cast<IdExpression>()->id),
expr->stmt);
} else if (assign->lhs.is<GlobalTensorElementExpression>()) {
auto global_ptr = assign->lhs.cast<GlobalTensorElementExpression>();
global_ptr->flatten(&fctx);
fctx.push_back<GlobalStoreStmt>(fctx.back_stmt(), expr->stmt);
} else { // global variable
TI_ASSERT(assign->lhs.is<GlobalPtrExpression>());
auto global_ptr = assign->lhs.cast<GlobalPtrExpression>();
Expand Down
Loading