Skip to content

Commit

Permalink
[Lang] Make slicing a single row/column of a matrix return a vector (t…
Browse files Browse the repository at this point in the history
…aichi-dev#7068)

Issue: taichi-dev#6978, fix taichi-dev#6902

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent 48b1b0b commit 0060231
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 27 deletions.
25 changes: 18 additions & 7 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,20 +229,31 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
dim = len(shape)
assert dim == len(indices)
indices = [
_calc_slice(index, shape[i])
if isinstance(index, slice) else [index]
_calc_slice(index, shape[i]) if isinstance(index, slice) else index
for i, index in enumerate(indices)
]
if dim == 1:
assert isinstance(indices[0], list)
multiple_indices = [make_expr_group(i) for i in indices[0]]
return_shape = (len(indices[0]), )
else:
assert dim == 2
multiple_indices = [
make_expr_group(i, j) for i in indices[0] for j in indices[1]
]
return_shape = (len(indices[0]), len(indices[1]))

if isinstance(indices[0], list) and isinstance(indices[1], list):
multiple_indices = [
make_expr_group(i, j) for i in indices[0]
for j in indices[1]
]
return_shape = (len(indices[0]), len(indices[1]))
elif isinstance(indices[0], list): # indices[1] is not list
multiple_indices = [
make_expr_group(i, indices[1]) for i in indices[0]
]
return_shape = (len(indices[0]), )
else: # indices[0] is not list while indices[1] is list
multiple_indices = [
make_expr_group(indices[0], j) for j in indices[1]
]
return_shape = (len(indices[1]), )
return Expr(
_ti_core.subscript_with_multiple_indices(
value.ptr, multiple_indices, return_shape,
Expand Down
15 changes: 8 additions & 7 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,15 +404,16 @@ def _linearize_entry_id(self, *args):
return args[0] * self.m + args[1]

def _get_slice(self, a, b):
if not isinstance(a, slice):
a = [a]
else:
if isinstance(a, slice):
a = range(a.start or 0, a.stop or self.n, a.step or 1)
if not isinstance(b, slice):
b = [b]
else:
if isinstance(b, slice):
b = range(b.start or 0, b.stop or self.m, b.step or 1)
return Matrix([[self._get_entry(i, j) for j in b] for i in a])
if isinstance(a, range) and isinstance(b, range):
return Matrix([[self._get_entry(i, j) for j in b] for i in a])
if isinstance(a, range): # b is not range
return Vector([self._get_entry(i, b) for i in a])
# a is not range while b is range
return Vector([self._get_entry(a, j) for j in b])

@python_scope
def _set_entry(self, i, j, item):
Expand Down
26 changes: 13 additions & 13 deletions tests/python/test_matrix_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def foo2() -> ti.types.matrix(2, 3, dtype=ti.i32):
assert (v2 == ti.Vector([3, 6])).all()
m2 = ti.Matrix([[2, 3], [4, 5]])[:1, 1:]
assert (m2 == ti.Matrix([[3]])).all()
v3 = ti.Matrix([[1, 2], [3, 4]])[:, 1]
assert (v3 == ti.Vector([2, 4])).all()


@test_utils.test()
Expand All @@ -52,36 +54,34 @@ def foo2():
@test_utils.test()
def test_matrix_slice_with_variable():
@ti.kernel
def test_one_row_slice(
index: ti.i32) -> ti.types.matrix(2, 1, dtype=ti.i32):
def test_one_row_slice(index: ti.i32) -> ti.types.vector(2, dtype=ti.i32):
m = ti.Matrix([[1, 2, 3], [4, 5, 6]])
return m[:, index]

@ti.kernel
def test_one_col_slice(
index: ti.i32) -> ti.types.matrix(1, 3, dtype=ti.i32):
def test_one_col_slice(index: ti.i32) -> ti.types.vector(3, dtype=ti.i32):
m = ti.Matrix([[1, 2, 3], [4, 5, 6]])
return m[index, :]

r1 = test_one_row_slice(1)
assert (r1 == ti.Matrix([[2], [5]])).all()
assert (r1 == ti.Vector([2, 5])).all()
c1 = test_one_col_slice(1)
assert (c1 == ti.Matrix([[4, 5, 6]])).all()
assert (c1 == ti.Vector([4, 5, 6])).all()


@test_utils.test()
def test_matrix_slice_write():
@ti.kernel
def assign_row() -> ti.types.matrix(3, 4, ti.i32):
def assign_col() -> ti.types.matrix(3, 4, ti.i32):
mat = ti.Matrix([[0, 0, 0, 0] for _ in range(3)])
row = ti.Matrix([[1, 2, 3, 4]])
mat[0, :] = row
col = ti.Vector([1, 2, 3])
mat[:, 0] = col
return mat

@ti.kernel
def assign_partial_row() -> ti.types.matrix(3, 4, ti.i32):
mat = ti.Matrix([[0, 0, 0, 0] for _ in range(3)])
mat[1, 1:3] = ti.Matrix([[1, 2]])
mat[1, 1:3] = ti.Vector([1, 2])
return mat

@ti.kernel
Expand All @@ -91,8 +91,8 @@ def augassign_rows() -> ti.types.matrix(3, 4, ti.i32):
mat[:2, :] += rows
return mat

assert (assign_row() == ti.Matrix([[1, 2, 3, 4], [0, 0, 0, 0],
[0, 0, 0, 0]])).all()
assert (assign_col() == ti.Matrix([[1, 0, 0, 0], [2, 0, 0, 0],
[3, 0, 0, 0]])).all()
assert (assign_partial_row() == ti.Matrix([[0, 0, 0, 0], [0, 1, 2, 0],
[0, 0, 0, 0]])).all()
assert (augassign_rows() == ti.Matrix([[2, 3, 4, 5], [2, 3, 4, 5],
Expand All @@ -104,7 +104,7 @@ def test_matrix_slice_write_dynamic_index():
@ti.kernel
def foo(i: ti.i32) -> ti.types.matrix(3, 4, ti.i32):
mat = ti.Matrix([[0, 0, 0, 0] for _ in range(3)])
mat[i, :] = ti.Matrix([[1, 2, 3, 4]])
mat[i, :] = ti.Vector([1, 2, 3, 4])
return mat

for i in range(3):
Expand Down

0 comments on commit 0060231

Please sign in to comment.