From 71b983f7be2a50492d3fc6eb2184101a6731dd64 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Wed, 11 Jan 2023 15:01:03 +0800 Subject: [PATCH] [Bug] Fix num_splits in parallel_struct_for (#7121) Issue: fix #7112 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- taichi/codegen/llvm/codegen_llvm.cpp | 3 +- tests/python/test_struct_for_non_pot.py | 37 +++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index aa1b9a6d3eb74..370bc6705a97e 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -2265,7 +2265,8 @@ void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, int list_element_size = std::min(leaf_block->max_num_elements(), (int64)taichi_listgen_max_element_size); - int num_splits = std::max(1, list_element_size / stmt->block_dim); + int num_splits = std::max(1, list_element_size / stmt->block_dim + + (list_element_size % stmt->block_dim != 0)); auto struct_for_func = get_runtime_function("parallel_struct_for"); diff --git a/tests/python/test_struct_for_non_pot.py b/tests/python/test_struct_for_non_pot.py index 9632fde3383e8..3840e82797997 100644 --- a/tests/python/test_struct_for_non_pot.py +++ b/tests/python/test_struct_for_non_pot.py @@ -66,3 +66,40 @@ def test_2d(): @test_utils.test(packed=True) def test_2d_packed(): _test_2d() + + +def _test_2d_pointer(): + block_size, leaf_size = 3, 8 + x = ti.field(ti.i32) + block = ti.root.pointer(ti.ij, (block_size, block_size)) + block.dense(ti.ij, (leaf_size, leaf_size)).place(x) + + @ti.kernel + def activate(): + x[7, 7] = 1 + + activate() + + @ti.kernel + def test() -> ti.i32: + res = 0 + for I in ti.grouped(x): + res += I[0] + I[1] * 2 + return res + + ans = 0 + for i in range(leaf_size): + for j in range(leaf_size): + ans += i + j * 2 + + assert ans == test() + + +@test_utils.test(require=ti.extension.sparse, packed=False) +def test_2d_pointer(): + _test_2d_pointer() + + +@test_utils.test(require=ti.extension.sparse, packed=True) +def test_2d_pointer_packed(): + _test_2d_pointer()