Skip to content

Commit

Permalink
[ir] [llvm] Add offset_bytes_in_parent_cell to SNode (#3793)
Browse files Browse the repository at this point in the history
* [ir] [llvm] Add offset_bytes_in_parent_cell to SNode

* Auto Format

Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
strongoier and taichi-gardener authored Dec 14, 2021
1 parent 9f0d9f5 commit cd6d4f8
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 1 deletion.
5 changes: 5 additions & 0 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,11 @@ def cell_size_bytes(self):
runtime.materialize()
return self.ptr.cell_size_bytes

@property
def offset_bytes_in_parent_cell(self):
impl.get_runtime().materialize()
return self.ptr.offset_bytes_in_parent_cell

def deactivate_all(self):
"""Recursively deactivate all children components of `self`."""
ch = self.get_children()
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/snode.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class SNode {
int total_bit_start{0};
int chunk_size{0};
std::size_t cell_size_bytes{0};
std::size_t offset_bytes_in_parent_cell{0};
PrimitiveType *physical_type{nullptr}; // for bit_struct and bit_array only
DataType dt;
bool has_ambient{false};
Expand Down
5 changes: 5 additions & 0 deletions taichi/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,11 @@ std::size_t TaichiLLVMContext::get_type_size(llvm::Type *type) {
return get_data_layout().getTypeAllocSize(type);
}

std::size_t TaichiLLVMContext::get_struct_element_offset(llvm::StructType *type,
int idx) {
return get_data_layout().getStructLayout(type)->getElementOffset(idx);
}

void TaichiLLVMContext::mark_inline(llvm::Function *f) {
for (auto &B : *f)
for (auto &I : B) {
Expand Down
2 changes: 2 additions & 0 deletions taichi/llvm/llvm_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class TaichiLLVMContext {

std::size_t get_type_size(llvm::Type *type);

std::size_t get_struct_element_offset(llvm::StructType *type, int idx);

template <typename T>
llvm::Value *get_constant(T t);

Expand Down
1 change: 1 addition & 0 deletions taichi/llvm/llvm_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class Value;
class Module;
class Function;
class DataLayout;
class StructType;
class JITSymbol;
class ExitOnError;
namespace orc {
Expand Down
2 changes: 2 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,8 @@ void export_lang(py::module &m) {
.def("num_active_indices",
[](SNode *snode) { return snode->num_active_indices; })
.def_readonly("cell_size_bytes", &SNode::cell_size_bytes)
.def_readonly("offset_bytes_in_parent_cell",
&SNode::offset_bytes_in_parent_cell)
.def("begin_shared_exp_placement", &SNode::begin_shared_exp_placement)
.def("end_shared_exp_placement", &SNode::end_shared_exp_placement);

Expand Down
7 changes: 7 additions & 0 deletions taichi/struct/struct_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ void StructCompilerLLVM::generate_types(SNode &snode) {

snode.cell_size_bytes = tlctx_->get_type_size(ch_type);

for (int i = 0; i < snode.ch.size(); i++) {
if (!snode.ch[i]->is_bit_level) {
snode.ch[i]->offset_bytes_in_parent_cell =
tlctx_->get_struct_element_offset(ch_type, i);
}
}

llvm::Type *body_type = nullptr, *aux_type = nullptr;
if (type == SNodeType::dense || type == SNodeType::bitmasked) {
TI_ASSERT(snode._morton == false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,22 @@ def test_primitives():
n3.place(p, q, r)

assert n1.cell_size_bytes == 2
assert 12 <= n2.cell_size_bytes <= 16
assert n2.cell_size_bytes in [12, 16]
assert n3.cell_size_bytes == 16

assert n1.offset_bytes_in_parent_cell == 0
assert n2.offset_bytes_in_parent_cell == 2 * 32
assert n3.offset_bytes_in_parent_cell in [
2 * 32 + 12 * 32, 2 * 32 + 16 * 32
]

assert x.snode.offset_bytes_in_parent_cell == 0
assert y.snode.offset_bytes_in_parent_cell == 0
assert z.snode.offset_bytes_in_parent_cell in [4, 8]
assert p.snode.offset_bytes_in_parent_cell == 0
assert q.snode.offset_bytes_in_parent_cell == 4
assert r.snode.offset_bytes_in_parent_cell == 8


@ti.test(arch=ti.cpu)
def test_bit_struct():
Expand Down

0 comments on commit cd6d4f8

Please sign in to comment.