Skip to content

Commit

Permalink
fix: handle more buffer types & fix nplike mixing (#1769)
Browse files Browse the repository at this point in the history
* fix: handle more buffer types, fix nplike mixing

* test: add test for Jacobian computation

* fix: don't regularise `NumpyArray`s

Remove assert - NumpyArray-only layouts might have Numpy nplike

* fix: use node parameters and identifiers

* test: skip Jacobian test for now.

* fix: restore `TypeError`

* refactor: use `_errors` module directly
.
  • Loading branch information
agoose77 authored Oct 5, 2022
1 parent 5cc8307 commit df377bd
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 22 deletions.
61 changes: 39 additions & 22 deletions src/awkward/_connect/jax/trees.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
from __future__ import annotations

from typing import Generic, NoReturn, TypeVar, Union
from typing import Generic, TypeVar, Union

import jax

Expand All @@ -22,23 +22,31 @@ def action(node, **kwargs):
if isinstance(node, ak.contents.NumpyArray):
data_ptrs.append(node.data)

layout.recursively_apply(action=action, return_array=False)
layout.recursively_apply(action=action, return_array=False, numpy_to_regular=False)

return data_ptrs


def replace_all_buffers(
layout: contents.Content | record.Record, buffers: list[numpy.ndarray]
layout: contents.Content | record.Record,
buffers: list,
nplike: nplikes.NumpyLike,
):
nplike = nplikes.nplike_of(*buffers)
jax = nplikes.Jax.instance()
numpy = nplikes.Numpy.instance()

def action(node, **kwargs):
if isinstance(node, ak.contents.NumpyArray):
return ak.contents.NumpyArray(
buffers.pop(0), layout.identifier, layout.parameters, nplike=nplike
)
buffer = buffers.pop(0)
# JAX might give us non-buffers, so ignore them
if not (numpy.is_own_array(buffer) or jax.is_own_array(buffer)):
return
else:
return ak.contents.NumpyArray(
buffer, node.identifier, node.parameters, nplike=nplike
)

return layout.recursively_apply(action=action)
return layout.recursively_apply(action=action, numpy_to_regular=False)


T = TypeVar(
Expand All @@ -59,8 +67,6 @@ def __init__(

@classmethod
def from_array_or_layout(cls, obj: T):
import numpy

is_highlevel = isinstance(obj, (highlevel.Array, highlevel.Record))
if is_highlevel:
layout = obj.layout
Expand All @@ -69,23 +75,32 @@ def from_array_or_layout(cls, obj: T):
else:
raise _errors.wrap_error(TypeError)

# First, make sure we're all JAX
layout = layout.to_backend("jax")

# Now pull out the Jax tracers / arrays
buffers = find_all_buffers(layout)

# Drop the references to the existing buffers by replacing them with empty buffers
# FIXME: This works-around the fact that AuxData should probably contain only a form and length,
# rather than the actual layout (which holds references to the buffers that we're returning)
# We use NumPy buffers here to ensure that we don't create any new tracers (they're just placeholders)
# This is particularly unpleasant, because we're mixing nplikes here (deliberately)
# We should use `to_buffers`.
import numpy as _numpy

def create_placeholder_like(array) -> numpy.ndarray:
data = numpy.empty(1, dtype=array.dtype)
def create_placeholder_like(array) -> _numpy.ndarray:
data = _numpy.empty(1, dtype=array.dtype)
strides = tuple(0 for _ in array.shape)
return numpy.lib.stride_tricks.as_strided(
return _numpy.lib.stride_tricks.as_strided(
data, array.shape, strides=strides, writeable=False
)

return buffers, AuxData(
layout=replace_all_buffers(
layout, [create_placeholder_like(n) for n in buffers]
layout,
[create_placeholder_like(n) for n in buffers],
nplike=nplikes.Numpy.instance(),
),
is_highlevel=is_highlevel,
behavior=ak._util.behavior_of(obj),
Expand All @@ -103,10 +118,12 @@ def behavior(self) -> dict | None:
def is_highlevel(self) -> bool:
return self._is_highlevel

def _validate_buffers(self, buffers: list) -> NoReturn:
def unflatten(self, buffers: tuple) -> T:
for buffer in buffers:
if buffer.dtype == np.dtype([("float0", "V")]):
raise ak._errors.wrap_error(
# Check that JAX isn't trying to give us float0 types
dtype = getattr(buffer, "dtype", None)
if dtype == np.dtype([("float0", "V")]):
raise _errors.wrap_error(
TypeError(
f"a buffer with the dtype {buffer.dtype} was encountered during unflattening. "
"JAX uses this dtype for the tangents of integer/boolean outputs; these cannot "
Expand All @@ -115,10 +132,10 @@ def _validate_buffers(self, buffers: list) -> NoReturn:
)
)

def unflatten(self, buffers: list) -> T:
self._validate_buffers(buffers)

layout = replace_all_buffers(self._layout, list(buffers))
# Replace the mixed NumPy-JAX layout leaves with the given buffers (and use the JAX nplike)
layout = replace_all_buffers(
self._layout, list(buffers), nplike=nplikes.Jax.instance()
)
return ak._util.wrap(
layout, behavior=self._behavior, highlevel=self._is_highlevel
)
Expand All @@ -136,7 +153,7 @@ def jax_flatten(
return result


def jax_unflatten(aux_data: AuxData, children: list[numpy.ndarray]) -> T:
def jax_unflatten(aux_data: AuxData, children: list[numpy.ndarray]) -> T | None:
return aux_data.unflatten(children)


Expand Down
25 changes: 25 additions & 0 deletions tests/test_1764-jax-jacobian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import pytest

import awkward as ak

jax = pytest.importorskip("jax")
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)

ak.jax.register_and_check()


@pytest.mark.skip("Jacobian support not implemented")
def test():
array = ak.Array([[1, 2, 3], [4, 5, 6.0]])

def func(x):
return x * 2 - 1

array_np = ak.to_numpy(array)
jac_np = jax.jacfwd(func)(array_np)

jac = jax.jacfwd(func)(array)
assert jac.to_list() == jac_np.tolist()

0 comments on commit df377bd

Please sign in to comment.