diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index cb479b38bf174..02ab13d6036ba 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -705,12 +705,28 @@ def func__(*args): array_shape = v.shape[ element_dim:] if is_soa else v.shape[:-element_dim] if is_numpy: - tmp = np.ascontiguousarray(v) - # Purpose: DO NOT GC |tmp|! - tmps.append(tmp) - launch_ctx.set_arg_external_array_with_shape( - actual_argument_slot, int(tmp.ctypes.data), - tmp.nbytes, array_shape) + if v.flags.c_contiguous: + launch_ctx.set_arg_external_array_with_shape( + actual_argument_slot, int(v.ctypes.data), + v.nbytes, array_shape) + elif v.flags.f_contiguous: + # TODO: A better way that avoids copying is saving strides info. + tmp = np.ascontiguousarray(v) + # Purpose: DO NOT GC |tmp|! + tmps.append(tmp) + + def callback(original, updated): + np.copyto(original, np.asfortranarray(updated)) + + callbacks.append( + functools.partial(callback, v, tmp)) + launch_ctx.set_arg_external_array_with_shape( + actual_argument_slot, int(tmp.ctypes.data), + tmp.nbytes, array_shape) + else: + raise ValueError( + "Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) before passing it into taichi kernel." + ) elif is_torch: is_ndarray = False tmp, torch_callbacks = self.get_torch_callbacks( diff --git a/tests/python/test_numpy.py b/tests/python/test_numpy.py index 87b0d8984f9a4..694f1720a21f0 100644 --- a/tests/python/test_numpy.py +++ b/tests/python/test_numpy.py @@ -89,18 +89,21 @@ def test_numpy_2d_transpose(): def test_numpy(arr: ti.types.ndarray()): for i in ti.grouped(val): val[i] = arr[i] + arr[i] = 1 a = np.empty(shape=(n, m), dtype=np.int32) + b = a.transpose() for i in range(n): for j in range(m): a[i, j] = i * j + i * 4 - test_numpy(a.transpose()) + test_numpy(b) for i in range(n): for j in range(m): assert val[i, j] == i * j + j * 4 + assert a[i][j] == 1 @test_utils.test() @@ -234,3 +237,15 @@ def test(): assert all(y == [1.0, 2.0]) test() + + +@test_utils.test() +def test_numpy_view(): + @ti.kernel + def fill(img: ti.types.ndarray()): + img[0] = 1 + + a = np.zeros(shape=(2, 2))[:, 0] + with pytest.raises(ValueError, + match='Non contiguous numpy arrays are not supported'): + fill(a)