Skip to content

Commit

Permalink
[lang] Add support for real matrix args on real function (#6522)
Browse files Browse the repository at this point in the history
Issue: #602

### Brief Summary

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lin-hitonami and pre-commit-ci[bot] authored Nov 7, 2022
1 parent 38b8ef7 commit 107e783
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
6 changes: 6 additions & 0 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ def func_call_rvalue(self, key, args):
elif isinstance(anno, primitive_types.RefType):
non_template_args.append(
_ti_core.make_reference(args[i].ptr))
elif impl.current_cfg().real_matrix and isinstance(
args[i], impl.Expr) and args[i].ptr.is_tensor():
non_template_args.extend([
Expr(x) for x in impl.get_runtime().prog.
current_ast_builder().expand_expr([args[i].ptr])
])
else:
non_template_args.append(args[i])
non_template_args = impl.make_expr_group(non_template_args,
Expand Down
27 changes: 21 additions & 6 deletions tests/python/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,15 +445,30 @@ def test_func_matrix_arg_real_matrix():
_test_func_matrix_arg()


@test_utils.test(arch=[ti.cpu, ti.cuda])
def test_real_func_matrix_arg():
def _test_real_func_matrix_arg():
@ti.experimental.real_func
def mat_arg(a: ti.math.mat2) -> float:
return a[0, 0] + a[0, 1] + a[1, 0] + a[1, 1]
def mat_arg(a: ti.math.mat2, b: ti.math.vec2) -> float:
return a[0, 0] + a[0, 1] + a[1, 0] + a[1, 1] + b[0] + b[1]

b = ti.Vector.field(n=2, dtype=float, shape=())
b[()][0] = 5
b[()][1] = 6

@ti.kernel
def foo() -> float:
a = ti.math.mat2(1, 2, 3, 4)
return mat_arg(a)
return mat_arg(a, b[()])

assert foo() == pytest.approx(21)


@test_utils.test(arch=[ti.cpu, ti.cuda])
def test_real_func_matrix_arg():
_test_real_func_matrix_arg()


assert foo() == pytest.approx(10)
@test_utils.test(arch=[ti.cpu, ti.cuda],
real_matrix=True,
real_matrix_scalarize=True)
def test_real_func_matrix_arg_real_matrix():
_test_real_func_matrix_arg()

0 comments on commit 107e783

Please sign in to comment.