Skip to content

Commit

Permalink
Adds a minimal but viable implementation of string arrays (with `nump…
Browse files Browse the repository at this point in the history
…y.dtypes.StringDType`) in JAX. Currently this only supports making of a string array by means of either `jax.numpy.asarray` or `jax.device_put` and reading it back with `jax.device_get`.

PiperOrigin-RevId: 716042460
  • Loading branch information
Google-ML-Automation committed Feb 3, 2025
1 parent 17d0b86 commit d77f1e0
Show file tree
Hide file tree
Showing 8 changed files with 316 additions and 14 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ pytype_strict_library(
":traceback_util",
":typing",
":util",
"//jax/_src/lib",
] + py_deps("ml_dtypes") + py_deps("numpy"),
)

Expand Down
16 changes: 16 additions & 0 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2223,6 +2223,18 @@ def _infer_src_sharding(src, x) -> Sharding | None:
return None


@lru_cache(maxsize=2048)
def _check_string_compatible_sharding(s):
"""Checks if target devices are compatible with string arrays."""
if isinstance(s, xc.Device) and s.device_kind == "cpu":
return
if isinstance(s, Sharding) and next(iter(s.device_set)).device_kind == "cpu":
return
raise TypeError(
"String arrays can only be sharded to CPU devices. Received"
f" unsupported device or sharding: {s}"
)

# TODO(yashkatariya): Generalize check_compatible_aval (maybe renamed) and use
# that to check if shardings are compatible with the input.
@lru_cache(maxsize=2048)
Expand All @@ -2233,6 +2245,10 @@ def _check_sharding(aval, s):
"`jax.device_put` only accepts `None`, `jax.sharding.Sharding`,"
" `jax.Device`, `Layout` or a pytree of these values. Received"
f" invalid value: {s}")

if isinstance(aval, core.ShapedArray) and dtypes.is_string_dtype(aval.dtype):
_check_string_compatible_sharding(s)

if isinstance(s, Sharding):
if isinstance(aval, core.AbstractToken):
aval = core.token_shaped_array
Expand Down
7 changes: 5 additions & 2 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,11 +1479,14 @@ def lattice_join(x, y):

def valid_jaxtype(x) -> bool:
try:
abstractify(x)
aval = abstractify(x)
except TypeError:
return False
else:
return True
if hasattr(aval, "dtype") and dtypes.is_string_dtype(aval.dtype):
return False
else:
return True

def check_valid_jaxtype(x):
if not valid_jaxtype(x):
Expand Down
50 changes: 40 additions & 10 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import numpy as np

from jax._src import config
from jax._src.lib import xla_extension_version
from jax._src.typing import Array, DType, DTypeLike
from jax._src.util import set_module, StrictABC

Expand Down Expand Up @@ -486,17 +487,36 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
np.dtype('complex64'),
np.dtype('complex128'),
]
_jax_types = _bool_types + _int_types + _float_types + _complex_types
_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types}


# We add the StringDType only to `_jax_dtype_set` but not to `_jax_types` and
# `_dtype_kinds`. This is because, in spite of a very similar sounding name,
# `_jax_types` is only meant for the promotion related logic, and StringDType
# does not participate in promotions at the moment. Similarly, `_dtype_kinds` is
# only meant for the `jnp.isdtype` and we want to be conservative and not allow
# StringDType to be used in there.
_string_types: list[JAXType] = []
if hasattr(np.dtypes, 'StringDType') and xla_extension_version >= 310:
_string_types: list[JAXType] = [np.dtypes.StringDType()] # type: ignore

_jax_dtype_set = {
float0,
*_bool_types,
*_int_types,
*_float_types,
*_complex_types,
*_string_types,
}

_jax_types = (_bool_types + _int_types + _float_types + _complex_types)

_dtype_kinds: dict[str, set] = {
'bool': {*_bool_types},
'signed integer': {*_signed_types},
'unsigned integer': {*_unsigned_types},
'integral': {*_signed_types, *_unsigned_types},
'real floating': {*_float_types},
'complex floating': {*_complex_types},
'bool': {*_bool_types},
'signed integer': {*_signed_types},
'unsigned integer': {*_unsigned_types},
'integral': {*_signed_types, *_unsigned_types},
'real floating': {*_float_types},
'complex floating': {*_complex_types},
'numeric': {*_signed_types, *_unsigned_types, *_float_types, *_complex_types},
}

