Skip to content

Commit

Permalink
[Lang] [bug] Let nested data classes have methods (#7909)
Browse files Browse the repository at this point in the history
Issue: fixes #7908 
Added the methods in the string representation of the class as well.
  • Loading branch information
lin-hitonami authored Apr 27, 2023
1 parent e2dd30b commit 23ac5da
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
7 changes: 6 additions & 1 deletion python/taichi/lang/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def __str__(self):
"""Python scope struct array print support."""
if impl.inside_kernel():
item_str = ", ".join([str(k) + "=" + str(v) for k, v in self.items])
item_str += f", struct_methods={self.methods}"
return f"<ti.Struct {item_str}>"
return str(self.to_dict())

Expand All @@ -196,7 +197,11 @@ def to_dict(self, include_methods=False, include_ndim=False):
Dict: The result dictionary.
"""
res_dict = {
k: v.to_dict() if isinstance(v, Struct) else v.to_list() if isinstance(v, Matrix) else v
k: v.to_dict(include_methods=include_methods, include_ndim=include_ndim)
if isinstance(v, Struct)
else v.to_list()
if isinstance(v, Matrix)
else v
for k, v in self.entries.items()
}
if include_methods:
Expand Down
26 changes: 26 additions & 0 deletions tests/python/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,29 @@ def foo(x: Foo) -> Foo:
c = b.add_python(b)
assert c.x == 4
assert c.y == 8


@test_utils.test()
def test_nested_data_class_func():
@ti.dataclass
class Foo:
a: int

@ti.func
def foo(self):
return self.a

@ti.dataclass
class Nested:
f: Foo

@ti.func
def testme(self) -> int:
return self.f.foo()

@ti.kernel
def k() -> int:
x = Nested(Foo(42))
return x.testme()

assert k() == 42

0 comments on commit 23ac5da

Please sign in to comment.