Skip to content

Commit

Permalink
[bug][lang] Fix copyback for fortran contiguous numpy arrays
Browse files Browse the repository at this point in the history
fixes taichi-dev#6305

Taichi kernel assumes input external array are row major / c_contiguous
for now so this PR throws an error message when the input numpy array
isn't contiguous.

Note that we only supports inplace update for c_contiguous
numpy arrays but for historical reason support for f_contiguous array
was added via copying (`ascontiguousarray()`) as well, thus this PR tries to
preserve the support to avoid breaking old code.

Also the support for f_contiguous numpy array was halfly done since
`ascontiguousarray()` may return a new numpy array, Taichi kernels just
read/write on the copied array and don't copy the values back
to the original numpy array.

This PR fixes the bug mentioned above by adding a callback function to
copy values back, although copying behavior isn't efficient it
guarantees correctness so that we can improve it in the future.
  • Loading branch information
Ailing Zhang committed Oct 19, 2022
1 parent 2ecff6b commit 7291e97
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 7 deletions.
28 changes: 22 additions & 6 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,12 +705,28 @@ def func__(*args):
array_shape = v.shape[
element_dim:] if is_soa else v.shape[:-element_dim]
if is_numpy:
tmp = np.ascontiguousarray(v)
# Purpose: DO NOT GC |tmp|!
tmps.append(tmp)
launch_ctx.set_arg_external_array_with_shape(
actual_argument_slot, int(tmp.ctypes.data),
tmp.nbytes, array_shape)
if v.flags.c_contiguous:
launch_ctx.set_arg_external_array_with_shape(
actual_argument_slot, int(v.ctypes.data),
v.nbytes, array_shape)
elif v.flags.f_contiguous:
# TODO: A better way that avoids copying is saving strides info.
tmp = np.ascontiguousarray(v)
# Purpose: DO NOT GC |tmp|!
tmps.append(tmp)

def callback(original, updated):
np.copyto(original, np.asfortranarray(updated))

callbacks.append(
functools.partial(callback, v, tmp))
launch_ctx.set_arg_external_array_with_shape(
actual_argument_slot, int(tmp.ctypes.data),
tmp.nbytes, array_shape)
else:
raise ValueError(
"Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) before passing it into taichi kernel."
)
elif is_torch:
is_ndarray = False
tmp, torch_callbacks = self.get_torch_callbacks(
Expand Down
17 changes: 16 additions & 1 deletion tests/python/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,21 @@ def test_numpy_2d_transpose():
def test_numpy(arr: ti.types.ndarray()):
for i in ti.grouped(val):
val[i] = arr[i]
arr[i] = 1

a = np.empty(shape=(n, m), dtype=np.int32)
b = a.transpose()

for i in range(n):
for j in range(m):
a[i, j] = i * j + i * 4

test_numpy(a.transpose())
test_numpy(b)

for i in range(n):
for j in range(m):
assert val[i, j] == i * j + j * 4
assert a[i][j] == 1


@test_utils.test()
Expand Down Expand Up @@ -234,3 +237,15 @@ def test():
assert all(y == [1.0, 2.0])

test()


@test_utils.test()
def test_numpy_view():
@ti.kernel
def fill(img: ti.types.ndarray()):
img[0] = 1

a = np.zeros(shape=(2, 2))[:, 0]
with pytest.raises(ValueError,
match='Non contiguous numpy arrays are not supported'):
fill(a)

0 comments on commit 7291e97

Please sign in to comment.