From 9fcf4f059f84f2354ed0b002723717f9a6357750 Mon Sep 17 00:00:00 2001 From: Mingrui Zhang <33411325+erizmr@users.noreply.github.com> Date: Fri, 19 Apr 2024 19:14:41 +0800 Subject: [PATCH] [bug] Fix offline cache emit dependencies (#8510) Issue: # Fix SNodeTree storing in unordered container when generating kernel offline cache key. The potential different traverse order of a unordered container each time would generate different offline cache keys for same kernel, which breaks the offline cache i.e., triggering re-compilation each time. ### Brief Summary copilot:summary ### Walkthrough copilot:walkthrough --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- taichi/analysis/gen_offline_cache_key.cpp | 6 +- tests/python/test_offline_cache.py | 68 +++++++++++++++++++++++ 2 files changed, 71 insertions(+), 3 deletions(-) diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index 1d099b025c236..62eca26d378ba 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -532,7 +532,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { emit(static_cast(snode->get_snode_tree_id())); emit(static_cast(snode->id)); const auto *root = snode->get_root(); - snode_tree_roots_.insert(root); + snode_tree_roots_.push_back(root); } else { emit(std::numeric_limits::max()); emit(std::numeric_limits::max()); @@ -654,8 +654,8 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { #undef DEFINE_EMIT_ENUM std::ostream *os_{nullptr}; - std::unordered_set snode_tree_roots_; - std::unordered_map real_funcs_; + std::vector snode_tree_roots_; + std::map real_funcs_; std::vector string_pool_; }; diff --git a/tests/python/test_offline_cache.py b/tests/python/test_offline_cache.py index e61cae7d64c07..bcc7310bde768 100644 --- a/tests/python/test_offline_cache.py +++ b/tests/python/test_offline_cache.py @@ -384,6 +384,74 @@ def helper(): assert added_files() == expected_num_cache_files(len(simple_kernels_to_test)) +@pytest.mark.parametrize("curr_arch", supported_archs_offline_cache) +@_test_offline_cache_dec +def test_offline_cache_with_different_snode_trees(curr_arch): + count_of_cache_file = cache_files_cnt() + + def added_files(): + return cache_files_cnt() - count_of_cache_file + + def helper(): + x = ti.field(float, shape=5) + + @ti.kernel + def trigger_compile(): + x[0] += 1 + + # This case is used for testing SNodeTree storeing order matters (i.e., use a ordered container such as vector instead of unordered_map or unordered_set) when generating kernel offline cache key + # The multiple `trigger_compile` equalivant to allocate each field to a different SNodeTree + # i.e., + # x = ti.field(float) + # fb.dense(ti.i, 5).place(x) + # fb.finalize() + + trigger_compile() + a = ti.field(float, shape=5) + trigger_compile() + b = ti.field(float, shape=10) + trigger_compile() + c = ti.field(float, shape=5) + trigger_compile() + d = ti.field(float, shape=10) + trigger_compile() + e = ti.field(float, shape=5) + trigger_compile() + f = ti.field(float, shape=10) + trigger_compile() + g = ti.field(float, shape=5) + trigger_compile() + h = ti.field(float, shape=10) + + @ti.kernel + def kernel_forward(): + for i in range(5): + a[i] += i + b[i] += i + c[i] += i + d[i] += i + e[i] += i + f[i] += i + g[i] += i + h[i] += i + + kernel_forward() + + assert added_files() == expected_num_cache_files(0) + ti.init(arch=curr_arch, enable_fallback=False, **current_thread_ext_options()) + helper() + + ti.init(arch=curr_arch, enable_fallback=False, **current_thread_ext_options()) + assert added_files() == expected_num_cache_files(2) + helper() + + # The number of cache file should not change + for _ in range(5): + ti.init(arch=curr_arch, enable_fallback=False, **current_thread_ext_options()) + assert added_files() == expected_num_cache_files(2) + helper() + + @pytest.mark.parametrize("curr_arch", supported_archs_offline_cache) @_test_offline_cache_dec def test_offline_cache_with_changing_compile_config(curr_arch):