Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[opt] Add ExtractPointers pass for dynamic index #7051

Merged
merged 2 commits into from
Jan 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,13 +652,86 @@ class ScalarizePointers : public BasicStmtVisitor {
using BasicStmtVisitor::visit;
};

// The ExtractPointers pass aims at removing redundant ConstStmts and
// MatrixPtrStmts generated for any (AllocaStmt, integer) pair by extracting
// a unique copy for any future usage.
//
// Example for redundant stmts:
// <i32> $0 = const 0
// <i32> $1 = const 1
// ...
// <[Tensor (3, 3) f32]> $47738 = alloca
// <i32> $47739 = const 0 [REDUNDANT]
// <*f32> $47740 = shift ptr [$47738 + $47739]
// $47741 : local store [$47740 <- $47713]
// <i32> $47742 = const 1 [REDUNDANT]
// <*f32> $47743 = shift ptr [$47738 + $47742]
// $47744 : local store [$47743 <- $47716]
// ...
// <i32> $47812 = const 1 [REDUNDANT]
// <*f32> $47813 = shift ptr [$47738 + $47812] [REDUNDANT]
// <f32> $47814 = local load [$47813]
class ExtractPointers : public BasicStmtVisitor {
public:
ImmediateIRModifier immediate_modifier_;
DelayedIRModifier delayed_modifier_;

std::unordered_map<std::pair<Stmt *, int>,
Stmt *,
hashing::Hasher<std::pair<Stmt *, int>>>
first_matrix_ptr_; // mapping an (AllocaStmt, integer) pair to the first
// MatrixPtrStmt representing it
std::unordered_map<int, Stmt *>
first_const_; // mapping an integer to the first ConstStmt representing
// it
Block *top_level_;

explicit ExtractPointers(IRNode *root) : immediate_modifier_(root) {
TI_ASSERT(root->is<Block>());
top_level_ = root->as<Block>();
root->accept(this);
delayed_modifier_.modify_ir();
}

void visit(MatrixPtrStmt *stmt) override {
if (stmt->origin->is<AllocaStmt>()) {
auto alloca_stmt = stmt->origin->cast<AllocaStmt>();
auto tensor_type =
alloca_stmt->ret_type.ptr_removed()->cast<TensorType>();
TI_ASSERT(tensor_type != nullptr);
if (stmt->offset->is<ConstStmt>()) {
int offset = stmt->offset->cast<ConstStmt>()->val.val_int32();
if (first_const_.count(offset) == 0) {
first_const_[offset] = stmt->offset;
delayed_modifier_.extract_to_block_front(stmt->offset, top_level_);
}
auto key = std::make_pair(alloca_stmt, offset);
if (first_matrix_ptr_.count(key) == 0) {
auto extracted = std::make_unique<MatrixPtrStmt>(
alloca_stmt, first_const_[offset]);
first_matrix_ptr_[key] = extracted.get();
delayed_modifier_.insert_after(alloca_stmt, std::move(extracted));
}
auto new_stmt = first_matrix_ptr_[key];
immediate_modifier_.replace_usages_with(stmt, new_stmt);
delayed_modifier_.erase(stmt);
}
}
}

private:
using BasicStmtVisitor::visit;
};

namespace irpass {

void scalarize(IRNode *root, const CompileConfig &config) {
TI_AUTO_PROF;
Scalarize scalarize_pass(root);
if (!config.dynamic_index) {
ScalarizePointers scalarize_pointers_pass(root);
} else {
ExtractPointers extract_pointers_pass(root);
}
}

Expand Down