Skip to content

Commit

Permalink
[lang] Improve misc error report related to passing ndarray to a kern…
Browse files Browse the repository at this point in the history
…el (#7966)

Issue: #6572

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at d1aa92d</samp>

Add new tests and error handling for ndarray feature. Improve `__repr__`
method for `TensorType` and error message for `NdarrayType`.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at d1aa92d</samp>

* Add a string representation method for tensor types
([link](https://github.com/taichi-dev/taichi/pull/7966/files?diff=unified&w=0#diff-78972e3ce6c462d977b9e713e447e8f3305899c8037a6992c37546aa0c4cb291L22-R25))
* Improve the error message for ndarray type validation
([link](https://github.com/taichi-dev/taichi/pull/7966/files?diff=unified&w=0#diff-06e0109e9fa5071f7d364306981845d410fa17425db48001e3ba69337b47c152L127-R127))
* Add tests for ndarray type mismatch scenarios
([link](https://github.com/taichi-dev/taichi/pull/7966/files?diff=unified&w=0#diff-ca3c8d1edb25b6a7f4affbb79b2e3e74f73b3757e5d465258ce42ea9eb09fbc0L6-R6),
[link](https://github.com/taichi-dev/taichi/pull/7966/files?diff=unified&w=0#diff-ca3c8d1edb25b6a7f4affbb79b2e3e74f73b3757e5d465258ce42ea9eb09fbc0R870-R910))
  • Loading branch information
ailzhang authored May 10, 2023
1 parent 4a24872 commit 477a0c7
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 2 deletions.
3 changes: 3 additions & 0 deletions python/taichi/types/compound_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def shape(self):
def element_type(self):
return self.ptr.element_type()

def __repr__(self):
return f"TensorType(shape={self.shape()}, dtype={self.element_type()})"


# TODO: maybe move MatrixType, StructType here to avoid the circular import?
def matrix(n, m, dtype):
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/types/ndarray_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def check_matched(self, ndarray_type: NdarrayTypeMetadata):
# Check ndim match
if self.ndim is not None and ndarray_type.shape is not None and self.ndim != len(ndarray_type.shape):
raise ValueError(
f"Invalid argument into ti.types.ndarray() - required ndim={self.ndim}, but {ndarray_type.element_type} is provided"
f"Invalid argument into ti.types.ndarray() - required ndim={self.ndim}, but {len(ndarray_type.shape)}d ndarray with shape {ndarray_type.shape} is provided"
)

# Check needs_grad
Expand Down
43 changes: 42 additions & 1 deletion tests/python/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest
from taichi.lang import impl
from taichi.lang.exception import TaichiIndexError
from taichi.lang.exception import TaichiIndexError, TaichiTypeError
from taichi.lang.misc import get_host_arch_list
from taichi.lang.util import has_pytorch

Expand Down Expand Up @@ -867,3 +867,44 @@ def test_ndarray_fill():

x_mat.fill(mat2x2([[2.0, 4.0], [1.0, 3.0]]))
assert (x_mat[3, 3] == [[2.0, 4.0], [1.0, 3.0]]).all()


@test_utils.test(arch=supported_archs_taichi_ndarray)
def test_ndarray_wrong_dtype():
@ti.kernel
def test2(arr: ti.types.ndarray(dtype=ti.f32)):
for I in ti.grouped(arr):
arr[I] = 2.0

tp_ivec3 = ti.types.vector(3, ti.i32)

y = ti.ndarray(tp_ivec3, shape=(12, 4))
with pytest.raises(TypeError, match=r"get TensorType\(shape=\(3,\), dtype=i32\)"):
test2(y)


@test_utils.test(arch=supported_archs_taichi_ndarray)
def test_ndarray_bad_assign():
tp_ivec3 = ti.types.vector(3, ti.i32)

@ti.kernel
def test4(arr: ti.types.ndarray(dtype=tp_ivec3)):
for I in ti.grouped(arr):
arr[I] = [1, 2]

y = ti.ndarray(tp_ivec3, shape=(12, 4))
with pytest.raises(TaichiTypeError, match=r"cannot assign '\[Tensor \(2\) i32\]' to '\[Tensor \(3\) i32\]'"):
test4(y)


@test_utils.test(arch=supported_archs_taichi_ndarray)
def test_bad_ndim():
x = ti.ndarray(ti.f32, shape=(12, 13))

@ti.kernel
def test5(arr: ti.types.ndarray(ndim=1)):
for i, j in arr:
arr[i, j] = 0

with pytest.raises(ValueError, match=r"required ndim=1, but 2d ndarray with shape \(12, 13\) is provided"):
test5(x)

0 comments on commit 477a0c7

Please sign in to comment.