diff --git a/python/taichi/lang/struct.py b/python/taichi/lang/struct.py index 89aab933afb46..c721714bb14b6 100644 --- a/python/taichi/lang/struct.py +++ b/python/taichi/lang/struct.py @@ -179,6 +179,7 @@ def __str__(self): """Python scope struct array print support.""" if impl.inside_kernel(): item_str = ", ".join([str(k) + "=" + str(v) for k, v in self.items]) + item_str += f", struct_methods={self.methods}" return f"" return str(self.to_dict()) @@ -196,7 +197,11 @@ def to_dict(self, include_methods=False, include_ndim=False): Dict: The result dictionary. """ res_dict = { - k: v.to_dict() if isinstance(v, Struct) else v.to_list() if isinstance(v, Matrix) else v + k: v.to_dict(include_methods=include_methods, include_ndim=include_ndim) + if isinstance(v, Struct) + else v.to_list() + if isinstance(v, Matrix) + else v for k, v in self.entries.items() } if include_methods: diff --git a/tests/python/test_struct.py b/tests/python/test_struct.py index 39933db5e64bf..a5584fc0c49e5 100644 --- a/tests/python/test_struct.py +++ b/tests/python/test_struct.py @@ -128,3 +128,29 @@ def foo(x: Foo) -> Foo: c = b.add_python(b) assert c.x == 4 assert c.y == 8 + + +@test_utils.test() +def test_nested_data_class_func(): + @ti.dataclass + class Foo: + a: int + + @ti.func + def foo(self): + return self.a + + @ti.dataclass + class Nested: + f: Foo + + @ti.func + def testme(self) -> int: + return self.f.foo() + + @ti.kernel + def k() -> int: + x = Nested(Foo(42)) + return x.testme() + + assert k() == 42