From e2ac5622d1956dab957ace49690b5657f6cdf530 Mon Sep 17 00:00:00 2001 From: Mingrui Zhang Date: Sun, 1 Dec 2024 14:09:22 +0800 Subject: [PATCH 1/4] accelerate compile --- taichi/analysis/gen_offline_cache_key.cpp | 44 ++++++++++++++++++++++- taichi/analysis/offline_cache_util.cpp | 15 +++++++- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index 62eca26d378ba..2fcf904e5124b 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -65,6 +65,10 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { return this->os_; } + void print_time_cost(){ + TI_TRACE("Emit pod time cost {} ms", emit_time_cost * 1000); + } + void visit(Expression *expr) override { this->ExpressionVisitor::visit(expr); } @@ -426,46 +430,82 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { static void run(IRNode *ast, std::ostream *os) { ASTSerializer serializer(os); + auto t = Time::get_time(); ast->accept(&serializer); + t = Time::get_time() - t; + TI_TRACE("Traversal and emit {} ms", t * 1000); + + t = Time::get_time(); serializer.emit_dependencies(); + t = Time::get_time() - t; + TI_TRACE("emit_dependencies cost {} ms", t * 1000); + serializer.print_time_cost(); } private: void emit_dependencies() { // Serialize dependent real-functions emit(real_funcs_.size()); + + auto t = Time::get_time(); for (auto &[func, id] : real_funcs_) { if (auto &ast_str = func->try_get_ast_serialization_data(); ast_str.has_value()) { emit_bytes(ast_str->c_str(), ast_str->size()); } } + t = Time::get_time() - t; + TI_TRACE("[emit_dependencies] serialize real func cost {} ms", t * 1000); + + t = Time::get_time(); // Serialize snode_trees(Temporary: using offline-cache-key of SNode) // Note: The result of serializing snode_tree_roots_ is not parsable now emit(static_cast(snode_tree_roots_.size())); for (const auto *snode : snode_tree_roots_) { - auto key = get_hashed_offline_cache_key_of_snode(snode); + std::string key; + if(snode_key_cache_.find(snode) == snode_key_cache_.end()){ + key = get_hashed_offline_cache_key_of_snode(snode); + snode_key_cache_[snode] = key; + }else{ + key = snode_key_cache_[snode]; + } + // key = get_hashed_offline_cache_key_of_snode(snode); + snode_key_cache_[snode] = key; emit_bytes(key.c_str(), key.size()); } + t = Time::get_time() - t; + TI_TRACE("[emit_dependencies] serialize snode tree cost {} ms", t * 1000); + + t = Time::get_time(); // Dump string-pool emit(static_cast(string_pool_.size())); emit_bytes(string_pool_.data(), string_pool_.size()); + t = Time::get_time() - t; + TI_TRACE("[emit_dependencies] dump string pool cost {} ms", t * 1000); } template void emit_pod(const T &val) { static_assert(std::is_pod::value); TI_ASSERT(os_); + auto t = Time::get_time(); os_->write((const char *)&val, sizeof(T)); + t = Time::get_time() - t; + // TI_TRACE("[{}] gen_offline_cache_key costs {} ms", kernel->name, t * 1000); + emit_time_cost += t; } void emit_bytes(const char *bytes, std::size_t len) { TI_ASSERT(os_); if (!bytes) return; + auto t = Time::get_time(); os_->write(bytes, len); + t = Time::get_time() - t; + // TI_TRACE("[{}] gen_offline_cache_key costs {} ms", kernel->name, t * 1000); + emit_time_cost += t; } template @@ -655,8 +695,10 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { std::ostream *os_{nullptr}; std::vector snode_tree_roots_; + std::unordered_map snode_key_cache_; std::map real_funcs_; std::vector string_pool_; + double emit_time_cost = 0.0; }; } // namespace diff --git a/taichi/analysis/offline_cache_util.cpp b/taichi/analysis/offline_cache_util.cpp index 2ea4d0df979da..6acfddb76c664 100644 --- a/taichi/analysis/offline_cache_util.cpp +++ b/taichi/analysis/offline_cache_util.cpp @@ -153,15 +153,23 @@ std::string get_hashed_offline_cache_key_of_snode(const SNode *snode) { TI_ASSERT(snode); BinaryOutputSerializer serializer; + + auto t = Time::get_time(); serializer.initialize(); { - std::unordered_set visited; + std::unordered_set visited; get_offline_cache_key_of_snode_impl(snode, serializer, visited); } serializer.finalize(); + t = Time::get_time() - t; + TI_TRACE("[emit_dependencies] get_hashed_offline_cache_key_of_snode cost {} ms", t * 1000); + picosha2::hash256_one_by_one hasher; + t = Time::get_time(); hasher.process(serializer.data.begin(), serializer.data.end()); + t = Time::get_time() - t; + TI_TRACE("[emit_dependencies] string len {}, hasher.process cost {} ms", serializer.data.size(), t * 1000); hasher.finish(); return picosha2::get_hash_hex_string(hasher); @@ -177,7 +185,12 @@ std::string get_hashed_offline_cache_key(const CompileConfig &config, get_offline_cache_key_of_parameter_list(kernel->parameter_list); kernel_rets_string = get_offline_cache_key_of_rets(kernel->rets); std::ostringstream oss; + + auto t = Time::get_time(); gen_offline_cache_key(kernel->ir.get(), &oss); + t = Time::get_time() - t; + TI_TRACE("[{}] gen_offline_cache_key costs {} ms", kernel->name, t * 1000); + kernel_body_string = oss.str(); } From 670f5a6d156de43a8c1f7705bec14dc84db443fd Mon Sep 17 00:00:00 2001 From: Mingrui Zhang Date: Fri, 6 Dec 2024 21:18:26 +0800 Subject: [PATCH 2/4] clean debug info --- taichi/analysis/gen_offline_cache_key.cpp | 30 ----------------------- taichi/analysis/offline_cache_util.cpp | 11 --------- 2 files changed, 41 deletions(-) diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index 2fcf904e5124b..e2a31a2c434fd 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -65,10 +65,6 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { return this->os_; } - void print_time_cost(){ - TI_TRACE("Emit pod time cost {} ms", emit_time_cost * 1000); - } - void visit(Expression *expr) override { this->ExpressionVisitor::visit(expr); } @@ -430,16 +426,8 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { static void run(IRNode *ast, std::ostream *os) { ASTSerializer serializer(os); - auto t = Time::get_time(); ast->accept(&serializer); - t = Time::get_time() - t; - TI_TRACE("Traversal and emit {} ms", t * 1000); - - t = Time::get_time(); serializer.emit_dependencies(); - t = Time::get_time() - t; - TI_TRACE("emit_dependencies cost {} ms", t * 1000); - serializer.print_time_cost(); } private: @@ -447,18 +435,14 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { // Serialize dependent real-functions emit(real_funcs_.size()); - auto t = Time::get_time(); for (auto &[func, id] : real_funcs_) { if (auto &ast_str = func->try_get_ast_serialization_data(); ast_str.has_value()) { emit_bytes(ast_str->c_str(), ast_str->size()); } } - t = Time::get_time() - t; - TI_TRACE("[emit_dependencies] serialize real func cost {} ms", t * 1000); - t = Time::get_time(); // Serialize snode_trees(Temporary: using offline-cache-key of SNode) // Note: The result of serializing snode_tree_roots_ is not parsable now emit(static_cast(snode_tree_roots_.size())); @@ -475,37 +459,23 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { emit_bytes(key.c_str(), key.size()); } - t = Time::get_time() - t; - TI_TRACE("[emit_dependencies] serialize snode tree cost {} ms", t * 1000); - - t = Time::get_time(); // Dump string-pool emit(static_cast(string_pool_.size())); emit_bytes(string_pool_.data(), string_pool_.size()); - t = Time::get_time() - t; - TI_TRACE("[emit_dependencies] dump string pool cost {} ms", t * 1000); } template void emit_pod(const T &val) { static_assert(std::is_pod::value); TI_ASSERT(os_); - auto t = Time::get_time(); os_->write((const char *)&val, sizeof(T)); - t = Time::get_time() - t; - // TI_TRACE("[{}] gen_offline_cache_key costs {} ms", kernel->name, t * 1000); - emit_time_cost += t; } void emit_bytes(const char *bytes, std::size_t len) { TI_ASSERT(os_); if (!bytes) return; - auto t = Time::get_time(); os_->write(bytes, len); - t = Time::get_time() - t; - // TI_TRACE("[{}] gen_offline_cache_key costs {} ms", kernel->name, t * 1000); - emit_time_cost += t; } template diff --git a/taichi/analysis/offline_cache_util.cpp b/taichi/analysis/offline_cache_util.cpp index 6acfddb76c664..c2a49015c5c5e 100644 --- a/taichi/analysis/offline_cache_util.cpp +++ b/taichi/analysis/offline_cache_util.cpp @@ -154,22 +154,15 @@ std::string get_hashed_offline_cache_key_of_snode(const SNode *snode) { BinaryOutputSerializer serializer; - auto t = Time::get_time(); serializer.initialize(); { std::unordered_set visited; get_offline_cache_key_of_snode_impl(snode, serializer, visited); } serializer.finalize(); - t = Time::get_time() - t; - TI_TRACE("[emit_dependencies] get_hashed_offline_cache_key_of_snode cost {} ms", t * 1000); - picosha2::hash256_one_by_one hasher; - t = Time::get_time(); hasher.process(serializer.data.begin(), serializer.data.end()); - t = Time::get_time() - t; - TI_TRACE("[emit_dependencies] string len {}, hasher.process cost {} ms", serializer.data.size(), t * 1000); hasher.finish(); return picosha2::get_hash_hex_string(hasher); @@ -186,11 +179,7 @@ std::string get_hashed_offline_cache_key(const CompileConfig &config, kernel_rets_string = get_offline_cache_key_of_rets(kernel->rets); std::ostringstream oss; - auto t = Time::get_time(); gen_offline_cache_key(kernel->ir.get(), &oss); - t = Time::get_time() - t; - TI_TRACE("[{}] gen_offline_cache_key costs {} ms", kernel->name, t * 1000); - kernel_body_string = oss.str(); } From 5f47e7e88dd07348ba3d91883b73b25e10db3830 Mon Sep 17 00:00:00 2001 From: Mingrui Zhang Date: Fri, 6 Dec 2024 21:22:23 +0800 Subject: [PATCH 3/4] update --- taichi/analysis/gen_offline_cache_key.cpp | 4 ---- taichi/analysis/offline_cache_util.cpp | 2 -- 2 files changed, 6 deletions(-) diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index e2a31a2c434fd..bd0a1f1a36fb0 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -434,7 +434,6 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { void emit_dependencies() { // Serialize dependent real-functions emit(real_funcs_.size()); - for (auto &[func, id] : real_funcs_) { if (auto &ast_str = func->try_get_ast_serialization_data(); ast_str.has_value()) { @@ -442,7 +441,6 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { } } - // Serialize snode_trees(Temporary: using offline-cache-key of SNode) // Note: The result of serializing snode_tree_roots_ is not parsable now emit(static_cast(snode_tree_roots_.size())); @@ -454,7 +452,6 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { }else{ key = snode_key_cache_[snode]; } - // key = get_hashed_offline_cache_key_of_snode(snode); snode_key_cache_[snode] = key; emit_bytes(key.c_str(), key.size()); } @@ -668,7 +665,6 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { std::unordered_map snode_key_cache_; std::map real_funcs_; std::vector string_pool_; - double emit_time_cost = 0.0; }; } // namespace diff --git a/taichi/analysis/offline_cache_util.cpp b/taichi/analysis/offline_cache_util.cpp index c2a49015c5c5e..66f51121ea392 100644 --- a/taichi/analysis/offline_cache_util.cpp +++ b/taichi/analysis/offline_cache_util.cpp @@ -153,7 +153,6 @@ std::string get_hashed_offline_cache_key_of_snode(const SNode *snode) { TI_ASSERT(snode); BinaryOutputSerializer serializer; - serializer.initialize(); { std::unordered_set visited; @@ -178,7 +177,6 @@ std::string get_hashed_offline_cache_key(const CompileConfig &config, get_offline_cache_key_of_parameter_list(kernel->parameter_list); kernel_rets_string = get_offline_cache_key_of_rets(kernel->rets); std::ostringstream oss; - gen_offline_cache_key(kernel->ir.get(), &oss); kernel_body_string = oss.str(); } From 980ffc96ee51086ffdef7d9bf1af92e80ea4e503 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Dec 2024 13:27:13 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/analysis/gen_offline_cache_key.cpp | 4 ++-- taichi/analysis/offline_cache_util.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index bd0a1f1a36fb0..592908f6f239f 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -446,10 +446,10 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { emit(static_cast(snode_tree_roots_.size())); for (const auto *snode : snode_tree_roots_) { std::string key; - if(snode_key_cache_.find(snode) == snode_key_cache_.end()){ + if (snode_key_cache_.find(snode) == snode_key_cache_.end()) { key = get_hashed_offline_cache_key_of_snode(snode); snode_key_cache_[snode] = key; - }else{ + } else { key = snode_key_cache_[snode]; } snode_key_cache_[snode] = key; diff --git a/taichi/analysis/offline_cache_util.cpp b/taichi/analysis/offline_cache_util.cpp index 66f51121ea392..2ea4d0df979da 100644 --- a/taichi/analysis/offline_cache_util.cpp +++ b/taichi/analysis/offline_cache_util.cpp @@ -155,7 +155,7 @@ std::string get_hashed_offline_cache_key_of_snode(const SNode *snode) { BinaryOutputSerializer serializer; serializer.initialize(); { - std::unordered_set visited; + std::unordered_set visited; get_offline_cache_key_of_snode_impl(snode, serializer, visited); } serializer.finalize();