diff --git a/python/taichi/lang/snode.py b/python/taichi/lang/snode.py index de1339730a111..222c32773017d 100644 --- a/python/taichi/lang/snode.py +++ b/python/taichi/lang/snode.py @@ -279,6 +279,11 @@ def cell_size_bytes(self): runtime.materialize() return self.ptr.cell_size_bytes + @property + def offset_bytes_in_parent_cell(self): + impl.get_runtime().materialize() + return self.ptr.offset_bytes_in_parent_cell + def deactivate_all(self): """Recursively deactivate all children components of `self`.""" ch = self.get_children() diff --git a/taichi/ir/snode.h b/taichi/ir/snode.h index f50fca4d0b6a0..5fe6b1799d717 100644 --- a/taichi/ir/snode.h +++ b/taichi/ir/snode.h @@ -123,6 +123,7 @@ class SNode { int total_bit_start{0}; int chunk_size{0}; std::size_t cell_size_bytes{0}; + std::size_t offset_bytes_in_parent_cell{0}; PrimitiveType *physical_type{nullptr}; // for bit_struct and bit_array only DataType dt; bool has_ambient{false}; diff --git a/taichi/llvm/llvm_context.cpp b/taichi/llvm/llvm_context.cpp index b5a82bb308763..6624533c9e6ca 100644 --- a/taichi/llvm/llvm_context.cpp +++ b/taichi/llvm/llvm_context.cpp @@ -521,6 +521,11 @@ std::size_t TaichiLLVMContext::get_type_size(llvm::Type *type) { return get_data_layout().getTypeAllocSize(type); } +std::size_t TaichiLLVMContext::get_struct_element_offset(llvm::StructType *type, + int idx) { + return get_data_layout().getStructLayout(type)->getElementOffset(idx); +} + void TaichiLLVMContext::mark_inline(llvm::Function *f) { for (auto &B : *f) for (auto &I : B) { diff --git a/taichi/llvm/llvm_context.h b/taichi/llvm/llvm_context.h index 1b000870d96c8..5c8426eb6cf7b 100644 --- a/taichi/llvm/llvm_context.h +++ b/taichi/llvm/llvm_context.h @@ -104,6 +104,8 @@ class TaichiLLVMContext { std::size_t get_type_size(llvm::Type *type); + std::size_t get_struct_element_offset(llvm::StructType *type, int idx); + template llvm::Value *get_constant(T t); diff --git a/taichi/llvm/llvm_fwd.h b/taichi/llvm/llvm_fwd.h index 70621ae6a8bd7..dae04dde079a5 100644 --- a/taichi/llvm/llvm_fwd.h +++ b/taichi/llvm/llvm_fwd.h @@ -7,6 +7,7 @@ class Value; class Module; class Function; class DataLayout; +class StructType; class JITSymbol; class ExitOnError; namespace orc { diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 1413ce330c726..a715c06cb8894 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -403,6 +403,8 @@ void export_lang(py::module &m) { .def("num_active_indices", [](SNode *snode) { return snode->num_active_indices; }) .def_readonly("cell_size_bytes", &SNode::cell_size_bytes) + .def_readonly("offset_bytes_in_parent_cell", + &SNode::offset_bytes_in_parent_cell) .def("begin_shared_exp_placement", &SNode::begin_shared_exp_placement) .def("end_shared_exp_placement", &SNode::end_shared_exp_placement); diff --git a/taichi/struct/struct_llvm.cpp b/taichi/struct/struct_llvm.cpp index f09a5b20d8e41..816aefd977750 100644 --- a/taichi/struct/struct_llvm.cpp +++ b/taichi/struct/struct_llvm.cpp @@ -57,6 +57,13 @@ void StructCompilerLLVM::generate_types(SNode &snode) { snode.cell_size_bytes = tlctx_->get_type_size(ch_type); + for (int i = 0; i < snode.ch.size(); i++) { + if (!snode.ch[i]->is_bit_level) { + snode.ch[i]->offset_bytes_in_parent_cell = + tlctx_->get_struct_element_offset(ch_type, i); + } + } + llvm::Type *body_type = nullptr, *aux_type = nullptr; if (type == SNodeType::dense || type == SNodeType::bitmasked) { TI_ASSERT(snode._morton == false); diff --git a/tests/python/test_cell_size_inspection.py b/tests/python/test_snode_layout_inspection.py similarity index 62% rename from tests/python/test_cell_size_inspection.py rename to tests/python/test_snode_layout_inspection.py index 39b07d1987439..64d5ee7f3ea10 100644 --- a/tests/python/test_cell_size_inspection.py +++ b/tests/python/test_snode_layout_inspection.py @@ -21,9 +21,22 @@ def test_primitives(): n3.place(p, q, r) assert n1.cell_size_bytes == 2 - assert 12 <= n2.cell_size_bytes <= 16 + assert n2.cell_size_bytes in [12, 16] assert n3.cell_size_bytes == 16 + assert n1.offset_bytes_in_parent_cell == 0 + assert n2.offset_bytes_in_parent_cell == 2 * 32 + assert n3.offset_bytes_in_parent_cell in [ + 2 * 32 + 12 * 32, 2 * 32 + 16 * 32 + ] + + assert x.snode.offset_bytes_in_parent_cell == 0 + assert y.snode.offset_bytes_in_parent_cell == 0 + assert z.snode.offset_bytes_in_parent_cell in [4, 8] + assert p.snode.offset_bytes_in_parent_cell == 0 + assert q.snode.offset_bytes_in_parent_cell == 4 + assert r.snode.offset_bytes_in_parent_cell == 8 + @ti.test(arch=ti.cpu) def test_bit_struct():