diff --git a/python/taichi/lang/_texture.py b/python/taichi/lang/_texture.py index 9e1c33b467e29..ceb0e4be33950 100644 --- a/python/taichi/lang/_texture.py +++ b/python/taichi/lang/_texture.py @@ -13,7 +13,8 @@ def _get_entries(mat): if isinstance(mat, Matrix): return mat.entries assert isinstance(mat, Expr) and mat.is_tensor() - return impl.get_runtime().prog.current_ast_builder().expand_expr([mat.ptr]) + return impl.get_runtime().prog.current_ast_builder().expand_exprs( + [mat.ptr]) class TextureSampler: diff --git a/python/taichi/lang/any_array.py b/python/taichi/lang/any_array.py index c7c44ba20e244..4b980797cd243 100644 --- a/python/taichi/lang/any_array.py +++ b/python/taichi/lang/any_array.py @@ -78,6 +78,8 @@ def __init__(self, arr, indices_first): @taichi_scope def subscript(self, i, j): + ast_builder = impl.get_runtime().prog.current_ast_builder() + indices_second = (i, ) if len(self.arr.element_shape()) == 1 else (i, j) if self.arr.layout() == Layout.SOA: @@ -85,8 +87,9 @@ def subscript(self, i, j): else: indices = self.indices_first + indices_second return Expr( - _ti_core.subscript(self.arr.ptr, make_expr_group(*indices), - impl.get_runtime().get_current_src_info())) + ast_builder.expr_subscript( + self.arr.ptr, make_expr_group(*indices), + impl.get_runtime().get_current_src_info())) __all__ = [] diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 221ef10e7fdba..8e7d13d5215bc 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -15,7 +15,7 @@ ReturnStatus) from taichi.lang.ast.symbol_resolver import ASTResolver from taichi.lang.exception import (TaichiIndexError, TaichiSyntaxError, - TaichiTypeError) + TaichiTypeError, handle_exception_from_cpp) from taichi.lang.expr import Expr, make_expr_group from taichi.lang.field import Field from taichi.lang.matrix import Matrix, MatrixType, Vector, is_vector @@ -156,7 +156,7 @@ def build_assign_unpack(ctx, node_target, values, is_static_assign): raise ValueError( 'Matrices with more than one columns cannot be unpacked') - values = ctx.ast_builder.expand_expr([values.ptr]) + values = ctx.ast_builder.expand_exprs([values.ptr]) if len(values) == 1: values = values[0] @@ -302,7 +302,7 @@ def process_generators(ctx, node, now_comp, func, result): if isinstance(_iter, impl.Expr) and _iter.ptr.is_tensor(): shape = _iter.ptr.get_shape() flattened = [ - Expr(x) for x in ctx.ast_builder.expand_expr([_iter.ptr]) + Expr(x) for x in ctx.ast_builder.expand_exprs([_iter.ptr]) ] _iter = reshape_list(flattened, shape) @@ -514,7 +514,7 @@ def build_Call(ctx, node): # Expand Expr with Matrix-type return into list of Exprs arg_list = [ Expr(x) - for x in ctx.ast_builder.expand_expr([arg_list.ptr]) + for x in ctx.ast_builder.expand_exprs([arg_list.ptr]) ] for i in arg_list: @@ -730,7 +730,7 @@ def build_Return(ctx, node): elif isinstance(ctx.func.return_type, MatrixType): values = node.value.ptr if isinstance(values, Expr) and values.ptr.is_tensor(): - values = ctx.ast_builder.expand_expr([values.ptr]) + values = ctx.ast_builder.expand_exprs([values.ptr]) else: assert isinstance(values, Matrix) values = itertools.chain.from_iterable(values.to_list()) if\ @@ -819,12 +819,15 @@ def build_Attribute(ctx, node): # we continue to process it as a normal attribute node. try: build_stmt(ctx, node.value) - except TaichiIndexError as e: - node.value.ptr = None - if ASTTransformer.build_attribute_if_is_dynamic_snode_method( - ctx, node): - return node.ptr + except Exception as e: + e = handle_exception_from_cpp(e) + if isinstance(e, TaichiIndexError): + node.value.ptr = None + if ASTTransformer.build_attribute_if_is_dynamic_snode_method( + ctx, node): + return node.ptr raise e + if ASTTransformer.build_attribute_if_is_dynamic_snode_method( ctx, node): return node.ptr @@ -837,11 +840,11 @@ def build_Attribute(ctx, node): node.attr) attr_len = len(node.attr) if attr_len == 1: - node.ptr = Expr( - _ti_core.subscript( - node.value.ptr.ptr, - make_expr_group(keygroup.index(node.attr)), - impl.get_runtime().get_current_src_info())) + node.ptr = Expr(impl.get_runtime( + ).prog.current_ast_builder().expr_subscript( + node.value.ptr.ptr, + make_expr_group(keygroup.index(node.attr)), + impl.get_runtime().get_current_src_info())) else: node.ptr = Expr( _ti_core.subscript_with_multiple_indices( diff --git a/python/taichi/lang/exception.py b/python/taichi/lang/exception.py index 319ab0f4b5102..5d66bf47ad5be 100644 --- a/python/taichi/lang/exception.py +++ b/python/taichi/lang/exception.py @@ -56,6 +56,8 @@ def handle_exception_from_cpp(exc): return TaichiTypeError(str(exc)) if isinstance(exc, core.TaichiSyntaxError): return TaichiSyntaxError(str(exc)) + if isinstance(exc, core.TaichiIndexError): + return TaichiIndexError(str(exc)) if isinstance(exc, core.TaichiAssertionError): return TaichiAssertionError(str(exc)) return exc diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 1fe288585e1ed..29fcf8a436630 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -170,7 +170,7 @@ def _get_flattened_ptrs(val): ptrs.extend(_get_flattened_ptrs(item)) return ptrs if isinstance(val, Expr) and val.ptr.is_tensor(): - return impl.get_runtime().prog.current_ast_builder().expand_expr( + return impl.get_runtime().prog.current_ast_builder().expand_exprs( [val.ptr]) return [Expr(val).ptr] diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index ca99c9fa22465..330519c0e58d9 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -9,9 +9,8 @@ from taichi.lang._texture import RWTextureAccessor from taichi.lang.any_array import AnyArray from taichi.lang.enums import SNodeGradType -from taichi.lang.exception import (TaichiCompilationError, TaichiIndexError, - TaichiRuntimeError, TaichiSyntaxError, - TaichiTypeError) +from taichi.lang.exception import (TaichiCompilationError, TaichiRuntimeError, + TaichiSyntaxError, TaichiTypeError) from taichi.lang.expr import Expr, make_expr_group from taichi.lang.field import Field, ScalarField from taichi.lang.kernel_arguments import SparseMatrixProxy @@ -132,6 +131,7 @@ def check_validity(x): @taichi_scope def subscript(ast_builder, value, *_indices, skip_reordered=False): + ast_builder = get_runtime().prog.current_ast_builder() # Directly evaluate in Python for non-Taichi types if not isinstance( value, @@ -150,9 +150,6 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): elif isinstance(_index, slice): ind = [_index] has_slice = True - elif isinstance(_index, Expr) and _index.is_tensor(): - # Expand Expr with TensorType return - ind = [Expr(e) for e in ast_builder.expand_expr([_index.ptr])] else: ind = [_index] flattened_indices += ind @@ -167,7 +164,6 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): f"The type {type(value)} do not support index of slice type") else: indices_expr_group = make_expr_group(*indices) - index_dim = indices_expr_group.size() if isinstance(value, SharedArray): return value.subscript(*indices) @@ -178,13 +174,13 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): if isinstance(value, (MeshReorderedScalarFieldProxy, MeshReorderedMatrixFieldProxy)) and not skip_reordered: - assert index_dim == 1 + reordered_index = tuple([ Expr( - _ti_core.get_index_conversion(value.mesh_ptr, - value.element_type, - Expr(indices[0]).ptr, - ConvType.g2r)) + ast_builder.mesh_index_conversion(value.mesh_ptr, + value.element_type, + Expr(indices[0]).ptr, + ConvType.g2r)) ]) return subscript(ast_builder, value, @@ -203,13 +199,12 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): raise RuntimeError( f"Gradient {_var.get_expr_name()} has not been placed, check whether `needs_grad=True`" ) - field_dim = snode.num_active_indices() - if field_dim != index_dim: - raise TaichiIndexError( - f'Field with dim {field_dim} accessed with indices of dim {index_dim}' - ) + if isinstance(value, MatrixField): - return make_index_expr(value.ptr, indices_expr_group) + return Expr( + ast_builder.expr_subscript( + value.ptr, indices_expr_group, + get_runtime().get_current_src_info())) if isinstance(value, StructField): entries = { k: subscript(ast_builder, v, *indices) @@ -217,15 +212,13 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): } entries['__struct_methods'] = value.struct_methods return _IntermediateStruct(entries) - return make_index_expr(_var, indices_expr_group) + return Expr( + ast_builder.expr_subscript(_var, indices_expr_group, + get_runtime().get_current_src_info())) if isinstance(value, AnyArray): - dim = _ti_core.get_external_tensor_dim(value.ptr) - element_dim = len(value.element_shape()) - if dim != index_dim + element_dim: - raise IndexError( - f'Field with dim {dim - element_dim} accessed with indices of dim {index_dim}' - ) - return make_index_expr(value.ptr, indices_expr_group) + return Expr( + ast_builder.expr_subscript(value.ptr, indices_expr_group, + get_runtime().get_current_src_info())) assert isinstance(value, Expr) # Index into TensorType # value: IndexExpression with ret_type = TensorType @@ -249,18 +242,14 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): make_expr_group(i, j) for i in indices[0] for j in indices[1] ] return_shape = (len(indices[0]), len(indices[1])) + return Expr( _ti_core.subscript_with_multiple_indices( value.ptr, multiple_indices, return_shape, get_runtime().get_current_src_info())) - return make_index_expr(value.ptr, indices_expr_group) - - -@taichi_scope -def make_index_expr(_var, indices_expr_group): return Expr( - _ti_core.subscript(_var, indices_expr_group, - get_runtime().get_current_src_info())) + ast_builder.expr_subscript(value.ptr, indices_expr_group, + get_runtime().get_current_src_info())) class SrcInfoGuard: diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 670a0639d02b7..e933f8de190f1 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -262,7 +262,7 @@ def func_call_rvalue(self, key, args): impl.Expr) and args[i].ptr.is_tensor(): non_template_args.extend([ Expr(x) for x in impl.get_runtime().prog. - current_ast_builder().expand_expr([args[i].ptr]) + current_ast_builder().expand_exprs([args[i].ptr]) ]) else: non_template_args.append(args[i]) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index a4291535c0ddf..7dec71786400d 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1503,7 +1503,7 @@ def __call__(self, *args): elif isinstance(x, impl.Expr) and x.ptr.is_tensor(): entries += [ impl.Expr(e) for e in impl.get_runtime().prog. - current_ast_builder().expand_expr([x.ptr]) + current_ast_builder().expand_exprs([x.ptr]) ] elif isinstance(x, Matrix): entries += x.entries @@ -1616,7 +1616,7 @@ def __call__(self, *args): elif isinstance(x, impl.Expr) and x.ptr.is_tensor(): entries += [ impl.Expr(e) for e in impl.get_runtime().prog. - current_ast_builder().expand_expr([x.ptr]) + current_ast_builder().expand_exprs([x.ptr]) ] else: entries.append(x) diff --git a/python/taichi/lang/mesh.py b/python/taichi/lang/mesh.py index c1d7169c50c1e..4de03673abdaa 100644 --- a/python/taichi/lang/mesh.py +++ b/python/taichi/lang/mesh.py @@ -605,14 +605,17 @@ def _TetMesh(): class MeshElementFieldProxy: def __init__(self, mesh: MeshInstance, element_type: MeshElementType, entry_expr: impl.Expr): + ast_builder = impl.get_runtime().prog.current_ast_builder() + self.mesh = mesh self.element_type = element_type self.entry_expr = entry_expr element_field = self.mesh.fields[self.element_type] for key, attr in element_field.field_dict.items(): + global_entry_expr = impl.Expr( - _ti_core.get_index_conversion( + ast_builder.mesh_index_conversion( self.mesh.mesh_ptr, element_type, entry_expr, ConvType.l2r if element_field.attr_dict[key].reorder else ConvType.l2g)) # transform index space @@ -622,7 +625,7 @@ def __init__(self, mesh: MeshInstance, element_type: MeshElementType, setattr( self, key, impl.Expr( - _ti_core.subscript( + ast_builder.expr_subscript( attr.ptr, global_entry_expr_group, impl.get_runtime().get_current_src_info()))) elif isinstance(attr, StructField): @@ -633,7 +636,7 @@ def __init__(self, mesh: MeshInstance, element_type: MeshElementType, setattr( self, key, impl.Expr( - _ti_core.subscript( + ast_builder.expr_subscript( var, global_entry_expr_group, impl.get_runtime().get_current_src_info()))) @@ -650,10 +653,11 @@ def ptr(self): @property def id(self): # return the global non-reordered index + ast_builder = impl.get_runtime().prog.current_ast_builder() l2g_expr = impl.Expr( - _ti_core.get_index_conversion(self.mesh.mesh_ptr, - self.element_type, self.entry_expr, - ConvType.l2g)) + ast_builder.mesh_index_conversion(self.mesh.mesh_ptr, + self.element_type, + self.entry_expr, ConvType.l2g)) return l2g_expr diff --git a/python/taichi/lang/simt/block.py b/python/taichi/lang/simt/block.py index a9fc6af89eb94..c1154793df23f 100644 --- a/python/taichi/lang/simt/block.py +++ b/python/taichi/lang/simt/block.py @@ -54,5 +54,8 @@ def __init__(self, shape, dtype): @taichi_scope def subscript(self, *indices): - return impl.make_index_expr(self.shared_array_proxy, - make_expr_group(*indices)) + ast_builder = impl.get_runtime().prog.current_ast_builder() + return impl.Expr( + ast_builder.expr_subscript( + self.shared_array_proxy, make_expr_group(*indices), + impl.get_runtime().get_current_src_info())) diff --git a/taichi/common/exceptions.h b/taichi/common/exceptions.h index 5ad9d8d789609..eb3570ffe23b9 100644 --- a/taichi/common/exceptions.h +++ b/taichi/common/exceptions.h @@ -23,6 +23,10 @@ class TaichiSyntaxError : public TaichiExceptionImpl { using TaichiExceptionImpl::TaichiExceptionImpl; }; +class TaichiIndexError : public TaichiExceptionImpl { + using TaichiExceptionImpl::TaichiExceptionImpl; +}; + class TaichiRuntimeError : public TaichiExceptionImpl { using TaichiExceptionImpl::TaichiExceptionImpl; }; diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index 767d05ef5eeef..da2c541fdf1de 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -26,12 +26,6 @@ Expr bit_cast(const Expr &input, DataType dt) { return Expr::make(UnaryOpType::cast_bits, input, dt); } -Expr Expr::operator[](const ExprGroup &indices) const { - TI_ASSERT(is() || is() || - is() || is_tensor(expr->ret_type)); - return Expr::make(*this, indices); -} - Expr &Expr::operator=(const Expr &o) { set(o); return *this; diff --git a/taichi/ir/expr.h b/taichi/ir/expr.h index 0d5f974c7aa0a..b8058c7f1ad22 100644 --- a/taichi/ir/expr.h +++ b/taichi/ir/expr.h @@ -83,8 +83,6 @@ class Expr { // std::variant in FrontendPrintStmt. Expr &operator=(const Expr &o); - Expr operator[](const ExprGroup &indices) const; - template static Expr make(Args &&...args) { return Expr(std::make_shared(std::forward(args)...)); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 4a81f56287141..b440a316bc172 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -25,7 +25,7 @@ FrontendSNodeOpStmt::FrontendSNodeOpStmt(ASTBuilder *builder, const Expr &val) : op_type(op_type), snode(snode), val(val) { this->indices = indices; - std::vector expanded_exprs = builder->expand_expr(this->indices.exprs); + std::vector expanded_exprs = builder->expand_exprs(this->indices.exprs); this->indices.exprs = expanded_exprs; if (val.expr != nullptr) { @@ -688,6 +688,26 @@ void MatrixExpression::flatten(FlattenContext *ctx) { stmt->ret_type = this->dt; } +IndexExpression::IndexExpression(const Expr &var, + const ExprGroup &indices, + std::string tb) + : var(var), indices_group({indices}) { + this->tb = tb; +} + +IndexExpression::IndexExpression(const Expr &var, + const std::vector &indices_group, + const std::vector &ret_shape, + std::string tb) + : var(var), indices_group(indices_group), ret_shape(ret_shape) { + // IndexExpression with ret_shape is used for matrix slicing, where each entry + // of ExprGroup is interpreted as a group of indices to return within each + // axis. For example, mat[0, 3:5] has indices_group={0, [3, 4]}, where [3, 4] + // means "m"-axis will return a TensorType with size of 2. In this case, we + // should not expand indices_group due to its special semantics. + this->tb = tb; +} + bool IndexExpression::is_field() const { return var.is(); } @@ -721,20 +741,43 @@ bool IndexExpression::is_global() const { return is_field() || is_matrix_field() || is_ndarray(); } +static void field_validation(FieldExpression *field_expr, int index_dim) { + TI_ASSERT(field_expr != nullptr); + TI_ASSERT(field_expr->snode != nullptr); + int field_dim = field_expr->snode->num_active_indices; + + if (field_dim != index_dim) { + throw TaichiIndexError( + fmt::format("Field with dim {} accessed with indices of dim {}", + field_dim, index_dim)); + } +} + void IndexExpression::type_check(CompileConfig *) { // TODO: Change to type-based solution // Currently, dimension compatibility check happens in Python TI_ASSERT(indices_group.size() == std::accumulate(begin(ret_shape), end(ret_shape), 1, std::multiplies<>())); - if (!ret_shape.empty()) { + int index_dim = indices_group.empty() ? 0 : indices_group[0].size(); + bool has_slice = !ret_shape.empty(); + if (has_slice) { TI_ASSERT_INFO(is_tensor(), "Slice or swizzle can only apply on matrices"); auto element_type = var->ret_type->as()->get_element_type(); ret_type = TypeFactory::create_tensor_type(ret_shape, element_type); + } else if (is_field()) { // field - ret_type = var.cast()->dt->get_compute_type(); + auto field_expr = var.cast(); + field_validation(field_expr.get(), index_dim); + ret_type = field_expr->dt->get_compute_type(); + } else if (is_matrix_field()) { auto matrix_field_expr = var.cast(); + + TI_ASSERT(!matrix_field_expr->fields.empty()); + auto field_expr = matrix_field_expr->fields[0].cast(); + field_validation(field_expr.get(), index_dim); + ret_type = TypeFactory::create_tensor_type(matrix_field_expr->element_shape, matrix_field_expr->fields[0] .cast() @@ -742,7 +785,12 @@ void IndexExpression::type_check(CompileConfig *) { } else if (is_ndarray()) { // ndarray auto external_tensor_expr = var.cast(); int total_dim = external_tensor_expr->dim; - int index_dim = indices_group[0].exprs.size(); + int element_dim = external_tensor_expr->dt.get_shape().size(); + if (total_dim != index_dim + element_dim) { + throw TaichiTypeError( + fmt::format("Array with dim {} accessed with indices of dim {}", + total_dim - element_dim, index_dim)); + } if (index_dim == total_dim) { // Access all the way to a single element @@ -910,7 +958,7 @@ SNodeOpExpression::SNodeOpExpression(ASTBuilder *builder, SNodeOpType op_type, const ExprGroup &indices) : snode(snode), op_type(op_type) { - std::vector expanded_indices = builder->expand_expr(indices.exprs); + std::vector expanded_indices = builder->expand_exprs(indices.exprs); this->indices = indices; this->indices.exprs = std::move(expanded_indices); } @@ -921,7 +969,7 @@ SNodeOpExpression::SNodeOpExpression(ASTBuilder *builder, const ExprGroup &indices, const std::vector &values) : SNodeOpExpression(builder, snode, op_type, indices) { - this->values = builder->expand_expr(values); + this->values = builder->expand_exprs(values); } void SNodeOpExpression::type_check(CompileConfig *config) { @@ -1163,6 +1211,14 @@ void MeshRelationAccessExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } +MeshIndexConversionExpression::MeshIndexConversionExpression( + mesh::Mesh *mesh, + mesh::MeshElementType idx_type, + const Expr idx, + mesh::ConvType conv_type) + : mesh(mesh), idx_type(idx_type), idx(idx), conv_type(conv_type) { +} + void MeshIndexConversionExpression::type_check(CompileConfig *) { ret_type = PrimitiveType::i32; } @@ -1359,6 +1415,25 @@ void ASTBuilder::expr_assign(const Expr &lhs, const Expr &rhs, std::string tb) { this->insert(std::move(stmt)); } +Expr ASTBuilder::expr_subscript(const Expr &expr, + const ExprGroup &indices, + std::string tb) { + TI_ASSERT(expr.is() || expr.is() || + expr.is() || + is_tensor(expr.expr->ret_type)); + + // IndexExpression without ret_shape is used for matrix indexing, + // where each entry of ExprGroup is interpreted as indexing into a specific + // axis. For example, mat[3, 4] has indices_group={[3, 4]}, where [3, 4] + // corresponds to "n"-axis and "m"-axis of the matrix. Therefore we expand + // indices_group={[3, 4]} into {3, 4} to avoid TensorType in indices. + std::vector expanded_indices = this->expand_exprs(indices.exprs); + auto expanded_expr_group = ExprGroup(); + expanded_expr_group.exprs = expanded_indices; + + return Expr::make(expr, expanded_expr_group, tb); +} + void ASTBuilder::create_assert_stmt(const Expr &cond, const std::string &msg, const std::vector &args) { @@ -1479,62 +1554,82 @@ Expr ASTBuilder::snode_get_addr(SNode *snode, const ExprGroup &indices) { indices); } -std::vector ASTBuilder::expand_expr(const std::vector &exprs) { - if (exprs.size() > 1 || exprs.size() == 0) { +std::vector ASTBuilder::expand_exprs(const std::vector &exprs) { + if (exprs.size() == 0) { return exprs; } - Expr index_expr = exprs[0]; - TI_ASSERT_TYPE_CHECKED(index_expr); - if (!index_expr->ret_type->is()) { - return exprs; - } - - // Expand TensorType expr - /* - Before: - TensorType<4 x i32> index = Expr; - - After: - TensorType<4 x i32>* id_expr = FrontendAllocaStmt(TensorType<4 x i32>) - i32 ind0 = IndexExpression(id_expr, 0) - i32 ind1 = IndexExpression(id_expr, 1) - i32 ind2 = IndexExpression(id_expr, 2) - i32 ind3 = IndexExpression(id_expr, 3) - - return {ind0, ind1, ind2, ind3} - - */ std::vector expanded_exprs; + for (auto expr : exprs) { + TI_ASSERT_TYPE_CHECKED(expr); + if (!expr->ret_type->is()) { + expanded_exprs.push_back(expr); + } else { + // Expand TensorType expr + /* + Before: + TensorType<4 x i32> index = Expr; + + After: + TensorType<4 x i32>* id_expr = FrontendAllocaStmt(TensorType<4 x i32>) + i32 ind0 = IndexExpression(id_expr, 0) + i32 ind1 = IndexExpression(id_expr, 1) + i32 ind2 = IndexExpression(id_expr, 2) + i32 ind3 = IndexExpression(id_expr, 3) + + return {ind0, ind1, ind2, ind3} + + */ + auto tensor_type = expr->ret_type->cast(); + + Expr id_expr; + if (expr.is()) { + id_expr = expr; + } else { + id_expr = make_var(expr, expr->tb); + } + auto shape = tensor_type->get_shape(); + if (shape.size() == 1) { + for (int i = 0; i < shape[0]; i++) { + auto ind = Expr(std::make_shared( + id_expr, ExprGroup(Expr(i)), expr->tb)); + ind.expr->ret_type = tensor_type->get_element_type(); + expanded_exprs.push_back(ind); + } + } else { + TI_ASSERT(shape.size() == 2); + for (int i = 0; i < shape[0]; i++) { + for (int j = 0; j < shape[1]; j++) { + auto ind = Expr(std::make_shared( + id_expr, ExprGroup(Expr(i), Expr(j)), expr->tb)); + ind.expr->ret_type = tensor_type->get_element_type(); + expanded_exprs.push_back(ind); + } + } + } + } + } - auto tensor_type = index_expr->ret_type->cast(); + return expanded_exprs; +} - Expr id_expr; - if (index_expr.is()) { - id_expr = index_expr; +Expr ASTBuilder::mesh_index_conversion(mesh::MeshPtr mesh_ptr, + mesh::MeshElementType idx_type, + const Expr &idx, + mesh::ConvType &conv_type) { + Expr expanded_idx; + if (idx.is() && idx.get_ret_type() == PrimitiveType::unknown) { + expanded_idx = idx; } else { - id_expr = make_var(index_expr, index_expr->tb); - } - auto shape = tensor_type->get_shape(); - if (shape.size() == 1) { - for (int i = 0; i < shape[0]; i++) { - auto ind = Expr(std::make_shared( - id_expr, ExprGroup(Expr(i)), index_expr->tb)); - ind.expr->ret_type = tensor_type->get_element_type(); - expanded_exprs.push_back(ind); - } - } else { - TI_ASSERT(shape.size() == 2); - for (int i = 0; i < shape[0]; i++) { - for (int j = 0; j < shape[1]; j++) { - auto ind = Expr(std::make_shared( - id_expr, ExprGroup(Expr(i), Expr(j)), index_expr->tb)); - ind.expr->ret_type = tensor_type->get_element_type(); - expanded_exprs.push_back(ind); - } + if (idx.expr->ret_type->is()) { + TI_ASSERT(idx.expr->ret_type->cast()->get_num_elements() == + 1); } + expanded_idx = this->expand_exprs({idx})[0]; } - return expanded_exprs; + + return Expr::make(mesh_ptr.ptr.get(), idx_type, + expanded_idx, conv_type); } void ASTBuilder::create_scope(std::unique_ptr &list, LoopType tp) { diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 955dd1d7c69fc..9e03107090bbe 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -590,18 +590,12 @@ class IndexExpression : public Expression { IndexExpression(const Expr &var, const ExprGroup &indices, - std::string tb = "") - : var(var), indices_group({indices}) { - this->tb = tb; - } + std::string tb = ""); IndexExpression(const Expr &var, const std::vector &indices_group, const std::vector &ret_shape, - std::string tb = "") - : var(var), indices_group(indices_group), ret_shape(ret_shape) { - this->tb = tb; - } + std::string tb = ""); void type_check(CompileConfig *config) override; @@ -868,9 +862,7 @@ class MeshIndexConversionExpression : public Expression { MeshIndexConversionExpression(mesh::Mesh *mesh, mesh::MeshElementType idx_type, const Expr idx, - mesh::ConvType conv_type) - : mesh(mesh), idx_type(idx_type), idx(idx), conv_type(conv_type) { - } + mesh::ConvType conv_type); void flatten(FlattenContext *ctx) override; @@ -960,6 +952,15 @@ class ASTBuilder { Expr expr_alloca(); Expr expr_alloca_shared_array(const std::vector &shape, const DataType &element_type); + Expr expr_subscript(const Expr &expr, + const ExprGroup &indices, + std::string tb = ""); + + Expr mesh_index_conversion(mesh::MeshPtr mesh_ptr, + mesh::MeshElementType idx_type, + const Expr &idx, + mesh::ConvType &conv_type); + void expr_assign(const Expr &lhs, const Expr &rhs, std::string tb); void create_assert_stmt(const Expr &cond, const std::string &msg, @@ -995,7 +996,7 @@ class ASTBuilder { Expr snode_length(SNode *snode, const ExprGroup &indices); Expr snode_get_addr(SNode *snode, const ExprGroup &indices); - std::vector expand_expr(const std::vector &exprs); + std::vector expand_exprs(const std::vector &exprs); void create_scope(std::unique_ptr &list, LoopType tp = NotLoop); void pop_scope(); diff --git a/taichi/ir/ir_builder.cpp b/taichi/ir/ir_builder.cpp index 15fbc2587a5e4..0d4d77f9662b7 100644 --- a/taichi/ir/ir_builder.cpp +++ b/taichi/ir/ir_builder.cpp @@ -475,15 +475,6 @@ MeshRelationAccessStmt *IRBuilder::get_relation_access( mesh, mesh_idx, to_type, neighbor_idx)); } -MeshIndexConversionStmt *IRBuilder::get_index_conversion( - mesh::Mesh *mesh, - mesh::MeshElementType idx_type, - Stmt *idx, - mesh::ConvType conv_type) { - return insert(Stmt::make_typed(mesh, idx_type, idx, - conv_type)); -} - MeshPatchIndexStmt *IRBuilder::get_patch_index() { return insert(Stmt::make_typed()); } diff --git a/taichi/ir/ir_builder.h b/taichi/ir/ir_builder.h index 316a2c14a39d2..08bdd1e83b9c1 100644 --- a/taichi/ir/ir_builder.h +++ b/taichi/ir/ir_builder.h @@ -278,10 +278,6 @@ class IRBuilder { Stmt *mesh_idx, mesh::MeshElementType to_type, Stmt *neighbor_idx); - MeshIndexConversionStmt *get_index_conversion(mesh::Mesh *mesh, - mesh::MeshElementType idx_type, - Stmt *idx, - mesh::ConvType conv_type); MeshPatchIndexStmt *get_patch_index(); private: diff --git a/taichi/math/svd.h b/taichi/math/svd.h index d0da10beb3da2..898d030660192 100644 --- a/taichi/math/svd.h +++ b/taichi/math/svd.h @@ -43,7 +43,7 @@ std::tuple sifakis_svd_export(ASTBuilder *ast_builder, const Expr &mat, int num_iters) { - auto expanded_exprs = ast_builder->expand_expr({mat}); + auto expanded_exprs = ast_builder->expand_exprs({mat}); TI_ASSERT(expanded_exprs.size() == 9); Expr a00 = expanded_exprs[0]; diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index 482dce4a5bf6c..5d7df34eb6290 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -362,10 +362,13 @@ Kernel &Program::get_snode_reader(SNode *snode) { auto &ker = kernel([snode, this] { ExprGroup indices; for (int i = 0; i < snode->num_active_indices; i++) { - indices.push_back(Expr::make(i, PrimitiveType::i32)); + auto argload_expr = Expr::make(i, PrimitiveType::i32); + argload_expr->type_check(&this->this_thread_config()); + indices.push_back(std::move(argload_expr)); } - auto ret = Stmt::make( - ExprGroup(Expr(snode_to_fields_.at(snode))[indices])); + ASTBuilder *builder = this->current_ast_builder(); + auto ret = Stmt::make(ExprGroup( + builder->expr_subscript(Expr(snode_to_fields_.at(snode)), indices))); this->current_ast_builder()->insert(std::move(ret)); }); ker.set_arch(get_accessor_arch()); @@ -383,9 +386,13 @@ Kernel &Program::get_snode_writer(SNode *snode) { auto &ker = kernel([snode, this] { ExprGroup indices; for (int i = 0; i < snode->num_active_indices; i++) { - indices.push_back(Expr::make(i, PrimitiveType::i32)); + auto argload_expr = Expr::make(i, PrimitiveType::i32); + argload_expr->type_check(&this->this_thread_config()); + indices.push_back(std::move(argload_expr)); } - auto expr = Expr(snode_to_fields_.at(snode))[indices]; + ASTBuilder *builder = current_ast_builder(); + auto expr = + builder->expr_subscript(Expr(snode_to_fields_.at(snode)), indices); this->current_ast_builder()->insert_assignment( expr, Expr::make(snode->num_active_indices, diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index d958e92a22325..d5485976eca1a 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -42,10 +42,6 @@ bool test_threading(); namespace taichi::lang { -Expr expr_index(const Expr &expr, const Expr &index) { - return expr[ExprGroup(index)]; -} - std::string libdevice_path(); } // namespace taichi::lang @@ -59,6 +55,8 @@ void export_lang(py::module &m) { PyExc_TypeError); py::register_exception(m, "TaichiSyntaxError", PyExc_SyntaxError); + py::register_exception(m, "TaichiIndexError", + PyExc_IndexError); py::register_exception(m, "TaichiRuntimeError", PyExc_RuntimeError); py::register_exception(m, "TaichiAssertionError", @@ -315,7 +313,9 @@ void export_lang(py::module &m) { .def("insert_expr_stmt", &ASTBuilder::insert_expr_stmt) .def("insert_thread_idx_expr", &ASTBuilder::insert_thread_idx_expr) .def("insert_patch_idx_expr", &ASTBuilder::insert_patch_idx_expr) - .def("expand_expr", &ASTBuilder::expand_expr) + .def("expand_exprs", &ASTBuilder::expand_exprs) + .def("mesh_index_conversion", &ASTBuilder::mesh_index_conversion) + .def("expr_subscript", &ASTBuilder::expr_subscript) .def("sifakis_svd_f32", sifakis_svd_export) .def("sifakis_svd_f64", sifakis_svd_export) .def("expr_var", &ASTBuilder::make_var) @@ -861,8 +861,6 @@ void export_lang(py::module &m) { return Expr::make(AtomicOpType::bit_xor, a, b); }); - m.def("expr_index", expr_index); - m.def("expr_assume_in_range", assume_range); m.def("expr_loop_unique", loop_unique); @@ -993,13 +991,6 @@ void export_lang(py::module &m) { m.def("data_type_name", data_type_name); - m.def("subscript", - [](const Expr &expr, const ExprGroup &expr_group, std::string tb) { - Expr idx_expr = expr[expr_group]; - idx_expr.set_tb(tb); - return idx_expr; - }); - m.def( "subscript_with_multiple_indices", Expr::make &, @@ -1044,13 +1035,6 @@ void export_lang(py::module &m) { mesh_ptr.ptr.get(), mesh_idx, to_type, neighbor_idx); }); - m.def("get_index_conversion", - [](mesh::MeshPtr mesh_ptr, mesh::MeshElementType idx_type, - const Expr &idx, mesh::ConvType &conv_type) { - return Expr::make( - mesh_ptr.ptr.get(), idx_type, idx, conv_type); - }); - py::class_(m, "FunctionKey") .def(py::init()) .def_readonly("instance_id", &FunctionKey::instance_id); diff --git a/tests/cpp/ir/frontend_type_inference_test.cpp b/tests/cpp/ir/frontend_type_inference_test.cpp index 11ad9269f6b67..558a72979b80a 100644 --- a/tests/cpp/ir/frontend_type_inference_test.cpp +++ b/tests/cpp/ir/frontend_type_inference_test.cpp @@ -86,21 +86,39 @@ TEST(FrontendTypeInference, TernaryOp) { } TEST(FrontendTypeInference, GlobalPtr_Field) { + auto prog = std::make_unique(Arch::x64); + auto func = []() {}; + auto kernel = std::make_unique(*prog, func, "fake_kernel"); + Callable::CurrentCallableGuard _(kernel->program, kernel.get()); + auto ast_builder = prog->current_ast_builder(); + auto global_var = Expr::make(PrimitiveType::u8, Identifier(0)); + SNode snode; + snode.num_active_indices = 1; + std::dynamic_pointer_cast(global_var.expr) + ->set_snode(&snode); + auto index = value(2); index->type_check(nullptr); - auto global_ptr = global_var[ExprGroup(index)]; + auto global_ptr = ast_builder->expr_subscript(global_var, ExprGroup(index)); global_ptr->type_check(nullptr); EXPECT_EQ(global_ptr->ret_type, PrimitiveType::u8); } TEST(FrontendTypeInference, GlobalPtr_ExternalTensor) { + auto prog = std::make_unique(Arch::x64); + auto func = []() {}; + auto kernel = std::make_unique(*prog, func, "fake_kernel"); + Callable::CurrentCallableGuard _(kernel->program, kernel.get()); + auto ast_builder = prog->current_ast_builder(); + auto index = value(2); index->type_check(nullptr); auto external_tensor = Expr::make(PrimitiveType::u16, 1, 0, 0); - auto global_ptr = external_tensor[ExprGroup(index)]; + auto global_ptr = + ast_builder->expr_subscript(external_tensor, ExprGroup(index)); EXPECT_THROW(global_ptr->type_check(nullptr), TaichiTypeError); }