Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lang] Migrate TensorType expansion for subscription indices from Python to Frontend IR #6942

Merged
merged 25 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
db2163e
[lang] Migrate TensorType expansion for subscription indices from Pyt…
jim19930609 Dec 21, 2022
6d7824c
Bug fix
jim19930609 Dec 22, 2022
a84ec42
Bug fix
jim19930609 Dec 22, 2022
fdb92f5
Bug fix
jim19930609 Dec 22, 2022
07be464
Bug fix
jim19930609 Dec 22, 2022
7808e03
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_ref…
jim19930609 Dec 23, 2022
4c5ab88
Bug fix
jim19930609 Dec 23, 2022
c357f4d
Bug fix
jim19930609 Dec 23, 2022
397001d
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_ref…
jim19930609 Dec 23, 2022
847989b
Bug fix
jim19930609 Dec 23, 2022
1411328
Code adjustment
jim19930609 Dec 27, 2022
88c0bf8
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_ref…
jim19930609 Dec 28, 2022
f55e981
Bug fix
jim19930609 Dec 28, 2022
4abc1af
Bug fix
jim19930609 Dec 28, 2022
3973835
Code adjustment
jim19930609 Dec 28, 2022
db501e6
Code adjustment
jim19930609 Dec 28, 2022
e6822d2
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_ref…
jim19930609 Dec 29, 2022
1df4316
Bug fix
jim19930609 Dec 29, 2022
bdae239
Bug fix
jim19930609 Dec 29, 2022
72b21b9
Code adjustment
jim19930609 Jan 3, 2023
061ff41
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_ref…
jim19930609 Jan 3, 2023
726c5c0
Code adjustment
jim19930609 Jan 3, 2023
9e21cdf
Bug fix
jim19930609 Jan 4, 2023
6223005
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_ref…
jim19930609 Jan 4, 2023
1e2b48c
Code adjustment
jim19930609 Jan 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/taichi/lang/_texture.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def _get_entries(mat):
if isinstance(mat, Matrix):
return mat.entries
assert isinstance(mat, Expr) and mat.is_tensor()
return impl.get_runtime().prog.current_ast_builder().expand_expr([mat.ptr])
return impl.get_runtime().prog.current_ast_builder().expand_exprs(
[mat.ptr])


class TextureSampler:
Expand Down
7 changes: 5 additions & 2 deletions python/taichi/lang/any_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,18 @@ def __init__(self, arr, indices_first):

@taichi_scope
def subscript(self, i, j):
ast_builder = impl.get_runtime().prog.current_ast_builder()

indices_second = (i, ) if len(self.arr.element_shape()) == 1 else (i,
j)
if self.arr.layout() == Layout.SOA:
indices = indices_second + self.indices_first
else:
indices = self.indices_first + indices_second
return Expr(
_ti_core.subscript(self.arr.ptr, make_expr_group(*indices),
impl.get_runtime().get_current_src_info()))
ast_builder.expr_subscript(
self.arr.ptr, make_expr_group(*indices),
impl.get_runtime().get_current_src_info()))


__all__ = []
33 changes: 18 additions & 15 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
ReturnStatus)
from taichi.lang.ast.symbol_resolver import ASTResolver
from taichi.lang.exception import (TaichiIndexError, TaichiSyntaxError,
TaichiTypeError)
TaichiTypeError, handle_exception_from_cpp)
from taichi.lang.expr import Expr, make_expr_group
from taichi.lang.field import Field
from taichi.lang.matrix import Matrix, MatrixType, Vector, is_vector
Expand Down Expand Up @@ -156,7 +156,7 @@ def build_assign_unpack(ctx, node_target, values, is_static_assign):
raise ValueError(
'Matrices with more than one columns cannot be unpacked')

values = ctx.ast_builder.expand_expr([values.ptr])
values = ctx.ast_builder.expand_exprs([values.ptr])
if len(values) == 1:
values = values[0]

Expand Down Expand Up @@ -302,7 +302,7 @@ def process_generators(ctx, node, now_comp, func, result):
if isinstance(_iter, impl.Expr) and _iter.ptr.is_tensor():
shape = _iter.ptr.get_shape()
flattened = [
Expr(x) for x in ctx.ast_builder.expand_expr([_iter.ptr])
Expr(x) for x in ctx.ast_builder.expand_exprs([_iter.ptr])
]
_iter = reshape_list(flattened, shape)

