Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

This PR adds JAX as a new nplike #1399

Merged
merged 15 commits into from
Apr 12, 2022
631 changes: 0 additions & 631 deletions dev/generate-cuda.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/awkward/_v2/_connect/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __init__(self, name, error_context):
self.error_context = error_context


def import_cupy(name):
def import_cupy(name="Awkward Arrays with CUDA"):
if cupy is None:
raise ImportError(error_message.format(name))
return cupy
Expand Down
22 changes: 22 additions & 0 deletions src/awkward/_v2/_connect/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,23 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

try:
import jax

error_message = None

except ModuleNotFoundError:
jax = None
error_message = """to use {0}, you must install jax:

pip install jax jaxlib

or

conda install -c conda-forge jax jaxlib
"""


def import_jax(name="Awkward Arrays with JAX"):
if jax is None:
raise ImportError(error_message.format(name))
return jax
221 changes: 220 additions & 1 deletion src/awkward/_v2/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import threading
import traceback

from collections.abc import Sequence, Mapping
from collections.abc import Sequence, Mapping, Iterable

import awkward as ak

Expand All @@ -32,6 +32,7 @@
_backends = {
"cpu": ak.nplike.Numpy,
"cuda": ak.nplike.Cupy,
"jax": ak.nplike.Jax,
}


Expand Down Expand Up @@ -959,3 +960,221 @@ def expand_braces(text, seen=None):


expand_braces.regex = re.compile(r"\{[^\{\}]*\}")


def from_arraylib(array, regulararray, recordarray, highlevel, behavior):
np = ak.nplike.NumpyMetadata.instance()
numpy = ak.nplike.Numpy.instance()

def recurse(array, mask=None):
if regulararray and len(array.shape) > 1:
return ak._v2.contents.RegularArray(
recurse(array.reshape((-1,) + array.shape[2:])),
array.shape[1],
array.shape[0],
)

if len(array.shape) == 0:
array = ak._v2.contents.NumpyArray(array.reshape(1))

if array.dtype.kind == "S":
asbytes = array.reshape(-1)
itemsize = asbytes.dtype.itemsize
starts = numpy.arange(0, len(asbytes) * itemsize, itemsize, dtype=np.int64)
stops = starts + numpy.char.str_len(asbytes)
data = ak._v2.contents.ListArray(
ak._v2.index.Index64(starts),
ak._v2.index.Index64(stops),
ak._v2.contents.NumpyArray(
asbytes.view("u1"), parameters={"__array__": "byte"}, nplike=numpy
),
parameters={"__array__": "bytestring"},
)
for i in range(len(array.shape) - 1, 0, -1):
data = ak._v2.contents.RegularArray(
data, array.shape[i], array.shape[i - 1]
)

elif array.dtype.kind == "U":
asbytes = numpy.char.encode(array.reshape(-1), "utf-8", "surrogateescape")
itemsize = asbytes.dtype.itemsize
starts = numpy.arange(0, len(asbytes) * itemsize, itemsize, dtype=np.int64)
stops = starts + numpy.char.str_len(asbytes)
data = ak._v2.contents.ListArray(
ak._v2.index.Index64(starts),
ak._v2.index.Index64(stops),
ak._v2.contents.NumpyArray(
asbytes.view("u1"), parameters={"__array__": "char"}, nplike=numpy
),
parameters={"__array__": "string"},
)
for i in range(len(array.shape) - 1, 0, -1):
data = ak._v2.contents.RegularArray(
data, array.shape[i], array.shape[i - 1]
)

else:
data = ak._v2.contents.NumpyArray(array)

if mask is None:
return data

elif mask is False or (isinstance(mask, np.bool_) and not mask):
# NumPy's MaskedArray with mask == False is an UnmaskedArray
if len(array.shape) == 1:
return ak._v2.contents.UnmaskedArray(data)
else:

def attach(x):
if isinstance(x, ak._v2.contents.NumpyArray):
return ak._v2.contents.UnmaskedArray(x)
else:
return ak._v2.contents.RegularArray(
attach(x.content), x.size, len(x)
)

return attach(data.toRegularArray())

else:
# NumPy's MaskedArray is a ByteMaskedArray with valid_when=False
return ak._v2.contents.ByteMaskedArray(
ak._v2.index.Index8(mask), data, valid_when=False
)

return data

if isinstance(array, numpy.ma.MaskedArray):
mask = numpy.ma.getmask(array)
array = numpy.ma.getdata(array)
if isinstance(mask, np.ndarray) and len(mask.shape) > 1:
regulararray = True
mask = mask.reshape(-1)
else:
mask = None

if not recordarray or array.dtype.names is None:
layout = recurse(array, mask)

else:
contents = []
for name in array.dtype.names:
contents.append(recurse(array[name], mask))
layout = ak._v2.contents.RecordArray(contents, array.dtype.names)

return ak._v2._util.wrap(layout, behavior, highlevel)