Expand Down Expand Up @@ -870,8 +890,14 @@ def check_user_dtype_supported(dtype, fun_name=None):
uint2,
uint4
]
if np_dtype.kind not in "biufc" and not is_custom_dtype and not dtype == float0:
msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
if (
np_dtype.kind not in 'biufcT'
and not is_custom_dtype
and not dtype == float0
):
msg = (
f'JAX only supports number, bool, and string dtypes, got dtype {dtype}'
)
msg += f" in {fun_name}" if fun_name else ""
raise TypeError(msg)
if dtype is not None and np_dtype != canonicalize_dtype(np_dtype):
Expand Down Expand Up @@ -949,3 +975,7 @@ def short_dtype_name(dtype) -> str:
else:
return (dtype.name.replace('float', 'f').replace('uint' , 'u')
.replace('int' , 'i').replace('complex', 'c'))


def is_string_dtype(dtype: DTypeLike | None) -> bool:
return dtype in _string_types
28 changes: 28 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from jax._src.lax.lax import (PrecisionLike,_array_copy,
_sort_le_comparator, _sort_lt_comparator)
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
from jax._src.numpy import util
Expand Down Expand Up @@ -5473,6 +5474,11 @@ def _supports_buffer_protocol(obj):
else:
return True

def _can_cast(from_dtype: DTypeLike, to_dtype: DTypeLike) -> bool:
"""Returns True if from_dtype can be cast to to_dtype."""
# Casting from or to StringDType is not supported.
return dtypes.is_string_dtype(from_dtype) == dtypes.is_string_dtype(to_dtype)


@export
def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
Expand Down Expand Up @@ -5548,6 +5554,18 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
# check if the given dtype is compatible with JAX
dtypes.check_user_dtype_supported(dtype, "array")

# Check if the object's dtype is castable to the explicitly specified dtype
# arg (if any).
if (
dtype is not None
and hasattr(object, "dtype") # Is there a better check for this?
and not _can_cast(from_dtype=object.dtype, to_dtype=dtype)
):
raise TypeError(
f"Cannot make an array with dtype {dtype} from an object with dtype"
f" {object.dtype}."
)

# Here we make a judgment call: we only return a weakly-typed array when the
# input object itself is weakly typed. That ensures asarray(x) is a no-op
# whenever x is weak, but avoids introducing weak types with something like
Expand All @@ -5567,6 +5585,16 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
# Keep the output uncommitted.
return jax.device_put(object)

# Do a device_put for string arrays since XLA does not support string dtype.
if (isinstance(object, np.ndarray) and dtypes.is_string_dtype(object.dtype)
and xla_extension_version >= 310):
if ndmin > object.ndim:
raise TypeError(
f"ndmin {ndmin} cannot be greater than object's ndims"
f" {object.ndim} for string arrays."
)
return jax.device_put(x=object, device=device)

# For Python scalar literals, call coerce_to_array to catch any overflow
# errors. We don't use dtypes.is_python_scalar because we don't want this
# triggering for traced values. We do this here because it matters whether or
Expand Down
6 changes: 6 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1607,6 +1607,11 @@ jax_py_test(
],
)

jax_multiplatform_test(
name = "string_array_test",
srcs = ["string_array_test.py"],
)

jax_multiplatform_test(
name = "cudnn_fusion_test",
srcs = ["cudnn_fusion_test.py"],
Expand Down Expand Up @@ -1642,6 +1647,7 @@ exports_files(
"shard_map_test.py",
"transfer_guard_test.py",
"layout_test.py",
"string_array_test.py",
],
visibility = jax_test_file_visibility,
)
Expand Down
5 changes: 3 additions & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3758,8 +3758,9 @@ def testArrayCopyVmap(self):
self.assertIsNot(x, y)

def testArrayUnsupportedDtypeError(self):
with self.assertRaisesRegex(TypeError,
"JAX only supports number and bool dtypes.*"):
with self.assertRaisesRegex(
TypeError, 'JAX only supports number, bool, and string dtypes.*'
):
jnp.array(3, [('a','<i4'),('b','<i4')])

def testArrayFromInteger(self):
Expand Down
Loading

0 comments on commit d77f1e0

Please sign in to comment.