Expand Down Expand Up @@ -514,7 +514,7 @@ def build_Call(ctx, node):
# Expand Expr with Matrix-type return into list of Exprs
arg_list = [
Expr(x)
for x in ctx.ast_builder.expand_expr([arg_list.ptr])
for x in ctx.ast_builder.expand_exprs([arg_list.ptr])
]

for i in arg_list:
Expand Down Expand Up @@ -730,7 +730,7 @@ def build_Return(ctx, node):
elif isinstance(ctx.func.return_type, MatrixType):
values = node.value.ptr
if isinstance(values, Expr) and values.ptr.is_tensor():
values = ctx.ast_builder.expand_expr([values.ptr])
values = ctx.ast_builder.expand_exprs([values.ptr])
else:
assert isinstance(values, Matrix)
values = itertools.chain.from_iterable(values.to_list()) if\
Expand Down Expand Up @@ -819,12 +819,15 @@ def build_Attribute(ctx, node):
# we continue to process it as a normal attribute node.
try:
build_stmt(ctx, node.value)
except TaichiIndexError as e:
node.value.ptr = None
if ASTTransformer.build_attribute_if_is_dynamic_snode_method(
ctx, node):
return node.ptr
except Exception as e:
e = handle_exception_from_cpp(e)
if isinstance(e, TaichiIndexError):
node.value.ptr = None
if ASTTransformer.build_attribute_if_is_dynamic_snode_method(
ctx, node):
return node.ptr
raise e