def to_arraylib(module, array, allow_missing):
def _impl(array):
if isinstance(array, (bool, numbers.Number)):
return module.array(array)

elif isinstance(array, module.ndarray):
return array

elif isinstance(array, np.ndarray):
return module.asarray(array)

elif isinstance(array, ak._v2.highlevel.Array):
return _impl(array.layout)

elif isinstance(array, ak._v2.highlevel.Record):
raise ak._v2._util.error(
ValueError(f"{module.__name__} does not support record structures")
)

elif isinstance(array, ak._v2.highlevel.ArrayBuilder):
return _impl(array.snapshot().layout)

elif isinstance(array, ak.layout.ArrayBuilder):
return _impl(array.snapshot())

elif (
ak._v2.operations.describe.parameters(array).get("__array__")
== "bytestring"
or ak._v2.operations.describe.parameters(array).get("__array__") == "string"
):
raise ak._v2._util.error(
ValueError(f"{module.__name__} does not support arrays of strings")
)

elif isinstance(array, ak._v2.contents.EmptyArray):
return module.array([])

elif isinstance(array, ak._v2.contents.IndexedArray):
return _impl(array.project())

elif isinstance(array, ak._v2.contents.UnionArray):
contents = [_impl(array.project(i)) for i in range(len(array.contents))]
out = module.concatenate(contents)

tags = module.asarray(array.tags)
for tag, content in enumerate(contents):
mask = tags == tag
if type(out).__module__.startswith("jaxlib."):
out = out.at[mask].set(content)
else:
out[mask] = content
return out

elif isinstance(array, ak._v2.contents.UnmaskedArray):
return _impl(array.content)

elif isinstance(array, ak._v2.contents.IndexedOptionArray):
content = _impl(array.project())

mask0 = module.asarray(array.bytemask()).view(np.bool_)
if mask0.any():
raise ak._v2._util.error(
ValueError(f"{module.__name__} does not support masked arrays")
)
else:
return content

elif isinstance(array, ak._v2.contents.RegularArray):
out = _impl(array.content)
head, tail = out.shape[0], out.shape[1:]
shape = (head // array.size, array.size) + tail
return out[: shape[0] * array.size].reshape(shape)

elif isinstance(
array, (ak._v2.contents.ListArray, ak._v2.contents.ListOffsetArray)
):
return _impl(array.toRegularArray())

elif isinstance(array, ak._v2.contents.recordarray.RecordArray):
raise ak._v2._util.error(
ValueError(f"{module.__name__} does not support record structures")
)

elif isinstance(array, ak._v2.contents.NumpyArray):
return module.asarray(array.data)

elif isinstance(array, ak._v2.contents.Content):
raise ak._v2._util.error(
AssertionError(f"unrecognized Content type: {type(array)}")
)

elif isinstance(array, Iterable):
return module.asarray(array)

else:
raise ak._v2._util.error(
ValueError(f"cannot convert {array} into {type(module.array([]))}")
)

if module.__name__.startswith("jax") or module.__name__.startswith("cupy"):
return _impl(array)
elif module.__name__.startswith("numpy"):
layout = ak._v2.operations.convert.to_layout(
array, allow_record=True, allow_other=True
)
jpivarski marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(layout, (ak._v2.contents.Content, ak._v2.record.Record)):
return layout.to_numpy(allow_missing=allow_missing)
else:
return module.asarray(array)
else:
ak._v2._util.error(
ValueError(f"{module.__name__} is not supported by to_arraylib")
)
3 changes: 3 additions & 0 deletions src/awkward/_v2/contents/numpyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,9 @@ def is_contiguous(self):
# Alternatively, self._data.flags["C_CONTIGUOUS"], but the following assumes
# less of the nplike.

if type(self._data).__module__.startswith("jaxlib."):
return True

x = self._data.dtype.itemsize

for i in range(len(self._data.shape), 0, -1):
Expand Down
27 changes: 3 additions & 24 deletions src/awkward/_v2/operations/convert/ak_from_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import awkward as ak

np = ak.nplike.NumpyMetadata.instance()


def from_cupy(array, regulararray=False, highlevel=True, behavior=None):
"""
Expand Down Expand Up @@ -38,25 +36,6 @@ def from_cupy(array, regulararray=False, highlevel=True, behavior=None):
behavior=behavior,
),
):
return _impl(array, regulararray, highlevel, behavior)


def _impl(array, regulararray, highlevel, behavior):
def recurse(array):
if regulararray and len(array.shape) > 1:
return ak._v2.contents.RegularArray(
recurse(array.reshape((-1,) + array.shape[2:])),
array.shape[1],
array.shape[0],
)

if len(array.shape) == 0:
data = ak._v2.contents.NumpyArray(array.reshape(1))
else:
data = ak._v2.contents.NumpyArray(array)

return data

layout = recurse(array)

return ak._v2._util.wrap(layout, behavior, highlevel)
return ak._v2._util.from_arraylib(
array, regulararray, False, highlevel, behavior
)
Loading