Skip to content

Commit

Permalink
Properly pack and unpack int4 arrays on CPU in PJRT.
Browse files Browse the repository at this point in the history
Transferring an array from host to device on CPU sometimes does a zero-copy implementation where no memory is actually moved. This is now never done with int4, since int4 arrays are stored in packed format on device and an unpacked format on host. Similarly, transferring an array from device to host on CPU used to always use a zero-copy implementation, but now it will unpack and copy for int4 arrays.

PiperOrigin-RevId: 578692796
  • Loading branch information
reedwm authored and jax authors committed Nov 2, 2023
1 parent 5d28961 commit d41078f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
24 changes: 20 additions & 4 deletions jax/_src/public_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,26 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
if a.dtype == b.dtype == _dtypes.float0:
np.testing.assert_array_equal(a, b, err_msg=err_msg)
return
custom_dtypes = [_dtypes.float8_e4m3b11fnuz, _dtypes.float8_e4m3fn,
_dtypes.float8_e5m2, _dtypes.bfloat16]
a = a.astype(np.float32) if a.dtype in custom_dtypes else a
b = b.astype(np.float32) if b.dtype in custom_dtypes else b

custom_float_dtypes = [_dtypes.float8_e4m3b11fnuz, _dtypes.float8_e4m3fn,
_dtypes.float8_e5m2, _dtypes.bfloat16]
def maybe_upcast(x):
if x.dtype in custom_float_dtypes:
return x.astype(np.float32)
# TODO(reedwm): Upcasting int4 to int8 will no longer be neccessary once
# ml_dtypes has a stable release with commit
# https://github.com/jax-ml/ml_dtypes/commit/348fd3704306cae97f617c38045cee6bc416bf10.
# Remove these checks once JAX depends on a version on ml_dtypes with that
# commit.
if x.dtype == _dtypes.int4:
return x.astype(np.int8)
if x.dtype == _dtypes.uint4:
return x.astype(np.uint8)
return x

a = maybe_upcast(a)
b = maybe_upcast(b)

kw = {}
if atol: kw["atol"] = atol
if rtol: kw["rtol"] = rtol
Expand Down
19 changes: 19 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_extension_version
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps
from jax._src.util import safe_zip, NumpyComplexWarning

Expand Down Expand Up @@ -3526,6 +3527,24 @@ def testAstypeNone(self):
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)

@unittest.skipIf(xla_extension_version < 210, 'jaxlib version too old')
def testAstypeInt4(self):
# Test converting from int4 to int8
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)
args_maker = lambda: [x]
np_op = lambda x: np.asarray(x).astype(jnp.int8)
jnp_op = lambda x: jnp.asarray(x).astype(jnp.int8)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)

# Test converting from int8 to int4
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int8)
args_maker = lambda: [x]
np_op = lambda x: np.asarray(x).astype(jnp.int4)
jnp_op = lambda x: jnp.asarray(x).astype(jnp.int4)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)

@jtu.sample_product(
shape=array_shapes,
dtype=all_dtypes,
Expand Down

0 comments on commit d41078f

Please sign in to comment.