Skip to content

Commit

Permalink
[llvm] Support nested struct with matrix return value on real function (
Browse files Browse the repository at this point in the history
#6734)

Issue: #602 #6590
Also fixed the bug that scalarize pass is not run on real functions
thanks to @jim19930609.
### Brief Summary

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lin-hitonami and pre-commit-ci[bot] authored Nov 28, 2022
1 parent 8ee95c9 commit 52a90c9
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 31 deletions.
9 changes: 3 additions & 6 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,11 +563,8 @@ def build_FunctionDef(ctx, node):
def transform_as_kernel():
# Treat return type
if node.returns is not None:
if isinstance(ctx.func.return_type, StructType):
for tp in ctx.func.return_type.members.values():
kernel_arguments.decl_ret(tp)
else:
kernel_arguments.decl_ret(ctx.func.return_type)
kernel_arguments.decl_ret(ctx.func.return_type,
ctx.is_real_function)

for i, arg in enumerate(args.args):
if not isinstance(ctx.func.arguments[i].annotation,
Expand Down Expand Up @@ -756,7 +753,7 @@ def build_Return(ctx, node):
values = node.value.ptr
assert isinstance(values, Struct)
ctx.ast_builder.create_kernel_exprgroup_return(
expr.make_expr_group(values._members))
expr.make_expr_group(expr._get_flattened_ptrs(values)))
else:
raise TaichiSyntaxError(
"The return type is not supported now!")
Expand Down
13 changes: 13 additions & 0 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,17 @@ def make_expr_group(*exprs, real_func_arg=False):
return expr_group


def _get_flattened_ptrs(val):
if is_taichi_class(val):
ptrs = []
for item in val._members:
ptrs.extend(_get_flattened_ptrs(item))
return ptrs
if impl.current_cfg().real_matrix and isinstance(
val, Expr) and val.ptr.is_tensor():
return impl.get_runtime().prog.current_ast_builder().expand_expr(
[val.ptr])
return [Expr(val).ptr]


__all__ = []
13 changes: 11 additions & 2 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from taichi.lang.enums import Layout
from taichi.lang.expr import Expr
from taichi.lang.matrix import Matrix, MatrixType, Vector, VectorType
from taichi.lang.struct import StructType
from taichi.lang.util import cook_dtype
from taichi.types.primitive_types import RefType, f32, u64

Expand Down Expand Up @@ -102,10 +103,18 @@ def decl_rw_texture_arg(num_dimensions, num_channels, channel_format, lod):
channel_format, lod), num_dimensions)


def decl_ret(dtype):
def decl_ret(dtype, real_func=False):
if isinstance(dtype, StructType):
for member in dtype.members.values():
decl_ret(member, real_func)
return
if isinstance(dtype, MatrixType):
if real_func:
for i in range(dtype.n * dtype.m):
decl_ret(dtype.dtype)
return
dtype = _ti_core.get_type_factory_instance().get_tensor_type(
[dtype.n, dtype.m], dtype.dtype)
else:
dtype = cook_dtype(dtype)
return impl.get_runtime().prog.decl_ret(dtype)
impl.get_runtime().prog.decl_ret(dtype)
6 changes: 6 additions & 0 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,12 @@ def __call__(self, *args):
# type cast
return self.cast(Matrix(entries, dt=self.dtype, ndim=self.ndim))

def from_real_func_ret(self, func_ret, ret_index=0):
return self([
expr.Expr(ti_python_core.make_get_element_expr(func_ret.ptr, i))
for i in range(ret_index, ret_index + self.m * self.n)
]), ret_index + self.m * self.n

