Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llvm] Support nested struct with matrix return value on real function #6734

Merged
merged 7 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,11 +563,8 @@ def build_FunctionDef(ctx, node):
def transform_as_kernel():
# Treat return type
if node.returns is not None:
if isinstance(ctx.func.return_type, StructType):
for tp in ctx.func.return_type.members.values():
kernel_arguments.decl_ret(tp)
else:
kernel_arguments.decl_ret(ctx.func.return_type)
kernel_arguments.decl_ret(ctx.func.return_type,
ctx.is_real_function)

for i, arg in enumerate(args.args):
if not isinstance(ctx.func.arguments[i].annotation,
Expand Down Expand Up @@ -756,7 +753,7 @@ def build_Return(ctx, node):
values = node.value.ptr
assert isinstance(values, Struct)
ctx.ast_builder.create_kernel_exprgroup_return(
expr.make_expr_group(values._members))
expr.make_expr_group(expr._get_flattened_ptrs(values)))
else:
raise TaichiSyntaxError(
"The return type is not supported now!")
Expand Down
13 changes: 13 additions & 0 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,17 @@ def make_expr_group(*exprs, real_func_arg=False):
return expr_group


def _get_flattened_ptrs(val):
if is_taichi_class(val):
ptrs = []
for item in val._members:
ptrs.extend(_get_flattened_ptrs(item))
return ptrs
if impl.current_cfg().real_matrix and isinstance(
val, Expr) and val.ptr.is_tensor():
return impl.get_runtime().prog.current_ast_builder().expand_expr(
[val.ptr])
return [Expr(val).ptr]


__all__ = []
13 changes: 11 additions & 2 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from taichi.lang.enums import Layout
from taichi.lang.expr import Expr
from taichi.lang.matrix import Matrix, MatrixType, Vector, VectorType
from taichi.lang.struct import StructType
from taichi.lang.util import cook_dtype
from taichi.types.primitive_types import RefType, f32, u64

Expand Down Expand Up @@ -102,10 +103,18 @@ def decl_rw_texture_arg(num_dimensions, num_channels, channel_format, lod):
channel_format, lod), num_dimensions)


def decl_ret(dtype):
def decl_ret(dtype, real_func=False):
if isinstance(dtype, StructType):
for member in dtype.members.values():
decl_ret(member, real_func)
return
if isinstance(dtype, MatrixType):
if real_func:
for i in range(dtype.n * dtype.m):
decl_ret(dtype.dtype)
return
dtype = _ti_core.get_type_factory_instance().get_tensor_type(
[dtype.n, dtype.m], dtype.dtype)
else:
dtype = cook_dtype(dtype)
return impl.get_runtime().prog.decl_ret(dtype)
impl.get_runtime().prog.decl_ret(dtype)
6 changes: 6 additions & 0 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,12 @@ def __call__(self, *args):
# type cast
return self.cast(Matrix(entries, dt=self.dtype, ndim=self.ndim))

def from_real_func_ret(self, func_ret, ret_index=0):
return self([
expr.Expr(ti_python_core.make_get_element_expr(func_ret.ptr, i))
for i in range(ret_index, ret_index + self.m * self.n)
]), ret_index + self.m * self.n

