diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 330519c0e58d9..e91d79e7651ad 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -229,20 +229,31 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): dim = len(shape) assert dim == len(indices) indices = [ - _calc_slice(index, shape[i]) - if isinstance(index, slice) else [index] + _calc_slice(index, shape[i]) if isinstance(index, slice) else index for i, index in enumerate(indices) ] if dim == 1: + assert isinstance(indices[0], list) multiple_indices = [make_expr_group(i) for i in indices[0]] return_shape = (len(indices[0]), ) else: assert dim == 2 - multiple_indices = [ - make_expr_group(i, j) for i in indices[0] for j in indices[1] - ] - return_shape = (len(indices[0]), len(indices[1])) - + if isinstance(indices[0], list) and isinstance(indices[1], list): + multiple_indices = [ + make_expr_group(i, j) for i in indices[0] + for j in indices[1] + ] + return_shape = (len(indices[0]), len(indices[1])) + elif isinstance(indices[0], list): # indices[1] is not list + multiple_indices = [ + make_expr_group(i, indices[1]) for i in indices[0] + ] + return_shape = (len(indices[0]), ) + else: # indices[0] is not list while indices[1] is list + multiple_indices = [ + make_expr_group(indices[0], j) for j in indices[1] + ] + return_shape = (len(indices[1]), ) return Expr( _ti_core.subscript_with_multiple_indices( value.ptr, multiple_indices, return_shape, diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 573e7a8fd0de0..178d54b62ee53 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -404,15 +404,16 @@ def _linearize_entry_id(self, *args): return args[0] * self.m + args[1] def _get_slice(self, a, b): - if not isinstance(a, slice): - a = [a] - else: + if isinstance(a, slice): a = range(a.start or 0, a.stop or self.n, a.step or 1) - if not isinstance(b, slice): - b = [b] - else: + if isinstance(b, slice): b = range(b.start or 0, b.stop or self.m, b.step or 1) - return Matrix([[self._get_entry(i, j) for j in b] for i in a]) + if isinstance(a, range) and isinstance(b, range): + return Matrix([[self._get_entry(i, j) for j in b] for i in a]) + if isinstance(a, range): # b is not range + return Vector([self._get_entry(i, b) for i in a]) + # a is not range while b is range + return Vector([self._get_entry(a, j) for j in b]) @python_scope def _set_entry(self, i, j, item): diff --git a/tests/python/test_matrix_slice.py b/tests/python/test_matrix_slice.py index cecb264716e17..3f815b7bca24c 100644 --- a/tests/python/test_matrix_slice.py +++ b/tests/python/test_matrix_slice.py @@ -26,6 +26,8 @@ def foo2() -> ti.types.matrix(2, 3, dtype=ti.i32): assert (v2 == ti.Vector([3, 6])).all() m2 = ti.Matrix([[2, 3], [4, 5]])[:1, 1:] assert (m2 == ti.Matrix([[3]])).all() + v3 = ti.Matrix([[1, 2], [3, 4]])[:, 1] + assert (v3 == ti.Vector([2, 4])).all() @test_utils.test() @@ -52,36 +54,34 @@ def foo2(): @test_utils.test() def test_matrix_slice_with_variable(): @ti.kernel - def test_one_row_slice( - index: ti.i32) -> ti.types.matrix(2, 1, dtype=ti.i32): + def test_one_row_slice(index: ti.i32) -> ti.types.vector(2, dtype=ti.i32): m = ti.Matrix([[1, 2, 3], [4, 5, 6]]) return m[:, index] @ti.kernel - def test_one_col_slice( - index: ti.i32) -> ti.types.matrix(1, 3, dtype=ti.i32): + def test_one_col_slice(index: ti.i32) -> ti.types.vector(3, dtype=ti.i32): m = ti.Matrix([[1, 2, 3], [4, 5, 6]]) return m[index, :] r1 = test_one_row_slice(1) - assert (r1 == ti.Matrix([[2], [5]])).all() + assert (r1 == ti.Vector([2, 5])).all() c1 = test_one_col_slice(1) - assert (c1 == ti.Matrix([[4, 5, 6]])).all() + assert (c1 == ti.Vector([4, 5, 6])).all() @test_utils.test() def test_matrix_slice_write(): @ti.kernel - def assign_row() -> ti.types.matrix(3, 4, ti.i32): + def assign_col() -> ti.types.matrix(3, 4, ti.i32): mat = ti.Matrix([[0, 0, 0, 0] for _ in range(3)]) - row = ti.Matrix([[1, 2, 3, 4]]) - mat[0, :] = row + col = ti.Vector([1, 2, 3]) + mat[:, 0] = col return mat @ti.kernel def assign_partial_row() -> ti.types.matrix(3, 4, ti.i32): mat = ti.Matrix([[0, 0, 0, 0] for _ in range(3)]) - mat[1, 1:3] = ti.Matrix([[1, 2]]) + mat[1, 1:3] = ti.Vector([1, 2]) return mat @ti.kernel @@ -91,8 +91,8 @@ def augassign_rows() -> ti.types.matrix(3, 4, ti.i32): mat[:2, :] += rows return mat - assert (assign_row() == ti.Matrix([[1, 2, 3, 4], [0, 0, 0, 0], - [0, 0, 0, 0]])).all() + assert (assign_col() == ti.Matrix([[1, 0, 0, 0], [2, 0, 0, 0], + [3, 0, 0, 0]])).all() assert (assign_partial_row() == ti.Matrix([[0, 0, 0, 0], [0, 1, 2, 0], [0, 0, 0, 0]])).all() assert (augassign_rows() == ti.Matrix([[2, 3, 4, 5], [2, 3, 4, 5], @@ -104,7 +104,7 @@ def test_matrix_slice_write_dynamic_index(): @ti.kernel def foo(i: ti.i32) -> ti.types.matrix(3, 4, ti.i32): mat = ti.Matrix([[0, 0, 0, 0] for _ in range(3)]) - mat[i, :] = ti.Matrix([[1, 2, 3, 4]]) + mat[i, :] = ti.Vector([1, 2, 3, 4]) return mat for i in range(3):