Skip to content

Commit

Permalink
[Lang] Support the functions of dataclass as kernel argument and retu…
Browse files Browse the repository at this point in the history
…rn value
  • Loading branch information
lin-hitonami committed Apr 21, 2023
1 parent 570d249 commit 4ae87d7
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/taichi/lang/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ def from_taichi_object(self, func_ret, ret_index=()):
d[name] = dtype.from_taichi_object(func_ret, ret_index + (index,))
else:
d[name] = expr.Expr(_ti_core.make_get_element_expr(func_ret.ptr, ret_index + (index,)))
d["__struct_methods"] = self.methods

return Struct(d)

Expand All @@ -650,6 +651,7 @@ def from_kernel_struct_ret(self, launch_ctx, ret_index=()):
d[name] = launch_ctx.get_struct_ret_float(ret_index + (index,))
else:
raise TaichiRuntimeTypeError(f"Invalid return type on index={ret_index + (index, )}")
d["__struct_methods"] = self.methods

return Struct(d)

Expand Down
52 changes: 52 additions & 0 deletions tests/python/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,55 @@ def test_2d_nested():
for i in range(n * 2):
for j in range(n):
assert x[i, j] == i + j * 10


@test_utils.test()
def test_func_of_data_class_as_kernel_arg():
@ti.dataclass
class Foo:
x: ti.f32
y: ti.f32

@ti.func
def add(self, other: ti.template()):
return Foo(self.x + other.x, self.y + other.y)

@ti.kernel
def foo_x(x: Foo) -> ti.f32:
return x.add(x).x

assert foo_x(Foo(1, 2)) == 2

@ti.kernel
def foo_y(x: Foo) -> ti.f32:
return x.add(x).y

assert foo_y(Foo(1, 2)) == 4


@test_utils.test(arch=[ti.cpu, ti.cuda, ti.amdgpu])
def test_func_of_data_class_as_kernel_return():
# TODO: enable this test in SPIR-V based backends after SPIR-V based backends can return structs.
@ti.dataclass
class Foo:
x: ti.f32
y: ti.f32

@ti.func
def add(self, other: ti.template()):
return Foo(self.x + other.x, self.y + other.y)

def add_python(self, other):
return Foo(self.x + other.x, self.y + other.y)

@ti.kernel
def foo(x: Foo) -> Foo:
return x.add(x)

b = foo(Foo(1, 2))
assert b.x == 2
assert b.y == 4

c = b.add_python(b)
assert c.x == 4
assert c.y == 8

0 comments on commit 4ae87d7

Please sign in to comment.