Skip to content

Commit

Permalink
[Lang] Support the functions of dataclass as kernel argument and retu…
Browse files Browse the repository at this point in the history
…rn value (taichi-dev#7865)

Issue: fixes taichi-dev#7822

<!--
copilot:all
-->
### <samp>🤖 Generated by Copilot at 4ae87d7</samp>

### Summary
🧪🛠️🚀

<!--
1. 🧪 - This emoji represents testing, experimentation, or science, and
can be used to indicate the addition of test cases or the verification
of functionality.
2. 🛠️ - This emoji represents tools, construction, or repair, and can be
used to indicate the addition of a new feature or the improvement of an
existing one.
3. 🚀 - This emoji represents speed, launch, or innovation, and can be
used to indicate the enhancement of performance, the expansion of
capabilities, or the introduction of a novel idea.
-->
This pull request enables data classes defined by `ti.dataclass` to have
methods decorated by `ti.func` and `ti.kernel`. It also adds tests to
check the functionality and correctness of this feature.

> _`StructType` creates the data class of doom_
> _With `__struct_methods` to unleash its power_
> _Call them from Python, or pass them to `ti.kernel`_
> _`ti.func` returns them, the ultimate metal_

### Walkthrough
* Add `__struct_methods` attribute to data class dictionaries to enable
access of methods defined in `ti.dataclass` decorator
([link](https://github.com/taichi-dev/taichi/pull/7865/files?diff=unified&w=0#diff-3154e0533b9fd63e663c16c2e87c65068c3b002a65cf80529b6577d173bdb5feR633),
[link](https://github.com/taichi-dev/taichi/pull/7865/files?diff=unified&w=0#diff-3154e0533b9fd63e663c16c2e87c65068c3b002a65cf80529b6577d173bdb5feR654))
* Add test cases for using `ti.func` and `ti.kernel` as methods of data
classes, both as kernel arguments and kernel returns, in
`test_struct.py`
([link](https://github.com/taichi-dev/taichi/pull/7865/files?diff=unified&w=0#diff-e87bf5cb1cd09e10b5cfa001ab2ef18f31a242db3a7b66ee98a76d60b1615e71R79-R130))
  • Loading branch information
lin-hitonami authored and quadpixels committed May 13, 2023
1 parent 4e0c379 commit d767e1f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/taichi/lang/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ def from_taichi_object(self, func_ret, ret_index=()):
d[name] = dtype.from_taichi_object(func_ret, ret_index + (index,))
else:
d[name] = expr.Expr(_ti_core.make_get_element_expr(func_ret.ptr, ret_index + (index,)))
d["__struct_methods"] = self.methods

return Struct(d)

Expand All @@ -650,6 +651,7 @@ def from_kernel_struct_ret(self, launch_ctx, ret_index=()):
d[name] = launch_ctx.get_struct_ret_float(ret_index + (index,))
else:
raise TaichiRuntimeTypeError(f"Invalid return type on index={ret_index + (index, )}")
d["__struct_methods"] = self.methods

return Struct(d)

Expand Down
52 changes: 52 additions & 0 deletions tests/python/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,55 @@ def test_2d_nested():
for i in range(n * 2):
for j in range(n):
assert x[i, j] == i + j * 10


@test_utils.test()
def test_func_of_data_class_as_kernel_arg():
@ti.dataclass
class Foo:
x: ti.f32
y: ti.f32

@ti.func
def add(self, other: ti.template()):
return Foo(self.x + other.x, self.y + other.y)

@ti.kernel
def foo_x(x: Foo) -> ti.f32:
return x.add(x).x

assert foo_x(Foo(1, 2)) == 2

@ti.kernel
def foo_y(x: Foo) -> ti.f32:
return x.add(x).y

assert foo_y(Foo(1, 2)) == 4


@test_utils.test(arch=[ti.cpu, ti.cuda, ti.amdgpu])
def test_func_of_data_class_as_kernel_return():
# TODO: enable this test in SPIR-V based backends after SPIR-V based backends can return structs.
@ti.dataclass
class Foo:
x: ti.f32
y: ti.f32

@ti.func
def add(self, other: ti.template()):
return Foo(self.x + other.x, self.y + other.y)

def add_python(self, other):
return Foo(self.x + other.x, self.y + other.y)

@ti.kernel
def foo(x: Foo) -> Foo:
return x.add(x)

b = foo(Foo(1, 2))
assert b.x == 2
assert b.y == 4

c = b.add_python(b)
assert c.x == 4
assert c.y == 8

0 comments on commit d767e1f

Please sign in to comment.