def cast(self, mat):
if in_python_scope():
return Matrix([[
Expand Down
17 changes: 2 additions & 15 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from taichi._lib import core as _ti_core
from taichi.lang import expr, impl, matrix
from taichi.lang.field import BitpackedFields, Field
from taichi.lang.util import get_traceback, is_taichi_class
from taichi.lang.util import get_traceback


class SNode:
Expand Down Expand Up @@ -376,19 +376,6 @@ def rescale_index(a, b, I):
return matrix.Vector(entries)


def _get_flattened_ptrs(val):
if is_taichi_class(val):
ptrs = []
for item in val._members:
ptrs.extend(_get_flattened_ptrs(item))
return ptrs
if impl.current_cfg().real_matrix and isinstance(
val, expr.Expr) and val.ptr.is_tensor():
return impl.get_runtime().prog.current_ast_builder().expand_expr(
[val.ptr])
return [expr.Expr(val).ptr]


def append(node, indices, val):
"""Append a value `val` to a SNode `node` at index `indices`.
Expand All @@ -397,7 +384,7 @@ def append(node, indices, val):
indices (Union[int, :class:`~taichi.Vector`]): the indices to visit.
val (:mod:`~taichi.types.primitive_types`): the scalar data to be appended, only i32 value is support for now.
"""
ptrs = _get_flattened_ptrs(val)
ptrs = expr._get_flattened_ptrs(val)
append_expr = expr.Expr(_ti_core.expr_snode_append(
node._snode.ptr, expr.make_expr_group(indices), ptrs),
tb=impl.get_runtime().get_current_src_info())
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace irpass {

void re_id(IRNode *root);
void flag_access(IRNode *root);
void scalarize(IRNode *root);
void scalarize(IRNode *root, const CompileConfig &config);
void lower_matrix_ptr(IRNode *root);
bool die(IRNode *root);
bool simplify(IRNode *root, const CompileConfig &config);
Expand Down
10 changes: 9 additions & 1 deletion taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void compile_to_offloads(IRNode *ir,
}

if (config.real_matrix && config.real_matrix_scalarize) {
irpass::scalarize(ir);
irpass::scalarize(ir, config);

// Remove redundant MatrixInitStmt inserted during scalarization
irpass::die(ir);
Expand Down Expand Up @@ -338,6 +338,14 @@ void compile_function(IRNode *ir,
print("Lowered");
}

if (config.real_matrix && config.real_matrix_scalarize) {
irpass::scalarize(ir, config);

// Remove redundant MatrixInitStmt inserted during scalarization
irpass::die(ir);
print("Scalarized");
}

irpass::lower_access(ir, config, {{}, true});
print("Access lowered");
irpass::analysis::verify(ir);
Expand Down
4 changes: 2 additions & 2 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,10 +613,10 @@ class ScalarizePointers : public BasicStmtVisitor {

namespace irpass {

void scalarize(IRNode *root) {
void scalarize(IRNode *root, const CompileConfig &config) {
TI_AUTO_PROF;
Scalarize scalarize_pass(root);
if (!root->get_kernel()->program->this_thread_config().dynamic_index) {
if (!config.dynamic_index) {
ScalarizePointers scalarize_pointers_pass(root);
}
}
Expand Down
8 changes: 4 additions & 4 deletions tests/cpp/transforms/scalarize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ TEST(Scalarize, ScalarizeGlobalStore) {

block->push_back<GlobalStoreStmt>(dest_stmt, matrix_init_stmt);

irpass::scalarize(block.get());
irpass::scalarize(block.get(), test_prog.prog()->this_thread_config());
irpass::lower_matrix_ptr(block.get());
irpass::die(block.get());

Expand Down Expand Up @@ -102,7 +102,7 @@ TEST(Scalarize, ScalarizeGlobalLoad) {
// Without this GlobalStoreStmt, nothing survives irpass::die()
block->push_back<GlobalStoreStmt>(src_stmt, load_stmt);

irpass::scalarize(block.get());
irpass::scalarize(block.get(), test_prog.prog()->this_thread_config());
irpass::lower_matrix_ptr(block.get());
irpass::die(block.get());

Expand Down Expand Up @@ -163,7 +163,7 @@ TEST(Scalarize, ScalarizeLocalStore) {
// LocalStoreStmt survives irpass::die()
block->push_back<LocalStoreStmt>(dest_stmt, matrix_init_stmt);

irpass::scalarize(block.get());
irpass::scalarize(block.get(), test_prog.prog()->this_thread_config());
irpass::die(block.get());

EXPECT_EQ(block->size(), 2 /*const*/ + 4 /*alloca*/ + 4 /*store*/);
Expand Down Expand Up @@ -211,7 +211,7 @@ TEST(Scalarize, ScalarizeLocalLoad) {
// Without this GlobalStoreStmt, nothing survives irpass::die()
block->push_back<GlobalStoreStmt>(src_stmt, load_stmt);

irpass::scalarize(block.get());
irpass::scalarize(block.get(), test_prog.prog()->this_thread_config());
irpass::die(block.get());

EXPECT_EQ(block->size(), 4 /*alloca*/ + 4 /*load*/ + 4 /*store*/);
Expand Down
28 changes: 28 additions & 0 deletions tests/python/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,3 +498,31 @@ def foo() -> ti.f64:
return a.a * a.b

assert foo() == pytest.approx(123 * 1.2345e300)


def _test_real_func_struct_ret_with_matrix():
s0 = ti.types.struct(a=ti.math.vec3, b=ti.i16)
s1 = ti.types.struct(a=ti.f32, b=s0)

@ti.experimental.real_func
def bar() -> s1:
return s1(a=1, b=s0(a=ti.Vector([100, 0.2, 3], dt=ti.f32), b=65537))

@ti.kernel
def foo() -> ti.f32:
s = bar()
return s.a + s.b.a[0] + s.b.a[1] + s.b.a[2] + s.b.b

assert foo() == pytest.approx(105.2)


@test_utils.test(arch=[ti.cpu, ti.cuda])
def test_real_func_struct_ret_with_matrix():
_test_real_func_struct_ret_with_matrix()


@test_utils.test(arch=[ti.cpu, ti.cuda],
real_matrix=True,
real_matrix_scalarize=True)
def test_real_func_struct_ret_with_matrix_real_matrix():
_test_real_func_struct_ret_with_matrix()

0 comments on commit 52a90c9

Please sign in to comment.