if ASTTransformer.build_attribute_if_is_dynamic_snode_method(
ctx, node):
return node.ptr
Expand All @@ -837,11 +840,11 @@ def build_Attribute(ctx, node):
node.attr)
attr_len = len(node.attr)
if attr_len == 1:
node.ptr = Expr(
_ti_core.subscript(
node.value.ptr.ptr,
make_expr_group(keygroup.index(node.attr)),
impl.get_runtime().get_current_src_info()))
node.ptr = Expr(impl.get_runtime(
).prog.current_ast_builder().expr_subscript(
node.value.ptr.ptr,
make_expr_group(keygroup.index(node.attr)),
impl.get_runtime().get_current_src_info()))
else:
node.ptr = Expr(
_ti_core.subscript_with_multiple_indices(
Expand Down
2 changes: 2 additions & 0 deletions python/taichi/lang/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def handle_exception_from_cpp(exc):
return TaichiTypeError(str(exc))
if isinstance(exc, core.TaichiSyntaxError):
return TaichiSyntaxError(str(exc))
if isinstance(exc, core.TaichiIndexError):
return TaichiIndexError(str(exc))
if isinstance(exc, core.TaichiAssertionError):
return TaichiAssertionError(str(exc))
return exc
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _get_flattened_ptrs(val):
ptrs.extend(_get_flattened_ptrs(item))
return ptrs
if isinstance(val, Expr) and val.ptr.is_tensor():
return impl.get_runtime().prog.current_ast_builder().expand_expr(
return impl.get_runtime().prog.current_ast_builder().expand_exprs(
[val.ptr])
return [Expr(val).ptr]

Expand Down
55 changes: 22 additions & 33 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from taichi.lang._texture import RWTextureAccessor
from taichi.lang.any_array import AnyArray
from taichi.lang.enums import SNodeGradType
from taichi.lang.exception import (TaichiCompilationError, TaichiIndexError,
TaichiRuntimeError, TaichiSyntaxError,
TaichiTypeError)
from taichi.lang.exception import (TaichiCompilationError, TaichiRuntimeError,
TaichiSyntaxError, TaichiTypeError)
from taichi.lang.expr import Expr, make_expr_group
from taichi.lang.field import Field, ScalarField
from taichi.lang.kernel_arguments import SparseMatrixProxy
Expand Down Expand Up @@ -132,6 +131,7 @@ def check_validity(x):

@taichi_scope
def subscript(ast_builder, value, *_indices, skip_reordered=False):
ast_builder = get_runtime().prog.current_ast_builder()
# Directly evaluate in Python for non-Taichi types
if not isinstance(
value,
Expand All @@ -150,9 +150,6 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
elif isinstance(_index, slice):
ind = [_index]
has_slice = True
elif isinstance(_index, Expr) and _index.is_tensor():
# Expand Expr with TensorType return
ind = [Expr(e) for e in ast_builder.expand_expr([_index.ptr])]
else:
ind = [_index]
flattened_indices += ind
Expand All @@ -167,7 +164,6 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
f"The type {type(value)} do not support index of slice type")
else:
indices_expr_group = make_expr_group(*indices)
index_dim = indices_expr_group.size()

if isinstance(value, SharedArray):
return value.subscript(*indices)
Expand All @@ -178,13 +174,13 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
if isinstance(value,
(MeshReorderedScalarFieldProxy,
MeshReorderedMatrixFieldProxy)) and not skip_reordered:
assert index_dim == 1

reordered_index = tuple([
Expr(
_ti_core.get_index_conversion(value.mesh_ptr,
value.element_type,
Expr(indices[0]).ptr,
ConvType.g2r))
ast_builder.mesh_index_conversion(value.mesh_ptr,
value.element_type,
Expr(indices[0]).ptr,
ConvType.g2r))
])
return subscript(ast_builder,
value,
Expand All @@ -203,29 +199,26 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
raise RuntimeError(
f"Gradient {_var.get_expr_name()} has not been placed, check whether `needs_grad=True`"
)
field_dim = snode.num_active_indices()
if field_dim != index_dim:
raise TaichiIndexError(
f'Field with dim {field_dim} accessed with indices of dim {index_dim}'
)

if isinstance(value, MatrixField):
return make_index_expr(value.ptr, indices_expr_group)
return Expr(
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved
ast_builder.expr_subscript(
value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))
if isinstance(value, StructField):
entries = {
k: subscript(ast_builder, v, *indices)
for k, v in value._items
}
entries['__struct_methods'] = value.struct_methods
return _IntermediateStruct(entries)
return make_index_expr(_var, indices_expr_group)
return Expr(
ast_builder.expr_subscript(_var, indices_expr_group,
get_runtime().get_current_src_info()))
if isinstance(value, AnyArray):
dim = _ti_core.get_external_tensor_dim(value.ptr)
element_dim = len(value.element_shape())
if dim != index_dim + element_dim:
raise IndexError(
f'Field with dim {dim - element_dim} accessed with indices of dim {index_dim}'
)
return make_index_expr(value.ptr, indices_expr_group)
return Expr(
ast_builder.expr_subscript(value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))
assert isinstance(value, Expr)
# Index into TensorType
# value: IndexExpression with ret_type = TensorType
Expand All @@ -249,18 +242,14 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
make_expr_group(i, j) for i in indices[0] for j in indices[1]
]
return_shape = (len(indices[0]), len(indices[1]))

return Expr(
_ti_core.subscript_with_multiple_indices(
value.ptr, multiple_indices, return_shape,
get_runtime().get_current_src_info()))
return make_index_expr(value.ptr, indices_expr_group)


@taichi_scope
def make_index_expr(_var, indices_expr_group):
return Expr(
_ti_core.subscript(_var, indices_expr_group,
get_runtime().get_current_src_info()))
ast_builder.expr_subscript(value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))


class SrcInfoGuard:
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def func_call_rvalue(self, key, args):
impl.Expr) and args[i].ptr.is_tensor():
non_template_args.extend([
Expr(x) for x in impl.get_runtime().prog.
current_ast_builder().expand_expr([args[i].ptr])
current_ast_builder().expand_exprs([args[i].ptr])
])
else:
non_template_args.append(args[i])
Expand Down
4 changes: 2 additions & 2 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1503,7 +1503,7 @@ def __call__(self, *args):
elif isinstance(x, impl.Expr) and x.ptr.is_tensor():
entries += [
impl.Expr(e) for e in impl.get_runtime().prog.
current_ast_builder().expand_expr([x.ptr])
current_ast_builder().expand_exprs([x.ptr])
]
elif isinstance(x, Matrix):
entries += x.entries
Expand Down Expand Up @@ -1615,7 +1615,7 @@ def __call__(self, *args):
elif isinstance(x, impl.Expr) and x.ptr.is_tensor():
entries += [
impl.Expr(e) for e in impl.get_runtime().prog.
current_ast_builder().expand_expr([x.ptr])
current_ast_builder().expand_exprs([x.ptr])
]
else:
entries.append(x)
Expand Down
16 changes: 10 additions & 6 deletions python/taichi/lang/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,14 +605,17 @@ def _TetMesh():
class MeshElementFieldProxy:
def __init__(self, mesh: MeshInstance, element_type: MeshElementType,
entry_expr: impl.Expr):
ast_builder = impl.get_runtime().prog.current_ast_builder()

