Skip to content

Commit

Permalink
[lang] [refactor] Refine implementation and tests for matrix slice (t…
Browse files Browse the repository at this point in the history
…aichi-dev#6373)

Issue: taichi-dev#4257, taichi-dev#5819

### Brief Summary

Let's refine the matrix slice feature a bit before supporting it with
the new MatrixType.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Oct 19, 2022
1 parent 51ee752 commit 2ecff6b
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 31 deletions.
19 changes: 17 additions & 2 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from taichi.lang._ndrange import GroupedNDRange, _Ndrange
from taichi.lang.any_array import AnyArray, AnyArrayAccess
from taichi.lang.enums import SNodeGradType
from taichi.lang.exception import (TaichiRuntimeError, TaichiSyntaxError,
TaichiTypeError)
from taichi.lang.exception import (TaichiCompilationError, TaichiRuntimeError,
TaichiSyntaxError, TaichiTypeError)
from taichi.lang.expr import Expr, make_expr_group
from taichi.lang.field import Field, ScalarField
from taichi.lang.kernel_arguments import SparseMatrixProxy
Expand Down Expand Up @@ -135,6 +135,21 @@ def begin_frontend_if(ast_builder, cond):
ast_builder.begin_frontend_if(Expr(cond).ptr)


@taichi_scope
def _calc_slice(index, default_stop):
start, stop, step = index.start or 0, index.stop or default_stop, index.step or 1

def check_validity(x):
# TODO(mzmzm): support variable in slice
if isinstance(x, Expr):
raise TaichiCompilationError(
"Taichi does not support variables in slice now, please use constant instead of it."
)

check_validity(start), check_validity(stop), check_validity(step)
return [_ for _ in range(start, stop, step)]


@taichi_scope
def subscript(ast_builder,
value,
Expand Down
19 changes: 2 additions & 17 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,10 @@ def _subscript(self, is_global_mat, *indices, get_ref=False):
j = 0 if len(indices) == 1 else indices[1]
has_slice = False
if isinstance(i, slice):
i = self._calc_slice(i, 0)
i = impl._calc_slice(i, self.n)
has_slice = True
if isinstance(j, slice):
j = self._calc_slice(j, 1)
j = impl._calc_slice(j, self.m)
has_slice = True

if has_slice:
Expand Down Expand Up @@ -280,21 +280,6 @@ def _subscript(self, is_global_mat, *indices, get_ref=False):
self.dynamic_index_stride)
return self._get_entry(i, j)

def _calc_slice(self, index, dim):
start, stop, step = index.start or 0, index.stop or (
self.n if dim == 0 else self.m), index.step or 1

def helper(x):
# TODO(mzmzm): support variable in slice
if isinstance(x, expr.Expr):
raise TaichiCompilationError(
"Taichi does not support variables in slice now, please use constant instead of it."
)
return x

start, stop, step = helper(start), helper(stop), helper(step)
return [_ for _ in range(start, stop, step)]


class _MatrixEntriesInitializer:
def pyscope_or_ref(self, arr):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,52 @@


@test_utils.test()
def test_slice():
b = 3
def test_matrix_slice_read():
b = 6

@ti.kernel
def foo1() -> ti.types.vector(3, dtype=ti.i32):
c = ti.Vector([0, 1, 2, 3, 4, 5, 6])
return c[:5:2]
return c[:b:2]

@ti.kernel
def foo2() -> ti.types.matrix(2, 2, dtype=ti.i32):
a = ti.Matrix([[1, 2, 3], [4, 5, 6]])
return a[:, :b:2]
def foo2() -> ti.types.matrix(2, 3, dtype=ti.i32):
a = ti.Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
return a[1::, :]

v1 = foo1()
assert (v1 == ti.Vector([0, 2, 4])).all() == 1
assert (v1 == ti.Vector([0, 2, 4])).all()
m1 = foo2()
assert (m1 == ti.Matrix([[1, 3], [4, 6]])).all() == 1
assert (m1 == ti.Matrix([[4, 5, 6], [7, 8, 9]])).all()
v2 = ti.Vector([1, 2, 3, 4, 5, 6])[2::3]
assert (v2 == ti.Vector([3, 6])).all()
m2 = ti.Matrix([[2, 3], [4, 5]])[:1, 1:]
assert (m2 == ti.Matrix([[3]])).all()


@test_utils.test()
def test_matrix_slice_invalid():
@ti.kernel
def foo1(i: ti.i32):
a = ti.Vector([0, 1, 2, 3, 4, 5, 6])
b = a[i::2]

@ti.kernel
def foo2():
i = 2
a = ti.Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = a[:i:, :i]

with pytest.raises(ti.TaichiCompilationError,
match='Taichi does not support variables in slice now'):
foo1(1)
with pytest.raises(ti.TaichiCompilationError,
match='Taichi does not support variables in slice now'):
foo2()


@test_utils.test(dynamic_index=True)
def test_dyn():
def test_matrix_slice_with_variable():
@ti.kernel
def test_one_row_slice() -> ti.types.matrix(2, 1, dtype=ti.i32):
m = ti.Matrix([[1, 2, 3], [4, 5, 6]])
Expand All @@ -39,13 +64,13 @@ def test_one_col_slice() -> ti.types.matrix(1, 3, dtype=ti.i32):
return m[index, :]

r1 = test_one_row_slice()
assert (r1 == ti.Matrix([[2], [5]])).all() == 1
assert (r1 == ti.Matrix([[2], [5]])).all()
c1 = test_one_col_slice()
assert (c1 == ti.Matrix([[4, 5, 6]])).all() == 1
assert (c1 == ti.Matrix([[4, 5, 6]])).all()


@test_utils.test(dynamic_index=False)
def test_no_dyn():
def test_matrix_slice_with_variable_invalid():
@ti.kernel
def test_one_col_slice() -> ti.types.matrix(1, 3, dtype=ti.i32):
m = ti.Matrix([[1, 2, 3], [4, 5, 6]])
Expand Down

0 comments on commit 2ecff6b

Please sign in to comment.