Skip to content

Commit

Permalink
[Misc] Strictly check ndim with external array (#7126)
Browse files Browse the repository at this point in the history
Issue: #6572 

* Check external ndarray total dim against type annotation
* Error out when element shape mismatch
* Refine error messages

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
turbo0628 and pre-commit-ci[bot] authored Jan 12, 2023
1 parent bef8e49 commit 05b8095
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 11 deletions.
6 changes: 3 additions & 3 deletions docs/lang/articles/get-started/accelerate_pytorch.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ The Taichi reference code is almost identical to its Python counterpart. And a g
```python
@ti.kernel
def taichi_forward_v0(
out: ti.types.ndarray(field_dim=3),
w: ti.types.ndarray(field_dim=3),
k: ti.types.ndarray(field_dim=3),
out: ti.types.ndarray(ndim=3),
w: ti.types.ndarray(ndim=3),
k: ti.types.ndarray(ndim=3),
eps: ti.f32):

for b, c, t in out:
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _ndarray_matrix_from_numpy(self, arr, as_vector):
raise TypeError(f"{np.ndarray} expected, but {type(arr)} provided")
if tuple(self.arr.total_shape()) != tuple(arr.shape):
raise ValueError(
f"Mismatch shape: {tuple(self.arr.shape)} expected, but {tuple(arr.shape)} provided"
f"Mismatch shape: {tuple(self.arr.total_shape())} expected, but {tuple(arr.shape)} provided"
)
if not arr.flags.c_contiguous:
arr = np.ascontiguousarray(arr)
Expand Down
27 changes: 23 additions & 4 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,30 @@ def extract_arg(arg, anno):
shape = tuple(shape)
element_shape = ()
if isinstance(anno.dtype, MatrixType):
if len(shape) < anno.dtype.ndim:
raise ValueError(
f"Invalid argument into ti.types.ndarray() - required element_dim={anno.dtype.ndim}, "
f"but the argument has only {len(shape)} dimensions")
if anno.ndim is not None:
if len(shape) != anno.dtype.ndim + anno.ndim:
raise ValueError(
f"Invalid argument into ti.types.ndarray() - required array has ndim={anno.ndim} element_dim={anno.dtype.ndim}, "
f"but the argument has {len(shape)} dimensions")
else:
if len(shape) < anno.dtype.ndim:
raise ValueError(
f"Invalid argument into ti.types.ndarray() - required element_dim={anno.dtype.ndim}, "
f"but the argument has only {len(shape)} dimensions"
)
element_shape = shape[-anno.dtype.ndim:]
anno_element_shape = anno.dtype.get_shape()
if None not in anno_element_shape and element_shape != anno_element_shape:
raise ValueError(
f"Invalid argument into ti.types.ndarray() - required element_shape={anno_element_shape}, "
f"but the argument has element shape of {element_shape}"
)
elif anno.dtype is not None:
# User specified scalar dtype
if anno.ndim is not None and len(shape) != anno.ndim:
raise ValueError(
f"Invalid argument into ti.types.ndarray() - required array has ndim={anno.ndim}, "
f"but the argument has {len(shape)} dimensions")
return to_taichi_type(
arg.dtype), len(shape), element_shape, Layout.AOS
if isinstance(anno, sparse_matrix_builder):
Expand Down
3 changes: 0 additions & 3 deletions tests/python/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,6 @@ def func(a: ti.types.ndarray(ti.types.vector(n=10, dtype=ti.i32))):
v = np.zeros((6, 10), dtype=np.int32)
func(v)
assert impl.get_runtime().get_num_compiled_functions() == 1
v = np.zeros((6, 11), dtype=np.int32)
func(v)
assert impl.get_runtime().get_num_compiled_functions() == 2


@test_utils.test(arch=supported_archs_taichi_ndarray)
Expand Down
41 changes: 41 additions & 0 deletions tests/python/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,44 @@ def fill(img: ti.types.ndarray()):
with pytest.raises(ValueError,
match='Non contiguous numpy arrays are not supported'):
fill(a)


@test_utils.test()
def test_numpy_ndarray_dim_check():
@ti.kernel
def add_one_mat(arr: ti.types.ndarray(dtype=ti.math.mat3, ndim=2)):
for i in ti.grouped(arr):
arr[i] = arr[i] + 1.0

@ti.kernel
def add_one_scalar(arr: ti.types.ndarray(dtype=ti.f32, ndim=2)):
for i in ti.grouped(arr):
arr[i] = arr[i] + 1.0

a = np.zeros(shape=(2, 2, 3, 3), dtype=np.float32)
b = np.zeros(shape=(2, 2, 2, 3), dtype=np.float32)
c = np.zeros(shape=(2, 2, 3), dtype=np.float32)
d = np.zeros(shape=(2, 2), dtype=np.float32)
add_one_mat(a)
add_one_scalar(d)
np.testing.assert_allclose(a, np.ones(shape=(2, 2, 3, 3),
dtype=np.float32))
np.testing.assert_allclose(d, np.ones(shape=(2, 2), dtype=np.float32))
with pytest.raises(
ValueError,
match=
r'Invalid argument into ti.types.ndarray\(\) - required element_shape=\(.*\), but the argument has element shape of \(.*\)'
):
add_one_mat(b)
with pytest.raises(
ValueError,
match=
r'Invalid argument into ti.types.ndarray\(\) - required array has ndim=2 element_dim=2, but the argument has 3 dimensions'
):
add_one_mat(c)
with pytest.raises(
ValueError,
match=
r'Invalid argument into ti.types.ndarray\(\) - required array has ndim=2, but the argument has 4 dimensions'
):
add_one_scalar(a)

0 comments on commit 05b8095

Please sign in to comment.