self.mesh = mesh
self.element_type = element_type
self.entry_expr = entry_expr

element_field = self.mesh.fields[self.element_type]
for key, attr in element_field.field_dict.items():

global_entry_expr = impl.Expr(
_ti_core.get_index_conversion(
ast_builder.mesh_index_conversion(
self.mesh.mesh_ptr, element_type, entry_expr,
ConvType.l2r if element_field.attr_dict[key].reorder else
ConvType.l2g)) # transform index space
Expand All @@ -622,7 +625,7 @@ def __init__(self, mesh: MeshInstance, element_type: MeshElementType,
setattr(
self, key,
impl.Expr(
_ti_core.subscript(
ast_builder.expr_subscript(
attr.ptr, global_entry_expr_group,
impl.get_runtime().get_current_src_info())))
elif isinstance(attr, StructField):
Expand All @@ -633,7 +636,7 @@ def __init__(self, mesh: MeshInstance, element_type: MeshElementType,
setattr(
self, key,
impl.Expr(
_ti_core.subscript(
ast_builder.expr_subscript(
var, global_entry_expr_group,
impl.get_runtime().get_current_src_info())))

Expand All @@ -650,10 +653,11 @@ def ptr(self):

@property
def id(self): # return the global non-reordered index
ast_builder = impl.get_runtime().prog.current_ast_builder()
l2g_expr = impl.Expr(
_ti_core.get_index_conversion(self.mesh.mesh_ptr,
self.element_type, self.entry_expr,
ConvType.l2g))
ast_builder.mesh_index_conversion(self.mesh.mesh_ptr,
self.element_type,
self.entry_expr, ConvType.l2g))
return l2g_expr


Expand Down
7 changes: 5 additions & 2 deletions python/taichi/lang/simt/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,8 @@ def __init__(self, shape, dtype):

@taichi_scope
def subscript(self, *indices):
return impl.make_index_expr(self.shared_array_proxy,
make_expr_group(*indices))
ast_builder = impl.get_runtime().prog.current_ast_builder()
return impl.Expr(
ast_builder.expr_subscript(
self.shared_array_proxy, make_expr_group(*indices),
impl.get_runtime().get_current_src_info()))
4 changes: 4 additions & 0 deletions taichi/common/exceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class TaichiSyntaxError : public TaichiExceptionImpl {
using TaichiExceptionImpl::TaichiExceptionImpl;
};

class TaichiIndexError : public TaichiExceptionImpl {
using TaichiExceptionImpl::TaichiExceptionImpl;
};

class TaichiRuntimeError : public TaichiExceptionImpl {
using TaichiExceptionImpl::TaichiExceptionImpl;
};
Expand Down
6 changes: 0 additions & 6 deletions taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@ Expr bit_cast(const Expr &input, DataType dt) {
return Expr::make<UnaryOpExpression>(UnaryOpType::cast_bits, input, dt);
}

Expr Expr::operator[](const ExprGroup &indices) const {
TI_ASSERT(is<FieldExpression>() || is<MatrixFieldExpression>() ||
is<ExternalTensorExpression>() || is_tensor(expr->ret_type));
return Expr::make<IndexExpression>(*this, indices);
}

Expr &Expr::operator=(const Expr &o) {
set(o);
return *this;
Expand Down
2 changes: 0 additions & 2 deletions taichi/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ class Expr {
// std::variant<Expr, std::string> in FrontendPrintStmt.
Expr &operator=(const Expr &o);

Expr operator[](const ExprGroup &indices) const;

template <typename T, typename... Args>
static Expr make(Args &&...args) {
return Expr(std::make_shared<T>(std::forward<Args>(args)...));
Expand Down
Loading