From 4ae87d7af6c92b977e6db086d24b3bdbb3cbc315 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Fri, 21 Apr 2023 16:30:25 +0800 Subject: [PATCH] [Lang] Support the functions of dataclass as kernel argument and return value --- python/taichi/lang/struct.py | 2 ++ tests/python/test_struct.py | 52 ++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/python/taichi/lang/struct.py b/python/taichi/lang/struct.py index 62f68171a6b5f..89aab933afb46 100644 --- a/python/taichi/lang/struct.py +++ b/python/taichi/lang/struct.py @@ -630,6 +630,7 @@ def from_taichi_object(self, func_ret, ret_index=()): d[name] = dtype.from_taichi_object(func_ret, ret_index + (index,)) else: d[name] = expr.Expr(_ti_core.make_get_element_expr(func_ret.ptr, ret_index + (index,))) + d["__struct_methods"] = self.methods return Struct(d) @@ -650,6 +651,7 @@ def from_kernel_struct_ret(self, launch_ctx, ret_index=()): d[name] = launch_ctx.get_struct_ret_float(ret_index + (index,)) else: raise TaichiRuntimeTypeError(f"Invalid return type on index={ret_index + (index, )}") + d["__struct_methods"] = self.methods return Struct(d) diff --git a/tests/python/test_struct.py b/tests/python/test_struct.py index e499f41a190b7..39933db5e64bf 100644 --- a/tests/python/test_struct.py +++ b/tests/python/test_struct.py @@ -76,3 +76,55 @@ def test_2d_nested(): for i in range(n * 2): for j in range(n): assert x[i, j] == i + j * 10 + + +@test_utils.test() +def test_func_of_data_class_as_kernel_arg(): + @ti.dataclass + class Foo: + x: ti.f32 + y: ti.f32 + + @ti.func + def add(self, other: ti.template()): + return Foo(self.x + other.x, self.y + other.y) + + @ti.kernel + def foo_x(x: Foo) -> ti.f32: + return x.add(x).x + + assert foo_x(Foo(1, 2)) == 2 + + @ti.kernel + def foo_y(x: Foo) -> ti.f32: + return x.add(x).y + + assert foo_y(Foo(1, 2)) == 4 + + +@test_utils.test(arch=[ti.cpu, ti.cuda, ti.amdgpu]) +def test_func_of_data_class_as_kernel_return(): + # TODO: enable this test in SPIR-V based backends after SPIR-V based backends can return structs. + @ti.dataclass + class Foo: + x: ti.f32 + y: ti.f32 + + @ti.func + def add(self, other: ti.template()): + return Foo(self.x + other.x, self.y + other.y) + + def add_python(self, other): + return Foo(self.x + other.x, self.y + other.y) + + @ti.kernel + def foo(x: Foo) -> Foo: + return x.add(x) + + b = foo(Foo(1, 2)) + assert b.x == 2 + assert b.y == 4 + + c = b.add_python(b) + assert c.x == 4 + assert c.y == 8