Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[numpy] add dlpack functions to npx (#18342)
Browse files Browse the repository at this point in the history
* add dlpack functions to npx

* improve tests

* further improve test

* fix comment
  • Loading branch information
szha authored May 17, 2020
1 parent 10b6b48 commit 7ab326c
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 2 deletions.
125 changes: 123 additions & 2 deletions python/mxnet/numpy_extension/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,24 @@

import ctypes
from .. util import is_np_array, is_np_shape
from .. base import _LIB, check_call, string_types, c_str_array
from .. base import _LIB, check_call, string_types, c_str_array, DLPackHandle
from .. base import c_handle_array, c_str, mx_uint, NDArrayHandle, py_str
from ..numpy import ndarray

__all__ = ['save', 'load']
__all__ = ['save', 'load', 'to_dlpack_for_read', 'to_dlpack_for_write', 'from_dlpack']

PyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
_c_str_dltensor = c_str('dltensor')
_c_str_used_dltensor = c_str('used_dltensor')

def _dlpack_deleter(pycapsule):
pycapsule = ctypes.c_void_p(pycapsule)
if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor):
ptr = ctypes.c_void_p(
ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor))
check_call(_LIB.MXNDArrayCallDLPackDeleter(ptr))

_c_dlpack_deleter = PyCapsuleDestructor(_dlpack_deleter)

def save(file, arr):
"""Saves a list of `ndarray`s or a dict of `str`->`ndarray` to file.
Expand Down Expand Up @@ -119,3 +131,112 @@ def load(file):
return dict(
(py_str(names[i]), ndarray(NDArrayHandle(handles[i])))
for i in range(out_size.value))


def from_dlpack(dlpack):
"""Returns a np.ndarray backed by a dlpack tensor.
Parameters
----------
dlpack: PyCapsule (the pointer of DLManagedTensor)
input data
Returns
-------
np.ndarray
an ndarray backed by a dlpack tensor
Examples
--------
>>> x = mx.np.ones((2,3))
>>> y = mx.npx.to_dlpack_for_read(x)
>>> type(y)
<class 'PyCapsule'>
>>> z = mx.npx.from_dlpack(y)
>>> type(z)
<class 'mxnet.numpy.ndarray'>
>>> z
array([[1., 1., 1.],
[1., 1., 1.]])
>>> w = mx.npx.to_dlpack_for_write(x)
>>> type(w)
<class 'PyCapsule'>
>>> u = mx.npx.from_dlpack(w)
>>> u += 1
>>> x
array([[2., 2., 2.],
[2., 2., 2.]])
"""
handle = NDArrayHandle()
dlpack = ctypes.py_object(dlpack)
assert ctypes.pythonapi.PyCapsule_IsValid(dlpack, _c_str_dltensor), ValueError(
'Invalid DLPack Tensor. DLTensor capsules can be consumed only once.')
dlpack_handle = ctypes.c_void_p(ctypes.pythonapi.PyCapsule_GetPointer(dlpack, _c_str_dltensor))
check_call(_LIB.MXNDArrayFromDLPackEx(dlpack_handle, False, ctypes.byref(handle)))
# Rename PyCapsule (DLPack)
ctypes.pythonapi.PyCapsule_SetName(dlpack, _c_str_used_dltensor)
# delete the deleter of the old dlpack
ctypes.pythonapi.PyCapsule_SetDestructor(dlpack, None)
return ndarray(handle=handle)

def to_dlpack_for_read(data):
"""Returns a reference view of np.ndarray that represents as DLManagedTensor until
all previous write operations on the current array are finished.
Parameters
----------
data: np.ndarray
input data.
Returns
-------
PyCapsule (the pointer of DLManagedTensor)
a reference view of ndarray that represents as DLManagedTensor.
Examples
--------
>>> x = mx.np.ones((2,3))
>>> y = mx.npx.to_dlpack_for_read(x)
>>> type(y)
<class 'PyCapsule'>
>>> z = mx.npx.from_dlpack(y)
>>> z
array([[1., 1., 1.],
[1., 1., 1.]])
"""
data.wait_to_read()
dlpack = DLPackHandle()
check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack)))
return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, _c_dlpack_deleter)

def to_dlpack_for_write(data):
"""Returns a reference view of ndarray that represents as DLManagedTensor until
all previous read/write operations on the current array are finished.
Parameters
----------
data: np.ndarray
input data.
Returns
-------
PyCapsule (the pointer of DLManagedTensor)
a reference view of np.ndarray that represents as DLManagedTensor.
Examples
--------
>>> x = mx.np.ones((2,3))
>>> w = mx.npx.to_dlpack_for_write(x)
>>> type(w)
<class 'PyCapsule'>
>>> u = mx.npx.from_dlpack(w)
>>> u += 1
>>> x
array([[2., 2., 2.],
[2., 2., 2.]])
"""
check_call(_LIB.MXNDArrayWaitToWrite(data.handle))
dlpack = DLPackHandle()
check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack)))
return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, _c_dlpack_deleter)
26 changes: 26 additions & 0 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,3 +1343,29 @@ def test_np_ndarray_pickle():
a_load = pickle.load(f)
same(a.asnumpy(), a_load.asnumpy())

@pytest.mark.parametrize('dtype', [np.float32, np.int32])
@pytest.mark.parametrize('size', [
(3, 4, 5, 6),
(2, 10),
(15,),
()
])
@use_np
def test_dlpack(dtype, size):
a = mx.np.random.uniform(size=size)
a_np = a.copy()
a += 1

pack = mx.npx.to_dlpack_for_read(a)
b = mx.npx.from_dlpack(pack)

a_copy = a.copy()
pack2 = mx.npx.to_dlpack_for_write(a_copy)
c = mx.npx.from_dlpack(pack2)
c += 1

del a, pack, pack2

same(a_np+1, b)
same(a_np+2, c)
same(a_np+2, a_copy)

0 comments on commit 7ab326c

Please sign in to comment.