diff --git a/python/taichi/lang/_ndrange.py b/python/taichi/lang/_ndrange.py index f414bf5e75cab..d98bf2f8af2fe 100644 --- a/python/taichi/lang/_ndrange.py +++ b/python/taichi/lang/_ndrange.py @@ -4,7 +4,7 @@ from taichi.lang import ops from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError from taichi.lang.expr import Expr -from taichi.lang.matrix import _IntermediateMatrix +from taichi.lang.matrix import Matrix from taichi.types.utils import is_integral @@ -144,7 +144,7 @@ def __init__(self, r): def __iter__(self): for ind in self.r: - yield _IntermediateMatrix(len(ind), 1, list(ind), ndim=1) + yield Matrix(list(ind)) __all__ = ['ndrange'] diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index a99c218008418..8e66c9564dfcc 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1250,49 +1250,6 @@ def ndarray(cls, n, dtype, shape): return VectorNdarray(n, dtype, shape) -class _IntermediateMatrix(Matrix): - """Intermediate matrix class for compiler internal use only. - - Args: - n (int): Number of rows of the matrix. - m (int): Number of columns of the matrix. - entries (List[Expr]): All entries of the matrix. - """ - def __init__(self, n, m, entries, ndim=None): - assert isinstance(entries, list) - assert n * m == len(entries), "Number of entries doesn't match n * m" - self.n = n - self.m = m - if ndim is not None: - self.ndim = ndim - else: - if len(entries) == 0: - self.ndim = 0 - else: - self.ndim = 2 if isinstance(entries[0], Iterable) else 1 - self._impl = _PyScopeMatrixImpl(m, n, entries) - - -class _MatrixFieldElement(_IntermediateMatrix): - """Matrix field element class for compiler internal use only. - - Args: - field (MatrixField): The matrix field. - indices (taichi_python.ExprGroup): Indices of the element. - """ - def __init__(self, field, indices): - super().__init__( - field.n, - field.m, [ - expr.Expr( - ti_python_core.subscript( - e.ptr, indices, - impl.get_runtime().get_current_src_info())) - for e in field._get_field_members() - ], - ndim=getattr(field, "ndim", 2)) - - class MatrixField(Field): """Taichi matrix field with SNode implementation. diff --git a/python/taichi/lang/mesh.py b/python/taichi/lang/mesh.py index f4d721d8276ad..c1d7169c50c1e 100644 --- a/python/taichi/lang/mesh.py +++ b/python/taichi/lang/mesh.py @@ -6,7 +6,7 @@ from taichi.lang.enums import Layout from taichi.lang.exception import TaichiSyntaxError from taichi.lang.field import Field, ScalarField -from taichi.lang.matrix import Matrix, MatrixField, _MatrixFieldElement +from taichi.lang.matrix import Matrix, MatrixField from taichi.lang.struct import StructField from taichi.lang.util import python_scope from taichi.types import u16, u32 @@ -619,8 +619,12 @@ def __init__(self, mesh: MeshInstance, element_type: MeshElementType, global_entry_expr_group = impl.make_expr_group( *tuple([global_entry_expr])) if isinstance(attr, MatrixField): - setattr(self, key, - _MatrixFieldElement(attr, global_entry_expr_group)) + setattr( + self, key, + impl.Expr( + _ti_core.subscript( + attr.ptr, global_entry_expr_group, + impl.get_runtime().get_current_src_info()))) elif isinstance(attr, StructField): raise RuntimeError( 'MeshTaichi has not support StructField yet')