Skip to content

Commit

Permalink
[refactor] Remove _IntermediateMatrix and _MatrixFieldElement (taichi…
Browse files Browse the repository at this point in the history
…-dev#6932)

Issue: taichi-dev#5819

### Brief Summary

These two intermediate classes are unnecessary now, so let's remove them
for simplicity.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent dc1a47b commit 8c69313
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 48 deletions.
4 changes: 2 additions & 2 deletions python/taichi/lang/_ndrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from taichi.lang import ops
from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError
from taichi.lang.expr import Expr
from taichi.lang.matrix import _IntermediateMatrix
from taichi.lang.matrix import Matrix
from taichi.types.utils import is_integral


Expand Down Expand Up @@ -144,7 +144,7 @@ def __init__(self, r):

def __iter__(self):
for ind in self.r:
yield _IntermediateMatrix(len(ind), 1, list(ind), ndim=1)
yield Matrix(list(ind))


__all__ = ['ndrange']
43 changes: 0 additions & 43 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,49 +1250,6 @@ def ndarray(cls, n, dtype, shape):
return VectorNdarray(n, dtype, shape)


class _IntermediateMatrix(Matrix):
"""Intermediate matrix class for compiler internal use only.
Args:
n (int): Number of rows of the matrix.
m (int): Number of columns of the matrix.
entries (List[Expr]): All entries of the matrix.
"""
def __init__(self, n, m, entries, ndim=None):
assert isinstance(entries, list)
assert n * m == len(entries), "Number of entries doesn't match n * m"
self.n = n
self.m = m
if ndim is not None:
self.ndim = ndim
else:
if len(entries) == 0:
self.ndim = 0
else:
self.ndim = 2 if isinstance(entries[0], Iterable) else 1
self._impl = _PyScopeMatrixImpl(m, n, entries)


class _MatrixFieldElement(_IntermediateMatrix):
"""Matrix field element class for compiler internal use only.
Args:
field (MatrixField): The matrix field.
indices (taichi_python.ExprGroup): Indices of the element.
"""
def __init__(self, field, indices):
super().__init__(
field.n,
field.m, [
expr.Expr(
ti_python_core.subscript(
e.ptr, indices,
impl.get_runtime().get_current_src_info()))
for e in field._get_field_members()
],
ndim=getattr(field, "ndim", 2))


class MatrixField(Field):
"""Taichi matrix field with SNode implementation.
Expand Down
10 changes: 7 additions & 3 deletions python/taichi/lang/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from taichi.lang.enums import Layout
from taichi.lang.exception import TaichiSyntaxError
from taichi.lang.field import Field, ScalarField
from taichi.lang.matrix import Matrix, MatrixField, _MatrixFieldElement
from taichi.lang.matrix import Matrix, MatrixField
from taichi.lang.struct import StructField
from taichi.lang.util import python_scope
from taichi.types import u16, u32
Expand Down Expand Up @@ -619,8 +619,12 @@ def __init__(self, mesh: MeshInstance, element_type: MeshElementType,
global_entry_expr_group = impl.make_expr_group(
*tuple([global_entry_expr]))
if isinstance(attr, MatrixField):
setattr(self, key,
_MatrixFieldElement(attr, global_entry_expr_group))
setattr(
self, key,
impl.Expr(
_ti_core.subscript(
attr.ptr, global_entry_expr_group,
impl.get_runtime().get_current_src_info())))
elif isinstance(attr, StructField):
raise RuntimeError(
'MeshTaichi has not support StructField yet')
Expand Down

0 comments on commit 8c69313

Please sign in to comment.