Skip to content

Commit

Permalink
[lang] Improve misc error report related to passing ndarray to a kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
Ailing Zhang committed May 9, 2023
1 parent 975a73b commit d1aa92d
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 2 deletions.
3 changes: 3 additions & 0 deletions python/taichi/types/compound_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def shape(self):
def element_type(self):
return self.ptr.element_type()

def __repr__(self):
return f"TensorType(shape={self.shape()}, dtype={self.element_type()})"


# TODO: maybe move MatrixType, StructType here to avoid the circular import?
def matrix(n, m, dtype):
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/types/ndarray_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def check_matched(self, ndarray_type: NdarrayTypeMetadata):
# Check ndim match
if self.ndim is not None and ndarray_type.shape is not None and self.ndim != len(ndarray_type.shape):
raise ValueError(
f"Invalid argument into ti.types.ndarray() - required ndim={self.ndim}, but {ndarray_type.element_type} is provided"
f"Invalid argument into ti.types.ndarray() - required ndim={self.ndim}, but {len(ndarray_type.shape)}d ndarray with shape {ndarray_type.shape} is provided"
)

# Check needs_grad
Expand Down
43 changes: 42 additions & 1 deletion tests/python/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest
from taichi.lang import impl
from taichi.lang.exception import TaichiIndexError
from taichi.lang.exception import TaichiIndexError, TaichiTypeError
from taichi.lang.misc import get_host_arch_list
from taichi.lang.util import has_pytorch

Expand Down Expand Up @@ -867,3 +867,44 @@ def test_ndarray_fill():

x_mat.fill(mat2x2([[2.0, 4.0], [1.0, 3.0]]))
assert (x_mat[3, 3] == [[2.0, 4.0], [1.0, 3.0]]).all()


@test_utils.test(arch=supported_archs_taichi_ndarray)
def test_ndarray_wrong_dtype():
@ti.kernel
def test2(arr: ti.types.ndarray(dtype=ti.f32)):
for I in ti.grouped(arr):
arr[I] = 2.0

tp_ivec3 = ti.types.vector(3, ti.i32)

y = ti.ndarray(tp_ivec3, shape=(12, 4))
with pytest.raises(TypeError, match=r"get TensorType\(shape=\(3,\), dtype=i32\)"):
test2(y)


@test_utils.test(arch=supported_archs_taichi_ndarray)
def test_ndarray_bad_assign():
tp_ivec3 = ti.types.vector(3, ti.i32)

@ti.kernel
def test4(arr: ti.types.ndarray(dtype=tp_ivec3)):
for I in ti.grouped(arr):
arr[I] = [1, 2]

y = ti.ndarray(tp_ivec3, shape=(12, 4))
with pytest.raises(TaichiTypeError, match=r"cannot assign '\[Tensor \(2\) i32\]' to '\[Tensor \(3\) i32\]'"):
test4(y)


@test_utils.test(arch=supported_archs_taichi_ndarray)
def test_bad_ndim():
x = ti.ndarray(ti.f32, shape=(12, 13))

@ti.kernel
def test5(arr: ti.types.ndarray(ndim=1)):
for i, j in arr:
arr[i, j] = 0

with pytest.raises(ValueError, match=r"required ndim=1, but 2d ndarray with shape \(12, 13\) is provided"):
test5(x)

0 comments on commit d1aa92d

Please sign in to comment.