def cast(self, mat):
if in_python_scope():
return Matrix([[
Expand Down
17 changes: 2 additions & 15 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from taichi._lib import core as _ti_core
from taichi.lang import expr, impl, matrix
from taichi.lang.field import BitpackedFields, Field
from taichi.lang.util import get_traceback, is_taichi_class
from taichi.lang.util import get_traceback


class SNode:
Expand Down Expand Up @@ -363,19 +363,6 @@ def rescale_index(a, b, I):
return matrix.Vector(entries)


def _get_flattened_ptrs(val):
if is_taichi_class(val):
ptrs = []
for item in val._members:
ptrs.extend(_get_flattened_ptrs(item))
return ptrs
if impl.current_cfg().real_matrix and isinstance(
val, expr.Expr) and val.ptr.is_tensor():
return impl.get_runtime().prog.current_ast_builder().expand_expr(
[val.ptr])
return [expr.Expr(val).ptr]


def append(node, indices, val):
"""Append a value `val` to a SNode `node` at index `indices`.

Expand All @@ -384,7 +371,7 @@ def append(node, indices, val):
indices (Union[int, :class:`~taichi.Vector`]): the indices to visit.
val (:mod:`~taichi.types.primitive_types`): the scalar data to be appended, only i32 value is support for now.
"""
ptrs = _get_flattened_ptrs(val)
ptrs = expr._get_flattened_ptrs(val)
append_expr = expr.Expr(_ti_core.expr_snode_append(
node._snode.ptr, expr.make_expr_group(indices), ptrs),
tb=impl.get_runtime().get_current_src_info())
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace irpass {

void re_id(IRNode *root);
void flag_access(IRNode *root);
void scalarize(IRNode *root);
void scalarize(IRNode *root, const CompileConfig &config);
void lower_matrix_ptr(IRNode *root);
bool die(IRNode *root);
bool simplify(IRNode *root, const CompileConfig &config);
Expand Down
10 changes: 9 additions & 1 deletion taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void compile_to_offloads(IRNode *ir,
}

if (config.real_matrix && config.real_matrix_scalarize) {
irpass::scalarize(ir);
irpass::scalarize(ir, config);

// Remove redundant MatrixInitStmt inserted during scalarization
irpass::die(ir);
Expand Down Expand Up @@ -338,6 +338,14 @@ void compile_function(IRNode *ir,
print("Lowered");
}

if (config.real_matrix && config.real_matrix_scalarize) {
irpass::scalarize(ir, config);

// Remove redundant MatrixInitStmt inserted during scalarization
irpass::die(ir);
print("Scalarized");
}

irpass::lower_access(ir, config, {{}, true});
print("Access lowered");
irpass::analysis::verify(ir);
Expand Down
4 changes: 2 additions & 2 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,10 +613,10 @@ class ScalarizePointers : public BasicStmtVisitor {

namespace irpass {

void scalarize(IRNode *root) {
void scalarize(IRNode *root, const CompileConfig &config) {
TI_AUTO_PROF;
Scalarize scalarize_pass(root);
if (!root->get_kernel()->program->this_thread_config().dynamic_index) {
if (!config.dynamic_index) {
ScalarizePointers scalarize_pointers_pass(root);
}
}
Expand Down
8 changes: 4 additions & 4 deletions tests/cpp/transforms/scalarize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ TEST(Scalarize, ScalarizeGlobalStore) {

block->push_back<GlobalStoreStmt>(dest_stmt, matrix_init_stmt);

irpass::scalarize(block.get());
irpass::scalarize(block.get(), test_prog.prog()->this_thread_config());
irpass::lower_matrix_ptr(block.get());
irpass::die(block.get());

Expand Down Expand Up @@ -102,7 +102,7 @@ TEST(Scalarize, ScalarizeGlobalLoad) {
// Without this GlobalStoreStmt, nothing survives irpass::die()
block->push_back<GlobalStoreStmt>(src_stmt, load_stmt);

irpass::scalarize(block.get());
irpass::scalarize(block.get(), test_prog.prog()->this_thread_config());
irpass::lower_matrix_ptr(block.get());
irpass::die(block.get());

Expand Down Expand Up @@ -163,7 +163,7 @@ TEST(Scalarize, ScalarizeLocalStore) {
// LocalStoreStmt survives irpass::die()
block->push_back<LocalStoreStmt>(dest_stmt, matrix_init_stmt);

irpass::scalarize(block.get());
irpass::scalarize(block.get(), test_prog.prog()->this_thread_config());
irpass::die(block.get());

EXPECT_EQ(block->size(), 2 /*const*/ + 4 /*alloca*/ + 4 /*store*/);
Expand Down Expand Up @@ -211,7 +211,7 @@ TEST(Scalarize, ScalarizeLocalLoad) {
// Without this GlobalStoreStmt, nothing survives irpass::die()
block->push_back<GlobalStoreStmt>(src_stmt, load_stmt);

irpass::scalarize(block.get());
irpass::scalarize(block.get(), test_prog.prog()->this_thread_config());
irpass::die(block.get());

EXPECT_EQ(block->size(), 4 /*alloca*/ + 4 /*load*/ + 4 /*store*/);
Expand Down
28 changes: 28 additions & 0 deletions tests/python/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,3 +498,31 @@ def foo() -> ti.f64:
return a.a * a.b

assert foo() == pytest.approx(123 * 1.2345e300)


def _test_real_func_struct_ret_with_matrix():
s0 = ti.types.struct(a=ti.math.vec3, b=ti.i16)
s1 = ti.types.struct(a=ti.f32, b=s0)

@ti.experimental.real_func
def bar() -> s1:
return s1(a=1, b=s0(a=ti.Vector([100, 0.2, 3], dt=ti.f32), b=65537))

@ti.kernel
def foo() -> ti.f32:
s = bar()
return s.a + s.b.a[0] + s.b.a[1] + s.b.a[2] + s.b.b

assert foo() == pytest.approx(105.2)


@test_utils.test(arch=[ti.cpu, ti.cuda])
def test_real_func_struct_ret_with_matrix():
_test_real_func_struct_ret_with_matrix()


@test_utils.test(arch=[ti.cpu, ti.cuda],
real_matrix=True,
real_matrix_scalarize=True)
def test_real_func_struct_ret_with_matrix_real_matrix():
_test_real_func_struct_ret_with_matrix()