diff --git a/jax/BUILD b/jax/BUILD index b5efb3d90dbd..f845654d7ee8 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -498,6 +498,7 @@ pytype_strict_library( ":traceback_util", ":typing", ":util", + "//jax/_src/lib", ] + py_deps("ml_dtypes") + py_deps("numpy"), ) diff --git a/jax/_src/api.py b/jax/_src/api.py index e4ead66236b9..56d91cdf8539 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -80,7 +80,6 @@ from jax._src.interpreters import pxla from jax._src.interpreters import xla - traceback_util.register_exclusion(__file__) _dtype = partial(dtypes.dtype, canonicalize=True) @@ -984,6 +983,14 @@ def vmap_f(*args, **kwargs): "to the positional arguments passed to the function, " f"but got {len(in_axes)=}, {len(args)=}") args_flat, in_tree = tree_flatten((args, kwargs), is_leaf=batching.is_vmappable) + + # StringDTtype arrays are not supported for vmap. + if any( + hasattr(x, "dtype") and dtypes.is_string_dtype(x.dtype) + for x in args_flat + ): + raise TypeError("StringDType arrays are not supported for vmap") + f = lu.wrap_init(fun) flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree) in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True) @@ -2194,6 +2201,19 @@ def _infer_src_sharding(src, x) -> Sharding | None: return None +@lru_cache(maxsize=2048) +def _check_string_compatible_sharding(s): + """Checks if sharding is compatible with StringDType arrays.""" + if isinstance(s, xc.Device) and s.device_kind == "cpu": + return + if isinstance(s, Sharding) and next(iter(s.device_set)).device_kind == "cpu": + return + raise TypeError( + "StringDType arrays can only be sharded to CPU devices. Received" + f" unsupported device or sharding: {s}" + ) + # TODO(jmudigonda): Add checks for Layout and TransferToMemoryKind. + # TODO(yashkatariya): Generalize check_compatible_aval (maybe renamed) and use # that to check if shardings are compatible with the input. @lru_cache(maxsize=2048) @@ -2204,6 +2224,10 @@ def _check_sharding(aval, s): "`jax.device_put` only accepts `None`, `jax.sharding.Sharding`," " `jax.Device`, `Layout` or a pytree of these values. Received" f" invalid value: {s}") + + if hasattr(aval, "dtype") and dtypes.is_string_dtype(aval.dtype): + _check_string_compatible_sharding(s) + if isinstance(s, Sharding): if isinstance(aval, core.AbstractToken): aval = core.token_shaped_array diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 4893e833532c..3ab3fb1f2128 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -55,7 +55,6 @@ is_single_device_sharding) import numpy as np - JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration" JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration" BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration" @@ -279,12 +278,12 @@ def _is_bint_axis_size(d: core.AxisSize) -> bool: type(d.aval.dtype) is core.bint) return False - def check_arg(arg: Any): - if not (isinstance(arg, core.Tracer) or core.valid_jaxtype(arg)): - raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid " - "JAX type.") - + if isinstance(arg, core.Tracer): + return + aval = core.abstractify(arg) + if hasattr(aval, "dtype") and dtypes.is_string_dtype(aval.dtype): + raise TypeError("StringDType arrays are not supported by jit") def jaxpr_replicas(jaxpr: core.Jaxpr) -> int: """The number of replicas needed for a jaxpr. diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 04b07843a324..3b971499ab07 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -33,6 +33,7 @@ import numpy as np from jax._src import config +from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member from jax._src.typing import Array, DType, DTypeLike from jax._src.util import set_module, StrictABC @@ -478,18 +479,41 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, np.dtype('complex64'), np.dtype('complex128'), ] -_jax_types = _bool_types + _int_types + _float_types + _complex_types -_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types} +_string_types: list[JAXType] = [] +try: + import numpy.dtypes as np_dtypes + if hasattr(np_dtypes, 'StringDType') and xla_extension_version >= 304: + _string_types: list[JAXType] = [np_dtypes.StringDType()] # type: ignore +except ImportError: + np_dtypes = None # type: ignore + +_jax_types = ( + _bool_types + _int_types + _float_types + _complex_types + _string_types +) +_jax_dtype_set = { + float0, + *_bool_types, + *_int_types, + *_float_types, + *_complex_types, + *_string_types, +} _dtype_kinds: dict[str, set] = { - 'bool': {*_bool_types}, - 'signed integer': {*_signed_types}, - 'unsigned integer': {*_unsigned_types}, - 'integral': {*_signed_types, *_unsigned_types}, - 'real floating': {*_float_types}, - 'complex floating': {*_complex_types}, - 'numeric': {*_signed_types, *_unsigned_types, *_float_types, *_complex_types}, + 'bool': {*_bool_types}, + 'signed integer': {*_signed_types}, + 'unsigned integer': {*_unsigned_types}, + 'integral': {*_signed_types, *_unsigned_types}, + 'real floating': {*_float_types}, + 'complex floating': {*_complex_types}, + 'numeric': { + *_signed_types, + *_unsigned_types, + *_float_types, + *_complex_types, + }, + 'string': {*_string_types}, } @@ -855,8 +879,15 @@ def check_user_dtype_supported(dtype, fun_name=None): uint2, uint4 ] - if np_dtype.kind not in "biufc" and not is_custom_dtype and not dtype == float0: - msg = f"JAX only supports number and bool dtypes, got dtype {dtype}" + if ( + np_dtype.kind not in 'biufcT' + and not is_custom_dtype + and not dtype == float0 + ): + msg = ( + 'JAX only supports number, bool and StringDType dtypes, got dtype' + f' {dtype}' + ) msg += f" in {fun_name}" if fun_name else "" raise TypeError(msg) if dtype is not None and np_dtype != canonicalize_dtype(np_dtype): @@ -934,3 +965,7 @@ def short_dtype_name(dtype) -> str: else: return (dtype.name.replace('float', 'f').replace('uint' , 'u') .replace('int' , 'i').replace('complex', 'c')) + + +def is_string_dtype(dtype: DTypeLike | None) -> bool: + return dtype in _dtype_kinds['string'] diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index dc689b619a6c..2496c9c13316 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -57,21 +57,32 @@ from jax._src.lax.lax import (PrecisionLike,_array_copy, _sort_le_comparator, _sort_lt_comparator) from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member from jax._src.numpy import reductions from jax._src.numpy import ufuncs from jax._src.numpy import util from jax._src.numpy.vectorize import vectorize +from jax._src.sharding_impls import ( + NamedSharding, + PartitionSpec as P, + SingleDeviceSharding, + canonicalize_sharding, +) from jax._src.typing import ( Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, DuckTypedArray, Shape, StaticScalar, ) from jax._src.util import ( - NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, - ceil_of_ratio, partition_list, safe_zip, set_module, unzip2, - tuple_replace) + NumpyComplexWarning, + canonicalize_axis as _canonicalize_axis, + ceil_of_ratio, + partition_list, + safe_zip, + set_module, + tuple_replace, + unzip2, +) from jax.sharding import Sharding -from jax._src.sharding_impls import (SingleDeviceSharding, NamedSharding, - PartitionSpec as P, canonicalize_sharding) from jax.tree_util import tree_flatten, tree_leaves, tree_map import numpy as np import opt_einsum @@ -5564,6 +5575,34 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, # Keep the output uncommitted. return jax.device_put(object) + # Do a device_put for string arrays since XLA does not support string dtype. + if ( + isinstance(object, np.ndarray) + and dtypes.is_string_dtype(object.dtype) + and xla_extension_version >= 304 + ): + if dtype is not None and dtype != object.dtype: + raise TypeError( + "Cannot make a non-StringDType array from a StringDType numpy array." + f" Got dtype: {dtype}" + ) + if ndmin > 0 and ndmin != object.ndim: + raise TypeError( + f"ndmin {ndmin} does not match ndims {object.ndim} of input array" + ) + return jax.device_put(x=object, device=device) + + if ( + isinstance(object, np.ndarray) + and dtypes.is_string_dtype(dtype) + and not dtypes.is_string_dtype(object.dtype) + ): + raise TypeError( + "Cannot make a StringDType array from a non-StringDType numpy array." + f" Got numpy array of dtype: {object.dtype}" + ) + + # For Python scalar literals, call coerce_to_array to catch any overflow # errors. We don't use dtypes.is_python_scalar because we don't want this # triggering for traced values. We do this here because it matters whether or diff --git a/tests/BUILD b/tests/BUILD index 61b1608f86ef..c8ad82a076ec 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1568,6 +1568,11 @@ jax_py_test( ], ) +jax_multiplatform_test( + name = "string_array_test", + srcs = ["string_array_test.py"], +) + jax_multiplatform_test( name = "cudnn_fusion_test", srcs = ["cudnn_fusion_test.py"], @@ -1603,6 +1608,7 @@ exports_files( "shard_map_test.py", "transfer_guard_test.py", "layout_test.py", + "string_array_test.py", ], visibility = jax_test_file_visibility, ) diff --git a/tests/api_test.py b/tests/api_test.py index 5bf9a1203cc1..b7d266cc761a 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1611,7 +1611,6 @@ def f(x): with self.assertRaisesRegex(TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type"): grad(f)("foo") - err_str = ("Error interpreting argument to .* as an abstract array. The problematic " "value is of type .* and was passed to the function at path x.") with self.assertRaisesRegex(TypeError, err_str): @@ -2033,7 +2032,6 @@ def f(inp1): result = jax.device_put(x, s2) result.block_until_ready() - @jax.default_matmul_precision("float32") def test_jacobian(self): R = self.rng().randn @@ -3291,7 +3289,7 @@ def f(x, y): return x + y def test_grad_object_array_error(self): x = np.array([1, 2, 3], dtype=object) - with self.assertRaisesRegex(TypeError, ".*is not a valid JAX type"): + with self.assertRaisesRegex(TypeError, ".*is not a valid JAX array type"): jax.grad(lambda x: x)(x) @jtu.thread_unsafe_test() # logging isn't thread-safe diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index f0b8f5367bb7..6d87ebf63576 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -26,6 +26,12 @@ import numpy as np +try: + import numpy.dtypes as np_dtypes +except ImportError: + np_dtypes = None # type: ignore + + import jax from jax import numpy as jnp from jax._src import earray @@ -110,6 +116,15 @@ 32: np.uint32, 64: np.uint64, } +# Not all types are promotable. For example, currently StringDType is not +# promotable. +if hasattr(np_dtypes, 'StringDType'): + _promotable_types = [ + x for x in dtypes._jax_types if not isinstance(x, np_dtypes.StringDType) + ] +else: + _promotable_types = dtypes._jax_types + def identity(x): """A named identity function for use in tests""" @@ -772,14 +787,15 @@ def g(x): class TestPromotionTables(jtu.JaxTestCase): @parameterized.named_parameters( - {"testcase_name": f"_{jaxtype=}", "jaxtype": jaxtype} - for jaxtype in dtypes._jax_types + dtypes._weak_types) + {'testcase_name': f'_{jaxtype=}', 'jaxtype': jaxtype} + for jaxtype in _promotable_types + dtypes._weak_types + ) def testJaxTypeFromType(self, jaxtype): self.assertIs(dtypes._jax_type(*dtypes._dtype_and_weaktype(jaxtype)), jaxtype) @parameterized.named_parameters( {"testcase_name": f"_{jaxtype=}", "jaxtype": jaxtype} - for jaxtype in dtypes._jax_types + dtypes._weak_types) + for jaxtype in _promotable_types + dtypes._weak_types) def testJaxTypeFromVal(self, jaxtype): from jax._src.export import shape_poly if jaxtype is shape_poly._DimExpr: @@ -795,7 +811,7 @@ def testJaxTypeFromVal(self, jaxtype): @parameterized.named_parameters( {"testcase_name": f"_{dtype=}", "dtype": dtype} - for dtype in dtypes._jax_types) + for dtype in _promotable_types) def testJaxTypeWeak(self, dtype): jax_type = dtypes._jax_type(dtype, weak_type=True) if dtypes.issubdtype(jax_type, np.complexfloating): diff --git a/tests/export_test.py b/tests/export_test.py index b13cf3a623e6..cae3efacf709 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -48,6 +48,11 @@ import numpy as np +try: + import numpy.dtypes as np_dtypes +except ImportError: + np_dtypes = None # type: ignore + # ruff: noqa: F401 try: import flatbuffers @@ -407,7 +412,6 @@ def f(x1, x2): self.assertEqual(tree_util.tree_structure(res2), tree_util.tree_structure(res)) - def test_error_wrong_intree(self): def f(a_b_pair, *, c): return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c @@ -1002,6 +1006,12 @@ def f_jax(x): # x: bool[b] for dtype in dtypes._jax_types if dtype != np.dtype("bool") ]) def test_poly_numeric_dtypes(self, dtype=np.int32): + if hasattr(np_dtypes, "StringDType") and isinstance( + dtype, np_dtypes.StringDType + ): + self.skipTest( + "StringDType is not a numeric type" + ) # TODO(jmudigonda): revisit. if str(dtype) in {"float8_e4m3b11fnuz", "float8_e4m3fnuz", "float8_e5m2fnuz", @@ -1617,7 +1627,6 @@ def test_multi_platform_unknown_platform(self): platforms=("tpu", "cpu", "cuda", "other"))(x) self.assertEqual(exp.platforms, ("tpu", "cpu", "cuda", "other")) - def test_multi_platform_with_donation(self): f = jax.jit(jnp.sin, donate_argnums=(0,)) x = np.arange(3, dtype=np.float32) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 62b0fc994e60..a9852e6db26c 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3758,8 +3758,9 @@ def testArrayCopyVmap(self): self.assertIsNot(x, y) def testArrayUnsupportedDtypeError(self): - with self.assertRaisesRegex(TypeError, - "JAX only supports number and bool dtypes.*"): + with self.assertRaisesRegex( + TypeError, 'JAX only supports number, bool and StringDType dtypes.*' + ): jnp.array(3, [('a','