Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[opt] Add pass eliminate_immutable_local_vars #6926

Merged
merged 6 commits into from
Dec 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions taichi/analysis/gather_immutable_local_vars.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include "taichi/ir/ir.h"
#include "taichi/ir/statements.h"
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"

namespace taichi::lang {

// The GatherImmutableLocalVars pass gathers all immutable local vars as input
// to the EliminateImmutableLocalVars pass. An immutable local var is an alloca
// which is stored only once (in the same block) and only loaded after that
strongoier marked this conversation as resolved.
Show resolved Hide resolved
// store.
class GatherImmutableLocalVars : public BasicStmtVisitor {
private:
using BasicStmtVisitor::visit;

enum class AllocaStatus { kCreated = 0, kStoredOnce = 1, kInvalid = 2 };
std::unordered_map<Stmt *, AllocaStatus> alloca_status_;
strongoier marked this conversation as resolved.
Show resolved Hide resolved

public:
explicit GatherImmutableLocalVars() {
invoke_default_visitor = true;
}

void visit(AllocaStmt *stmt) override {
TI_ASSERT(alloca_status_.find(stmt) == alloca_status_.end());
alloca_status_[stmt] = AllocaStatus::kCreated;
}

void visit(LocalLoadStmt *stmt) override {
if (stmt->src->is<AllocaStmt>()) {
auto status_iter = alloca_status_.find(stmt->src);
TI_ASSERT(status_iter != alloca_status_.end());
if (status_iter->second == AllocaStatus::kCreated) {
status_iter->second = AllocaStatus::kInvalid;
}
}
}

void visit(LocalStoreStmt *stmt) override {
if (stmt->dest->is<AllocaStmt>()) {
auto status_iter = alloca_status_.find(stmt->dest);
TI_ASSERT(status_iter != alloca_status_.end());
if (stmt->parent != stmt->dest->parent ||
status_iter->second == AllocaStatus::kStoredOnce ||
stmt->val->ret_type != stmt->dest->ret_type.ptr_removed()) {
status_iter->second = AllocaStatus::kInvalid;
} else if (status_iter->second == AllocaStatus::kCreated) {
status_iter->second = AllocaStatus::kStoredOnce;
}
}
}

void default_visit(Stmt *stmt) {
for (auto &op : stmt->get_operands()) {
if (op != nullptr && op->is<AllocaStmt>()) {
auto status_iter = alloca_status_.find(op);
TI_ASSERT(status_iter != alloca_status_.end());
status_iter->second = AllocaStatus::kInvalid;
}
}
}

void visit(Stmt *stmt) override {
default_visit(stmt);
}

void preprocess_container_stmt(Stmt *stmt) override {
default_visit(stmt);
}

static std::unordered_set<Stmt *> run(IRNode *node) {
GatherImmutableLocalVars pass;
node->accept(&pass);
std::unordered_set<Stmt *> result;
for (auto &[k, v] : pass.alloca_status_) {
if (v == AllocaStatus::kStoredOnce) {
result.insert(k);
}
}
return result;
}
};

namespace irpass::analysis {

std::unordered_set<Stmt *> gather_immutable_local_vars(IRNode *root) {
return GatherImmutableLocalVars::run(root);
}

} // namespace irpass::analysis

} // namespace taichi::lang
1 change: 1 addition & 0 deletions taichi/ir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ bool definitely_same_address(Stmt *var1, Stmt *var2);

std::unordered_set<Stmt *> detect_fors_with_break(IRNode *root);
std::unordered_set<Stmt *> detect_loops_with_continue(IRNode *root);
std::unordered_set<Stmt *> gather_immutable_local_vars(IRNode *root);
std::unordered_set<SNode *> gather_deactivations(IRNode *root);
std::pair<std::unordered_set<SNode *>, std::unordered_set<SNode *>>
gather_snode_read_writes(IRNode *root);
Expand Down
10 changes: 7 additions & 3 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -670,9 +670,12 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx,
}

