From 5cf629c10cdb441139ba6cb3e86f02f06f903392 Mon Sep 17 00:00:00 2001 From: Lin Jiang Date: Fri, 25 Nov 2022 13:09:54 +0800 Subject: [PATCH 1/6] [llvm] Support nested struct with matrix return value on real function --- python/taichi/lang/ast/ast_transformer.py | 8 ++----- python/taichi/lang/expr.py | 13 ++++++++++ python/taichi/lang/kernel_arguments.py | 16 ++++++++++--- python/taichi/lang/matrix.py | 4 ++++ python/taichi/lang/snode.py | 15 +----------- tests/python/test_function.py | 29 +++++++++++++++++++++++ 6 files changed, 62 insertions(+), 23 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index c6ab4e84464ea..687670436a521 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -563,11 +563,7 @@ def build_FunctionDef(ctx, node): def transform_as_kernel(): # Treat return type if node.returns is not None: - if isinstance(ctx.func.return_type, StructType): - for tp in ctx.func.return_type.members.values(): - kernel_arguments.decl_ret(tp) - else: - kernel_arguments.decl_ret(ctx.func.return_type) + kernel_arguments.decl_ret(ctx.func.return_type, ctx.is_real_function) for i, arg in enumerate(args.args): if not isinstance(ctx.func.arguments[i].annotation, @@ -756,7 +752,7 @@ def build_Return(ctx, node): values = node.value.ptr assert isinstance(values, Struct) ctx.ast_builder.create_kernel_exprgroup_return( - expr.make_expr_group(values._members)) + expr.make_expr_group(expr._get_flattened_ptrs(values))) else: raise TaichiSyntaxError( "The return type is not supported now!") diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 0d78cdfc0f133..4f38c049a17b2 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -167,4 +167,17 @@ def make_expr_group(*exprs, real_func_arg=False): return expr_group +def _get_flattened_ptrs(val): + if is_taichi_class(val): + ptrs = [] + for item in val._members: + ptrs.extend(_get_flattened_ptrs(item)) + return ptrs + if impl.current_cfg().real_matrix and isinstance( + val, Expr) and val.ptr.is_tensor(): + return impl.get_runtime().prog.current_ast_builder().expand_expr( + [val.ptr]) + return [Expr(val).ptr] + + __all__ = [] diff --git a/python/taichi/lang/kernel_arguments.py b/python/taichi/lang/kernel_arguments.py index a286eaabf0178..f419e013f77df 100644 --- a/python/taichi/lang/kernel_arguments.py +++ b/python/taichi/lang/kernel_arguments.py @@ -8,6 +8,7 @@ from taichi.lang.enums import Layout from taichi.lang.expr import Expr from taichi.lang.matrix import Matrix, MatrixType, Vector, VectorType +from taichi.lang.struct import StructType from taichi.lang.util import cook_dtype from taichi.types.primitive_types import RefType, f32, u64 @@ -102,10 +103,19 @@ def decl_rw_texture_arg(num_dimensions, num_channels, channel_format, lod): channel_format, lod), num_dimensions) -def decl_ret(dtype): +def decl_ret(dtype, real_func=False): + if isinstance(dtype, StructType): + for member in dtype.members.values(): + decl_ret(member, real_func) + return if isinstance(dtype, MatrixType): - dtype = _ti_core.get_type_factory_instance().get_tensor_type( - [dtype.n, dtype.m], dtype.dtype) + if real_func: + for i in range(dtype.n * dtype.m): + decl_ret(dtype.dtype) + return + else: + dtype = _ti_core.get_type_factory_instance().get_tensor_type( + [dtype.n, dtype.m], dtype.dtype) else: dtype = cook_dtype(dtype) return impl.get_runtime().prog.decl_ret(dtype) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 90050709042dd..d38a6ff831a97 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1776,6 +1776,10 @@ def __call__(self, *args): # type cast return self.cast(Matrix(entries, dt=self.dtype, ndim=self.ndim)) + def from_real_func_ret(self, func_ret, ret_index=0): + return self([expr.Expr( + ti_python_core.make_get_element_expr(func_ret.ptr, i)) for i in range(ret_index, ret_index + self.m * self.n)]), ret_index + self.m * self.n + def cast(self, mat): if in_python_scope(): return Matrix([[ diff --git a/python/taichi/lang/snode.py b/python/taichi/lang/snode.py index 3d96752356a3a..2ad3c6934c52f 100644 --- a/python/taichi/lang/snode.py +++ b/python/taichi/lang/snode.py @@ -363,19 +363,6 @@ def rescale_index(a, b, I): return matrix.Vector(entries) -def _get_flattened_ptrs(val): - if is_taichi_class(val): - ptrs = [] - for item in val._members: - ptrs.extend(_get_flattened_ptrs(item)) - return ptrs - if impl.current_cfg().real_matrix and isinstance( - val, expr.Expr) and val.ptr.is_tensor(): - return impl.get_runtime().prog.current_ast_builder().expand_expr( - [val.ptr]) - return [expr.Expr(val).ptr] - - def append(node, indices, val): """Append a value `val` to a SNode `node` at index `indices`. @@ -384,7 +371,7 @@ def append(node, indices, val): indices (Union[int, :class:`~taichi.Vector`]): the indices to visit. val (:mod:`~taichi.types.primitive_types`): the scalar data to be appended, only i32 value is support for now. """ - ptrs = _get_flattened_ptrs(val) + ptrs = expr._get_flattened_ptrs(val) append_expr = expr.Expr(_ti_core.expr_snode_append( node._snode.ptr, expr.make_expr_group(indices), ptrs), tb=impl.get_runtime().get_current_src_info()) diff --git a/tests/python/test_function.py b/tests/python/test_function.py index ef35f0934b27a..491cca887a2e9 100644 --- a/tests/python/test_function.py +++ b/tests/python/test_function.py @@ -498,3 +498,32 @@ def foo() -> ti.f64: return a.a * a.b assert foo() == pytest.approx(123 * 1.2345e300) + + +def _test_real_func_struct_ret_with_matrix(): + s0 = ti.types.struct(a=ti.math.vec3, b=ti.i16) + s1 = ti.types.struct(a=ti.f32, b=s0) + + @ti.experimental.real_func + def bar() -> s1: + return s1(a=1, b=s0(a=ti.Vector([100, 0.2, 3], dt=ti.f32), b=65537)) + + @ti.kernel + def foo() -> ti.f32: + s = bar() + return s.a + s.b.a[0] + s.b.a[1] + s.b.a[2] + s.b.b + + assert foo() == pytest.approx(105.2) + + +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_real_func_struct_ret_with_matrix(): + _test_real_func_struct_ret_with_matrix() + + +@test_utils.test(arch=[ti.cpu, ti.cuda], + real_matrix=True, + real_matrix_scalarize=True) +def _test_real_func_struct_ret_with_matrix_real_matrix(): + # fails: Assertion failure: a->is() && b->is() + _test_real_func_struct_ret_with_matrix() From 4b948d6d0ed96484bcb62b99ed046c6f4daf77ef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 25 Nov 2022 05:11:58 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/taichi/lang/ast/ast_transformer.py | 3 ++- python/taichi/lang/matrix.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 687670436a521..18ce3bc2a05aa 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -563,7 +563,8 @@ def build_FunctionDef(ctx, node): def transform_as_kernel(): # Treat return type if node.returns is not None: - kernel_arguments.decl_ret(ctx.func.return_type, ctx.is_real_function) + kernel_arguments.decl_ret(ctx.func.return_type, + ctx.is_real_function) for i, arg in enumerate(args.args): if not isinstance(ctx.func.arguments[i].annotation, diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index d38a6ff831a97..d240b0db0b55d 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1777,8 +1777,10 @@ def __call__(self, *args): return self.cast(Matrix(entries, dt=self.dtype, ndim=self.ndim)) def from_real_func_ret(self, func_ret, ret_index=0): - return self([expr.Expr( - ti_python_core.make_get_element_expr(func_ret.ptr, i)) for i in range(ret_index, ret_index + self.m * self.n)]), ret_index + self.m * self.n + return self([ + expr.Expr(ti_python_core.make_get_element_expr(func_ret.ptr, i)) + for i in range(ret_index, ret_index + self.m * self.n) + ]), ret_index + self.m * self.n def cast(self, mat): if in_python_scope(): From 9b136422cc52c4f8ac1949dfbed1929cc6d823f7 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Mon, 28 Nov 2022 11:58:09 +0800 Subject: [PATCH 3/6] fix real matrix --- taichi/ir/transforms.h | 2 +- taichi/transforms/compile_to_offloads.cpp | 10 +++++++++- taichi/transforms/scalarize.cpp | 4 ++-- tests/cpp/transforms/scalarize_test.cpp | 8 ++++---- tests/python/test_function.py | 3 +-- 5 files changed, 17 insertions(+), 10 deletions(-) diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index b46f94767b0eb..083ed6aaa0722 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -29,7 +29,7 @@ namespace irpass { void re_id(IRNode *root); void flag_access(IRNode *root); -void scalarize(IRNode *root); +void scalarize(IRNode *root, const CompileConfig& config); void lower_matrix_ptr(IRNode *root); bool die(IRNode *root); bool simplify(IRNode *root, const CompileConfig &config); diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index f50a0caac9a73..810d8c657fb91 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -53,7 +53,7 @@ void compile_to_offloads(IRNode *ir, } if (config.real_matrix && config.real_matrix_scalarize) { - irpass::scalarize(ir); + irpass::scalarize(ir, config); // Remove redundant MatrixInitStmt inserted during scalarization irpass::die(ir); @@ -337,6 +337,14 @@ void compile_function(IRNode *ir, irpass::lower_ast(ir); print("Lowered"); } + + if (config.real_matrix && config.real_matrix_scalarize) { + irpass::scalarize(ir, config); + + // Remove redundant MatrixInitStmt inserted during scalarization + irpass::die(ir); + print("Scalarized"); + } irpass::lower_access(ir, config, {{}, true}); print("Access lowered"); diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 002e83739fcfd..032235f75855f 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -613,10 +613,10 @@ class ScalarizePointers : public BasicStmtVisitor { namespace irpass { -void scalarize(IRNode *root) { +void scalarize(IRNode *root, const CompileConfig& config) { TI_AUTO_PROF; Scalarize scalarize_pass(root); - if (!root->get_kernel()->program->this_thread_config().dynamic_index) { + if (!config.dynamic_index) { ScalarizePointers scalarize_pointers_pass(root); } } diff --git a/tests/cpp/transforms/scalarize_test.cpp b/tests/cpp/transforms/scalarize_test.cpp index 279c47f2bb098..9600f83420157 100644 --- a/tests/cpp/transforms/scalarize_test.cpp +++ b/tests/cpp/transforms/scalarize_test.cpp @@ -45,7 +45,7 @@ TEST(Scalarize, ScalarizeGlobalStore) { block->push_back(dest_stmt, matrix_init_stmt); - irpass::scalarize(block.get()); + irpass::scalarize(block.get(), test_prog.prog()->this_thread_config()); irpass::lower_matrix_ptr(block.get()); irpass::die(block.get()); @@ -102,7 +102,7 @@ TEST(Scalarize, ScalarizeGlobalLoad) { // Without this GlobalStoreStmt, nothing survives irpass::die() block->push_back(src_stmt, load_stmt); - irpass::scalarize(block.get()); + irpass::scalarize(block.get(), test_prog.prog()->this_thread_config()); irpass::lower_matrix_ptr(block.get()); irpass::die(block.get()); @@ -163,7 +163,7 @@ TEST(Scalarize, ScalarizeLocalStore) { // LocalStoreStmt survives irpass::die() block->push_back(dest_stmt, matrix_init_stmt); - irpass::scalarize(block.get()); + irpass::scalarize(block.get(), test_prog.prog()->this_thread_config()); irpass::die(block.get()); EXPECT_EQ(block->size(), 2 /*const*/ + 4 /*alloca*/ + 4 /*store*/); @@ -211,7 +211,7 @@ TEST(Scalarize, ScalarizeLocalLoad) { // Without this GlobalStoreStmt, nothing survives irpass::die() block->push_back(src_stmt, load_stmt); - irpass::scalarize(block.get()); + irpass::scalarize(block.get(), test_prog.prog()->this_thread_config()); irpass::die(block.get()); EXPECT_EQ(block->size(), 4 /*alloca*/ + 4 /*load*/ + 4 /*store*/); diff --git a/tests/python/test_function.py b/tests/python/test_function.py index 491cca887a2e9..fb3d37fe9aa80 100644 --- a/tests/python/test_function.py +++ b/tests/python/test_function.py @@ -524,6 +524,5 @@ def test_real_func_struct_ret_with_matrix(): @test_utils.test(arch=[ti.cpu, ti.cuda], real_matrix=True, real_matrix_scalarize=True) -def _test_real_func_struct_ret_with_matrix_real_matrix(): - # fails: Assertion failure: a->is() && b->is() +def test_real_func_struct_ret_with_matrix_real_matrix(): _test_real_func_struct_ret_with_matrix() From 39eef5f972ecc5afef409650069b9b59bddc3a51 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Nov 2022 03:59:24 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/ir/transforms.h | 2 +- taichi/transforms/compile_to_offloads.cpp | 2 +- taichi/transforms/scalarize.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index 083ed6aaa0722..ed8571fe02825 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -29,7 +29,7 @@ namespace irpass { void re_id(IRNode *root); void flag_access(IRNode *root); -void scalarize(IRNode *root, const CompileConfig& config); +void scalarize(IRNode *root, const CompileConfig &config); void lower_matrix_ptr(IRNode *root); bool die(IRNode *root); bool simplify(IRNode *root, const CompileConfig &config); diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 810d8c657fb91..975e1ecd6adbc 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -337,7 +337,7 @@ void compile_function(IRNode *ir, irpass::lower_ast(ir); print("Lowered"); } - + if (config.real_matrix && config.real_matrix_scalarize) { irpass::scalarize(ir, config); diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 032235f75855f..626ed168fe536 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -613,7 +613,7 @@ class ScalarizePointers : public BasicStmtVisitor { namespace irpass { -void scalarize(IRNode *root, const CompileConfig& config) { +void scalarize(IRNode *root, const CompileConfig &config) { TI_AUTO_PROF; Scalarize scalarize_pass(root); if (!config.dynamic_index) { From 5ef03ee672fb26c02ce5394b97c6bd34a01e3fa0 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Mon, 28 Nov 2022 12:06:46 +0800 Subject: [PATCH 5/6] fix pylint --- python/taichi/lang/snode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/taichi/lang/snode.py b/python/taichi/lang/snode.py index 2ad3c6934c52f..54c45299be1a7 100644 --- a/python/taichi/lang/snode.py +++ b/python/taichi/lang/snode.py @@ -3,7 +3,7 @@ from taichi._lib import core as _ti_core from taichi.lang import expr, impl, matrix from taichi.lang.field import BitpackedFields, Field -from taichi.lang.util import get_traceback, is_taichi_class +from taichi.lang.util import get_traceback class SNode: From 538eb981c7ad8ed5dcee930e43720ca4e05d2ad4 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Mon, 28 Nov 2022 12:38:40 +0800 Subject: [PATCH 6/6] fix pylint --- python/taichi/lang/kernel_arguments.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/taichi/lang/kernel_arguments.py b/python/taichi/lang/kernel_arguments.py index f419e013f77df..d94b16844a86b 100644 --- a/python/taichi/lang/kernel_arguments.py +++ b/python/taichi/lang/kernel_arguments.py @@ -113,9 +113,8 @@ def decl_ret(dtype, real_func=False): for i in range(dtype.n * dtype.m): decl_ret(dtype.dtype) return - else: - dtype = _ti_core.get_type_factory_instance().get_tensor_type( - [dtype.n, dtype.m], dtype.dtype) + dtype = _ti_core.get_type_factory_instance().get_tensor_type( + [dtype.n, dtype.m], dtype.dtype) else: dtype = cook_dtype(dtype) - return impl.get_runtime().prog.decl_ret(dtype) + impl.get_runtime().prog.decl_ret(dtype)