From 6a84c60e96f4c8b6391916ff82209347e48570de Mon Sep 17 00:00:00 2001 From: Lin Jiang Date: Sun, 23 Apr 2023 14:40:35 +0800 Subject: [PATCH] [Lang] Support the functions of dataclass as kernel argument and return value (#7865) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Issue: fixes #7822 ### ๐Ÿค– Generated by Copilot at 4ae87d7 ### Summary ๐Ÿงช๐Ÿ› ๏ธ๐Ÿš€ This pull request enables data classes defined by `ti.dataclass` to have methods decorated by `ti.func` and `ti.kernel`. It also adds tests to check the functionality and correctness of this feature. > _`StructType` creates the data class of doom_ > _With `__struct_methods` to unleash its power_ > _Call them from Python, or pass them to `ti.kernel`_ > _`ti.func` returns them, the ultimate metal_ ### Walkthrough * Add `__struct_methods` attribute to data class dictionaries to enable access of methods defined in `ti.dataclass` decorator ([link](https://github.com/taichi-dev/taichi/pull/7865/files?diff=unified&w=0#diff-3154e0533b9fd63e663c16c2e87c65068c3b002a65cf80529b6577d173bdb5feR633), [link](https://github.com/taichi-dev/taichi/pull/7865/files?diff=unified&w=0#diff-3154e0533b9fd63e663c16c2e87c65068c3b002a65cf80529b6577d173bdb5feR654)) * Add test cases for using `ti.func` and `ti.kernel` as methods of data classes, both as kernel arguments and kernel returns, in `test_struct.py` ([link](https://github.com/taichi-dev/taichi/pull/7865/files?diff=unified&w=0#diff-e87bf5cb1cd09e10b5cfa001ab2ef18f31a242db3a7b66ee98a76d60b1615e71R79-R130)) --- python/taichi/lang/struct.py | 2 ++ tests/python/test_struct.py | 52 ++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/python/taichi/lang/struct.py b/python/taichi/lang/struct.py index 62f68171a6b5f..89aab933afb46 100644 --- a/python/taichi/lang/struct.py +++ b/python/taichi/lang/struct.py @@ -630,6 +630,7 @@ def from_taichi_object(self, func_ret, ret_index=()): d[name] = dtype.from_taichi_object(func_ret, ret_index + (index,)) else: d[name] = expr.Expr(_ti_core.make_get_element_expr(func_ret.ptr, ret_index + (index,))) + d["__struct_methods"] = self.methods return Struct(d) @@ -650,6 +651,7 @@ def from_kernel_struct_ret(self, launch_ctx, ret_index=()): d[name] = launch_ctx.get_struct_ret_float(ret_index + (index,)) else: raise TaichiRuntimeTypeError(f"Invalid return type on index={ret_index + (index, )}") + d["__struct_methods"] = self.methods return Struct(d) diff --git a/tests/python/test_struct.py b/tests/python/test_struct.py index e499f41a190b7..39933db5e64bf 100644 --- a/tests/python/test_struct.py +++ b/tests/python/test_struct.py @@ -76,3 +76,55 @@ def test_2d_nested(): for i in range(n * 2): for j in range(n): assert x[i, j] == i + j * 10 + + +@test_utils.test() +def test_func_of_data_class_as_kernel_arg(): + @ti.dataclass + class Foo: + x: ti.f32 + y: ti.f32 + + @ti.func + def add(self, other: ti.template()): + return Foo(self.x + other.x, self.y + other.y) + + @ti.kernel + def foo_x(x: Foo) -> ti.f32: + return x.add(x).x + + assert foo_x(Foo(1, 2)) == 2 + + @ti.kernel + def foo_y(x: Foo) -> ti.f32: + return x.add(x).y + + assert foo_y(Foo(1, 2)) == 4 + + +@test_utils.test(arch=[ti.cpu, ti.cuda, ti.amdgpu]) +def test_func_of_data_class_as_kernel_return(): + # TODO: enable this test in SPIR-V based backends after SPIR-V based backends can return structs. + @ti.dataclass + class Foo: + x: ti.f32 + y: ti.f32 + + @ti.func + def add(self, other: ti.template()): + return Foo(self.x + other.x, self.y + other.y) + + def add_python(self, other): + return Foo(self.x + other.x, self.y + other.y) + + @ti.kernel + def foo(x: Foo) -> Foo: + return x.add(x) + + b = foo(Foo(1, 2)) + assert b.x == 2 + assert b.y == 4 + + c = b.add_python(b) + assert c.x == 4 + assert c.y == 8