void MatrixExpression::type_check(CompileConfig *config) {
// TODO: typecheck matrix
for (auto &arg : elements) {
TI_ASSERT_TYPE_CHECKED(arg);
if (arg->ret_type != dt.get_element_type()) {
arg = cast(arg, dt.get_element_type());
arg->type_check(config);
}
}
ret_type = dt;
}
Expand Down Expand Up @@ -1569,8 +1572,9 @@ Stmt *flatten_global_load(Stmt *ptr_stmt, Expression::FlattenContext *ctx) {
}

Stmt *flatten_local_load(Stmt *ptr_stmt, Expression::FlattenContext *ctx) {
ctx->push_back<LocalLoadStmt>(ptr_stmt);
return ctx->back_stmt();
auto local_load = ctx->push_back<LocalLoadStmt>(ptr_stmt);
local_load->ret_type = local_load->src->ret_type.ptr_removed();
return local_load;
}

Stmt *flatten_rvalue(Expr ptr, Expression::FlattenContext *ctx) {
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace irpass {

void re_id(IRNode *root);
void flag_access(IRNode *root);
void eliminate_immutable_local_vars(IRNode *root);
void scalarize(IRNode *root, const CompileConfig &config);
void lower_matrix_ptr(IRNode *root);
bool die(IRNode *root);
Expand Down
3 changes: 3 additions & 0 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ void compile_to_offloads(IRNode *ir,
print("Lowered");
}

irpass::eliminate_immutable_local_vars(ir);
print("Immutable local vars eliminated");

if (config.real_matrix_scalarize) {
irpass::scalarize(ir, config);

Expand Down
67 changes: 67 additions & 0 deletions taichi/transforms/eliminate_immutable_local_vars.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include "taichi/ir/ir.h"
#include "taichi/ir/statements.h"
#include "taichi/ir/analysis.h"
#include "taichi/ir/visitors.h"
#include "taichi/system/profiler.h"

namespace taichi::lang {

// The EliminateImmutableLocalVars pass eliminates all immutable local vars
// calculated from the GatherImmutableLocalVars pass. An immutable local var
// can be eliminated by forwarding the value of its only store to all loads
// after that store. See https://github.com/taichi-dev/taichi/pull/6926 for the
// background of this optimization.
class EliminateImmutableLocalVars : public BasicStmtVisitor {
private:
using BasicStmtVisitor::visit;

DelayedIRModifier modifier_;
std::unordered_set<Stmt *> immutable_local_vars_;
std::unordered_map<Stmt *, Stmt *> immutable_local_var_to_value_;

public:
explicit EliminateImmutableLocalVars(
const std::unordered_set<Stmt *> &immutable_local_vars)
: immutable_local_vars_(immutable_local_vars) {
}

void visit(AllocaStmt *stmt) override {
if (immutable_local_vars_.find(stmt) != immutable_local_vars_.end()) {
modifier_.erase(stmt);
}
}

void visit(LocalLoadStmt *stmt) override {
if (immutable_local_vars_.find(stmt->src) != immutable_local_vars_.end()) {
stmt->replace_usages_with(immutable_local_var_to_value_[stmt->src]);
modifier_.erase(stmt);
}
}

void visit(LocalStoreStmt *stmt) override {
if (immutable_local_vars_.find(stmt->dest) != immutable_local_vars_.end()) {
TI_ASSERT(immutable_local_var_to_value_.find(stmt->dest) ==
immutable_local_var_to_value_.end());
immutable_local_var_to_value_[stmt->dest] = stmt->val;
modifier_.erase(stmt);
}
}

static void run(IRNode *node) {
EliminateImmutableLocalVars pass(
irpass::analysis::gather_immutable_local_vars(node));
node->accept(&pass);
pass.modifier_.modify_ir();
}
};

namespace irpass {

void eliminate_immutable_local_vars(IRNode *root) {
TI_AUTO_PROF;
EliminateImmutableLocalVars::run(root);
}

} // namespace irpass

} // namespace taichi::lang