From ae0882c021cf01eed9a051f353f4895b47319734 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Thu, 5 Jan 2023 10:14:58 +0800 Subject: [PATCH] [opt] Add ExtractPointers pass for dynamic index (#7051) Issue: #2590 ### Brief Summary Under pure `dynamic_index` setting, `MatrixPtrStmt`s are not scalarized. It actually produces `2n` more instructions (`n` `ConstStmt`s and n `MatrixPtrStmt`s) than the scalarized setting, where `n` is the number of usages of `MatrixPtrStmt`s. This PR adds `ExtractPointers` pass to eliminate all the redundant instructions. See comments in the code for details. After this PR, the number of instructions after the `scalarize()` pass of the script in #6933 under dynamic index reduces from 49589 to 26581, and the compilation time reduces from 20.02s to 7.82s. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- taichi/transforms/scalarize.cpp | 73 +++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 2c2ecfddee2d9..e8f96c339a52c 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -652,6 +652,77 @@ class ScalarizePointers : public BasicStmtVisitor { using BasicStmtVisitor::visit; }; +// The ExtractPointers pass aims at removing redundant ConstStmts and +// MatrixPtrStmts generated for any (AllocaStmt, integer) pair by extracting +// a unique copy for any future usage. +// +// Example for redundant stmts: +// $0 = const 0 +// $1 = const 1 +// ... +// <[Tensor (3, 3) f32]> $47738 = alloca +// $47739 = const 0 [REDUNDANT] +// <*f32> $47740 = shift ptr [$47738 + $47739] +// $47741 : local store [$47740 <- $47713] +// $47742 = const 1 [REDUNDANT] +// <*f32> $47743 = shift ptr [$47738 + $47742] +// $47744 : local store [$47743 <- $47716] +// ... +// $47812 = const 1 [REDUNDANT] +// <*f32> $47813 = shift ptr [$47738 + $47812] [REDUNDANT] +// $47814 = local load [$47813] +class ExtractPointers : public BasicStmtVisitor { + public: + ImmediateIRModifier immediate_modifier_; + DelayedIRModifier delayed_modifier_; + + std::unordered_map, + Stmt *, + hashing::Hasher>> + first_matrix_ptr_; // mapping an (AllocaStmt, integer) pair to the first + // MatrixPtrStmt representing it + std::unordered_map + first_const_; // mapping an integer to the first ConstStmt representing + // it + Block *top_level_; + + explicit ExtractPointers(IRNode *root) : immediate_modifier_(root) { + TI_ASSERT(root->is()); + top_level_ = root->as(); + root->accept(this); + delayed_modifier_.modify_ir(); + } + + void visit(MatrixPtrStmt *stmt) override { + if (stmt->origin->is()) { + auto alloca_stmt = stmt->origin->cast(); + auto tensor_type = + alloca_stmt->ret_type.ptr_removed()->cast(); + TI_ASSERT(tensor_type != nullptr); + if (stmt->offset->is()) { + int offset = stmt->offset->cast()->val.val_int32(); + if (first_const_.count(offset) == 0) { + first_const_[offset] = stmt->offset; + delayed_modifier_.extract_to_block_front(stmt->offset, top_level_); + } + auto key = std::make_pair(alloca_stmt, offset); + if (first_matrix_ptr_.count(key) == 0) { + auto extracted = std::make_unique( + alloca_stmt, first_const_[offset]); + first_matrix_ptr_[key] = extracted.get(); + delayed_modifier_.insert_after(alloca_stmt, std::move(extracted)); + } + auto new_stmt = first_matrix_ptr_[key]; + immediate_modifier_.replace_usages_with(stmt, new_stmt); + delayed_modifier_.erase(stmt); + } + } + } + + private: + using BasicStmtVisitor::visit; +}; + namespace irpass { void scalarize(IRNode *root, const CompileConfig &config) { @@ -659,6 +730,8 @@ void scalarize(IRNode *root, const CompileConfig &config) { Scalarize scalarize_pass(root); if (!config.dynamic_index) { ScalarizePointers scalarize_pointers_pass(root); + } else { + ExtractPointers extract_pointers_pass(root); } }