Skip to content

Commit

Permalink
[opt] Add pass eliminate_immutable_local_vars (#6926)
Browse files Browse the repository at this point in the history
Issue: #6933

### Brief Summary

There are many redundant copies of local vars in the initial IR:
```
  <[Tensor (3, 3) f32]> $128 = [$103, $106, $109, $112, $115, $118, $121, $124, $127]
  $129 : local store [$100 <- $128]
  <[Tensor (3, 3) f32]> $130 = alloca
  $131 = local load [$100]
  $132 : local store [$130 <- $131]
  <[Tensor (3, 3) f32]> $133 = alloca
  $134 = local load [$130]
  $135 : local store [$133 <- $134]
  <[Tensor (3, 3) f32]> $136 = alloca
  $137 = local load [$133]
  $138 : local store [$136 <- $137]
// In fact, `$128` can be used wherever `$136` is loaded.
```

These can come from many places; one of the main sources is the
pass-by-value convention of `ti.func`. The consequence is that the
number of instructions is unnecessarily large, which significantly slows
down compilation.

My solution here is to identify and eliminate such redundant
instructions in the first place so all later passes can take a much
smaller number of instructions as input. These redundant local vars are
essentially immutable ones - they are assigned only once and only loaded
after the assignment. In this PR, I add an optimization pass
`eliminate_immutable_local_vars` as the first pass.

(P.S. The type check processes of `MatrixExpression` and `LocalLoadStmt`
are fixed by the way to make the pass work properly.)

Let's study the effects in two cases: #6933 and
[voxel-rt2](https://github.com/taichi-dev/voxel-rt2/blob/main/example7.py).

First, let's compare the number of instructions after `scalarization`
pass (which happens immediately after the first pass).

| Kernel | Before this PR | After this PR | Rate of decrease |
| ------ | ------ | ------ | ------ |
| `test` (#6933) | 45859  | 26452 | 42% |
| `spatial_GRIS` (voxel-rt2) | 48519 | 17713 | 63% |

Then, let's compare the total time of `compile()`.

| Case | Before this PR | After this PR | Rate of decrease |
| ------ | ------ | ------ | ------ |
| #6933 | 20.622s | 8.550s | 59% |
| voxel-rt2  | 27.676s  | 9.495s | 66% |

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Dec 21, 2022
1 parent 8501dcf commit 19fce81
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 3 deletions.
92 changes: 92 additions & 0 deletions taichi/analysis/gather_immutable_local_vars.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include "taichi/ir/ir.h"
#include "taichi/ir/statements.h"
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"

namespace taichi::lang {

// The GatherImmutableLocalVars pass gathers all immutable local vars as input
// to the EliminateImmutableLocalVars pass. An immutable local var is an alloca
// which is stored only once (in the same block) and only loaded after that
// store.
class GatherImmutableLocalVars : public BasicStmtVisitor {
private:
using BasicStmtVisitor::visit;

enum class AllocaStatus { kCreated = 0, kStoredOnce = 1, kInvalid = 2 };
std::unordered_map<Stmt *, AllocaStatus> alloca_status_;

public:
explicit GatherImmutableLocalVars() {
invoke_default_visitor = true;
}

void visit(AllocaStmt *stmt) override {
TI_ASSERT(alloca_status_.find(stmt) == alloca_status_.end());
alloca_status_[stmt] = AllocaStatus::kCreated;
}

void visit(LocalLoadStmt *stmt) override {
if (stmt->src->is<AllocaStmt>()) {
auto status_iter = alloca_status_.find(stmt->src);
TI_ASSERT(status_iter != alloca_status_.end());
if (status_iter->second == AllocaStatus::kCreated) {
status_iter->second = AllocaStatus::kInvalid;
}
}
}

void visit(LocalStoreStmt *stmt) override {
if (stmt->dest->is<AllocaStmt>()) {
auto status_iter = alloca_status_.find(stmt->dest);
TI_ASSERT(status_iter != alloca_status_.end());
if (stmt->parent != stmt->dest->parent ||
status_iter->second == AllocaStatus::kStoredOnce ||
stmt->val->ret_type != stmt->dest->ret_type.ptr_removed()) {
status_iter->second = AllocaStatus::kInvalid;
} else if (status_iter->second == AllocaStatus::kCreated) {
status_iter->second = AllocaStatus::kStoredOnce;
}
}
}

void default_visit(Stmt *stmt) {
for (auto &op : stmt->get_operands()) {
if (op != nullptr && op->is<AllocaStmt>()) {
auto status_iter = alloca_status_.find(op);
TI_ASSERT(status_iter != alloca_status_.end());
status_iter->second = AllocaStatus::kInvalid;
}
}
}

void visit(Stmt *stmt) override {
default_visit(stmt);
}

void preprocess_container_stmt(Stmt *stmt) override {
default_visit(stmt);
}

static std::unordered_set<Stmt *> run(IRNode *node) {
GatherImmutableLocalVars pass;
node->accept(&pass);
std::unordered_set<Stmt *> result;
for (auto &[k, v] : pass.alloca_status_) {
if (v == AllocaStatus::kStoredOnce) {
result.insert(k);
}
}
return result;
}
};

namespace irpass::analysis {

std::unordered_set<Stmt *> gather_immutable_local_vars(IRNode *root) {
return GatherImmutableLocalVars::run(root);
}

} // namespace irpass::analysis

} // namespace taichi::lang
1 change: 1 addition & 0 deletions taichi/ir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ bool definitely_same_address(Stmt *var1, Stmt *var2);

std::unordered_set<Stmt *> detect_fors_with_break(IRNode *root);
std::unordered_set<Stmt *> detect_loops_with_continue(IRNode *root);
std::unordered_set<Stmt *> gather_immutable_local_vars(IRNode *root);
std::unordered_set<SNode *> gather_deactivations(IRNode *root);
std::pair<std::unordered_set<SNode *>, std::unordered_set<SNode *>>
gather_snode_read_writes(IRNode *root);
Expand Down
10 changes: 7 additions & 3 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -670,9 +670,12 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx,
}

void MatrixExpression::type_check(CompileConfig *config) {
// TODO: typecheck matrix
for (auto &arg : elements) {
TI_ASSERT_TYPE_CHECKED(arg);
if (arg->ret_type != dt.get_element_type()) {
arg = cast(arg, dt.get_element_type());
arg->type_check(config);
}
}
ret_type = dt;
}
Expand Down Expand Up @@ -1569,8 +1572,9 @@ Stmt *flatten_global_load(Stmt *ptr_stmt, Expression::FlattenContext *ctx) {
}

Stmt *flatten_local_load(Stmt *ptr_stmt, Expression::FlattenContext *ctx) {
ctx->push_back<LocalLoadStmt>(ptr_stmt);
return ctx->back_stmt();
auto local_load = ctx->push_back<LocalLoadStmt>(ptr_stmt);
local_load->ret_type = local_load->src->ret_type.ptr_removed();
return local_load;
}

Stmt *flatten_rvalue(Expr ptr, Expression::FlattenContext *ctx) {
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace irpass {

void re_id(IRNode *root);
void flag_access(IRNode *root);
void eliminate_immutable_local_vars(IRNode *root);
void scalarize(IRNode *root, const CompileConfig &config);
void lower_matrix_ptr(IRNode *root);
bool die(IRNode *root);
Expand Down
3 changes: 3 additions & 0 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ void compile_to_offloads(IRNode *ir,
print("Lowered");
}

irpass::eliminate_immutable_local_vars(ir);
print("Immutable local vars eliminated");

if (config.real_matrix_scalarize) {
irpass::scalarize(ir, config);

Expand Down
67 changes: 67 additions & 0 deletions taichi/transforms/eliminate_immutable_local_vars.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include "taichi/ir/ir.h"
#include "taichi/ir/statements.h"
#include "taichi/ir/analysis.h"
#include "taichi/ir/visitors.h"
#include "taichi/system/profiler.h"

namespace taichi::lang {

// The EliminateImmutableLocalVars pass eliminates all immutable local vars
// calculated from the GatherImmutableLocalVars pass. An immutable local var
// can be eliminated by forwarding the value of its only store to all loads
// after that store. See https://github.com/taichi-dev/taichi/pull/6926 for the
// background of this optimization.
class EliminateImmutableLocalVars : public BasicStmtVisitor {
private:
using BasicStmtVisitor::visit;

DelayedIRModifier modifier_;
std::unordered_set<Stmt *> immutable_local_vars_;
std::unordered_map<Stmt *, Stmt *> immutable_local_var_to_value_;

public:
explicit EliminateImmutableLocalVars(
const std::unordered_set<Stmt *> &immutable_local_vars)
: immutable_local_vars_(immutable_local_vars) {
}

void visit(AllocaStmt *stmt) override {
if (immutable_local_vars_.find(stmt) != immutable_local_vars_.end()) {
modifier_.erase(stmt);
}
}

void visit(LocalLoadStmt *stmt) override {
if (immutable_local_vars_.find(stmt->src) != immutable_local_vars_.end()) {
stmt->replace_usages_with(immutable_local_var_to_value_[stmt->src]);
modifier_.erase(stmt);
}
}

void visit(LocalStoreStmt *stmt) override {
if (immutable_local_vars_.find(stmt->dest) != immutable_local_vars_.end()) {
TI_ASSERT(immutable_local_var_to_value_.find(stmt->dest) ==
immutable_local_var_to_value_.end());
immutable_local_var_to_value_[stmt->dest] = stmt->val;
modifier_.erase(stmt);
}
}

static void run(IRNode *node) {
EliminateImmutableLocalVars pass(
irpass::analysis::gather_immutable_local_vars(node));
node->accept(&pass);
pass.modifier_.modify_ir();
}
};

namespace irpass {

void eliminate_immutable_local_vars(IRNode *root) {
TI_AUTO_PROF;
EliminateImmutableLocalVars::run(root);
}

} // namespace irpass

} // namespace taichi::lang

0 comments on commit 19fce81

Please sign in to comment.