diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 2904c8cda0ae6..a99c218008418 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1,3 +1,4 @@ +import functools import numbers import warnings from collections.abc import Iterable @@ -102,25 +103,44 @@ def prop_setter(instance, value): return cls +def _infer_entry_dt(entry): + if isinstance(entry, (int, np.integer)): + return impl.get_runtime().default_ip + if isinstance(entry, float): + return impl.get_runtime().default_fp + if isinstance(entry, expr.Expr): + dt = entry.ptr.get_ret_type() + if dt == ti_python_core.DataType_unknown: + raise TaichiTypeError( + 'Element type of the matrix cannot be inferred. Please set dt instead for now.' + ) + return dt + raise TaichiTypeError('Element type of the matrix is invalid.') + + +def _infer_array_dt(arr): + assert len(arr) > 0 + return functools.reduce(ti_python_core.promoted_type, + map(_infer_entry_dt, arr)) + + def make_matrix(arr, dt=None): if len(arr) == 0: # the only usage of an empty vector is to serve as field indices - is_matrix = False + shape = [0] dt = primitive_types.i32 else: - is_matrix = isinstance(arr[0], Iterable) + if isinstance(arr[0], Iterable): # matrix + shape = [len(arr), len(arr[0])] + arr = [elt for row in arr for elt in row] + else: # vector + shape = [len(arr)] if dt is None: - dt = _make_entries_initializer(is_matrix).infer_dt(arr) + dt = _infer_array_dt(arr) else: dt = cook_dtype(dt) - if not is_matrix: - return impl.Expr( - impl.make_matrix_expr([len(arr)], dt, - [expr.Expr(elt).ptr for elt in arr])) - return impl.Expr( - impl.make_matrix_expr( - [len(arr), len(arr[0])], dt, - [expr.Expr(elt).ptr for row in arr for elt in row])) + return expr.Expr( + impl.make_matrix_expr(shape, dt, [expr.Expr(elt).ptr for elt in arr])) def is_vector(x): @@ -241,48 +261,6 @@ def _set_entries(self, value): self[i, j] = value[i][j] -class _MatrixEntriesInitializer: - def pyscope(self, arr): - raise NotImplementedError('Override') - - def _get_entry_to_infer(self, arr): - raise NotImplementedError('Override') - - def infer_dt(self, arr): - entry = self._get_entry_to_infer(arr) - if isinstance(entry, (int, np.integer)): - return impl.get_runtime().default_ip - if isinstance(entry, float): - return impl.get_runtime().default_fp - if isinstance(entry, expr.Expr): - dt = entry.ptr.get_ret_type() - if dt == ti_python_core.DataType_unknown: - raise TypeError( - 'Element type of the matrix cannot be inferred. Please set dt instead for now.' - ) - return dt - raise Exception( - 'dt required when using dynamic_index for local tensor') - - -def _make_entries_initializer(is_matrix: bool) -> _MatrixEntriesInitializer: - class _VecImpl(_MatrixEntriesInitializer): - def pyscope(self, arr): - return [[x] for x in arr] - - def _get_entry_to_infer(self, arr): - return arr[0] - - class _MatImpl(_MatrixEntriesInitializer): - def pyscope(self, arr): - return [list(row) for row in arr] - - def _get_entry_to_infer(self, arr): - return arr[0][0] - - return _MatImpl() if is_matrix else _VecImpl() - - @_gen_swizzles class Matrix(TaichiOperations): """The matrix class. @@ -346,13 +324,16 @@ def __init__(self, arr, dt=None, ndim=None): is_matrix = isinstance(arr[0], Iterable) and not is_vector(self) self.ndim = 2 if is_matrix else 1 - initializer = _make_entries_initializer(is_matrix) - if not is_matrix and isinstance(arr[0], Iterable): - flattened = [] - for row in arr: - flattened += row - arr = flattened - mat = initializer.pyscope(arr) + + if is_matrix: + mat = [list(row) for row in arr] + else: + if isinstance(arr[0], Iterable): + flattened = [] + for row in arr: + flattened += row + arr = flattened + mat = [[x] for x in arr] self.n, self.m = len(mat), 1 if len(mat) > 0: diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index c25fa95de7134..c187983f0421a 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -1138,6 +1138,8 @@ void export_lang(py::module &m) { py::class_(m, "Type").def("to_string", &Type::to_string); + m.def("promoted_type", promoted_type); + // Note that it is important to specify py::return_value_policy::reference for // the factory methods, otherwise pybind11 will delete the Types owned by // TypeFactory on Python-scope pointer destruction. diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 54aaa1cc81ff2..2b21f0f3501b2 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -1205,3 +1205,13 @@ def foo() -> ti.types.vector(4, ti.i32): return ti.Vector([a[0, 0], a[0, 1], a[1, 0], a[1, 1]]) assert (foo() == [1, 2, 3, 4]).all() + + +@test_utils.test(debug=True) +def test_matrix_type_inference(): + @ti.kernel + def foo(): + a = ti.Vector([1, 2.5])[1] # should be f32 instead of i32 + assert a == 2.5 + + foo()