Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a minimal but viable implementation of string arrays (with numpy.dtypes.StringDType) in JAX. Currently this only supports making of a string array by means of either jax.numpy.asarray or jax.device_put and reading it back with jax.device_get. #25918

Merged
merged 2 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/tsan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,8 @@ jobs:
--test_output=errors \
--local_test_jobs=32 \
--test_timeout=600 \
--config=resultstore \
--spawn_strategy=local \
--remote_cache=remotebuildexecution.googleapis.com \
--remote_instance_name=projects/tensorflow-testing/instances/default_instance \
//tests:cpu_tests
134 changes: 132 additions & 2 deletions docs/sharded-computation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,52 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"id": "UEObolTqw4pp"
},
"source": [
"The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n",
"\n",
"The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host.\n",
"\n",
"To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aKNeOHTJnqmS",
"outputId": "847c53ec-8b2e-4be0-f993-7fde7d77c0f2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pinned_host\n",
"device\n"
]
}
],
"source": [
"s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')\n",
"s_dev = s_host.with_memory_kind('device')\n",
"arr_host = jax.device_put(arr, s_host)\n",
"arr_dev = jax.device_put(arr, s_dev)\n",
"print(arr_host.sharding.memory_kind)\n",
"print(arr_dev.sharding.memory_kind)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jDHYnVqHwaST"
},
"source": [
"## 1. Automatic parallelism via `jit`\n",
"\n",
"Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.\n",
Expand Down Expand Up @@ -354,10 +396,98 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"id": "Q4N5mrr9i_ki"
},
"source": [
"The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.\n",
"\n",
"### 1.1 Sharding transformation between memory types\n",
"\n",
"The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array.\n",
"\n",
"#### Example 1: Pinned host to device memory\n",
"\n",
"In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PXu3MhafyRHo",
"outputId": "7bc6821f-a4a9-4cf8-8b21-e279d516d27b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n",
" [ 8. 9. 10. 11. 12. 13. 14. 15.]\n",
" [16. 17. 18. 19. 20. 21. 22. 23.]\n",
" [24. 25. 26. 27. 28. 29. 30. 31.]]\n",
"device\n"
]
}
],
"source": [
"f = jax.jit(lambda x: x, out_shardings=s_dev)\n",
"out_dev = f(arr_host)\n",
"print(out_dev)\n",
"print(out_dev.sharding.memory_kind)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LuYFqpcBySiX"
},
"source": [
"#### Example 2: Device to pinned_host memory\n",
"\n",
"In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qLsgNlKfybRw",
"outputId": "a16448b9-7e39-408f-b200-505f65ad4464"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n",
" [ 8. 9. 10. 11. 12. 13. 14. 15.]\n",
" [16. 17. 18. 19. 20. 21. 22. 23.]\n",
" [24. 25. 26. 27. 28. 29. 30. 31.]]\n",
"pinned_host\n"
]
}
],
"source": [
"g = jax.jit(lambda x: x, out_shardings=s_host)\n",
"out_host = g(arr_dev)\n",
"print(out_host)\n",
"print(out_host.sharding.memory_kind)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7BGD31-owaSU"
},
"source": [
"## 2. Semi-automated sharding with constraints\n",
"\n",
"If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n",
Expand Down
67 changes: 67 additions & 0 deletions docs/sharded-computation.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,31 @@ print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)
```

+++ {"id": "UEObolTqw4pp"}

The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.

The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host.

To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding.

```{code-cell}
---
colab:
base_uri: https://localhost:8080/
id: aKNeOHTJnqmS
outputId: 847c53ec-8b2e-4be0-f993-7fde7d77c0f2
---
s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')
s_dev = s_host.with_memory_kind('device')
arr_host = jax.device_put(arr, s_host)
arr_dev = jax.device_put(arr, s_dev)
print(arr_host.sharding.memory_kind)
print(arr_dev.sharding.memory_kind)
```

+++ {"id": "jDHYnVqHwaST"}

## 1. Automatic parallelism via `jit`

Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.
Expand Down Expand Up @@ -129,8 +152,52 @@ jax.debug.visualize_array_sharding(result)
print(result)
```

+++ {"id": "Q4N5mrr9i_ki"}

The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.

### 1.1 Sharding transformation between memory types

The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array.

#### Example 1: Pinned host to device memory

In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory.

```{code-cell}
---
colab:
base_uri: https://localhost:8080/
id: PXu3MhafyRHo
outputId: 7bc6821f-a4a9-4cf8-8b21-e279d516d27b
---
f = jax.jit(lambda x: x, out_shardings=s_dev)
out_dev = f(arr_host)
print(out_dev)
print(out_dev.sharding.memory_kind)
```

+++ {"id": "LuYFqpcBySiX"}

#### Example 2: Device to pinned_host memory

In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory.

```{code-cell}
---
colab:
base_uri: https://localhost:8080/
id: qLsgNlKfybRw
outputId: a16448b9-7e39-408f-b200-505f65ad4464
---
g = jax.jit(lambda x: x, out_shardings=s_host)
out_host = g(arr_dev)
print(out_host)
print(out_host.sharding.memory_kind)
```

+++ {"id": "7BGD31-owaSU"}

## 2. Semi-automated sharding with constraints

If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.
Expand Down
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ pytype_strict_library(
":traceback_util",
":typing",
":util",
"//jax/_src/lib",
] + py_deps("ml_dtypes") + py_deps("numpy"),
)

Expand Down
50 changes: 40 additions & 10 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
zip, unsafe_zip = safe_zip, zip


@api_boundary
def _nan_check_posthook(fun, args, kwargs, output):
"""Hook function called by the C++ jit/pmap to perform NaN checking."""
buffers = []
Expand All @@ -108,12 +109,18 @@ def _nan_check_posthook(fun, args, kwargs, output):

try:
dispatch.check_special(pjit.pjit_p.name, buffers)
except FloatingPointError:
# compiled_fun can only raise in this case
except dispatch.InternalFloatingPointError as e:
assert config.debug_nans.value or config.debug_infs.value
print("Invalid nan value encountered in the output of a C++-jit/pmap "
"function. Calling the de-optimized version.")
fun._cache_miss(*args, **kwargs)[0] # probably won't return
if hasattr(fun, '_fun'):
f = fun._fun
if getattr(f, '_apply_primitive', False):
raise FloatingPointError(f"invalid value ({e.ty}) encountered in {f.__qualname__}") from None
# compiled_fun can only raise in this case
dispatch.maybe_recursive_nan_check(e, f, args, kwargs)
raise AssertionError("Unreachable") from e
else:
# TODO(emilyaf): Shouldn't need this fallback.
raise

def _update_debug_special_global(_):
if config._read("jax_debug_nans") or config._read("jax_debug_infs"):
Expand Down Expand Up @@ -1574,11 +1581,14 @@ def cache_miss(*args, **kwargs):

execute: Callable | None = None
with core.take_current_trace() as trace:
if isinstance(trace, core.EvalTrace):
execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params)
out = execute(*p.flat_args)
else:
out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params)
try:
if isinstance(trace, core.EvalTrace):
execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params)
out = execute(*p.flat_args)
else:
out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params)
except dispatch.InternalFloatingPointError as e:
raise FloatingPointError(f'Invalid value ({e.ty}) encountered in parallel computation.')

out_tree, out_flat = p.out_tree, out
out_pytree_def = out_tree()
Expand Down Expand Up @@ -1629,6 +1639,7 @@ def cache_miss(*args, **kwargs):
_pmap_cache_clears.add(cpp_mapped_f)

pmap_f = wraps(fun)(cpp_mapped_f)
pmap_f._fun = fun

@api_boundary
def lower(*args, **kwargs):
Expand Down Expand Up @@ -1674,6 +1685,7 @@ def trace(*args, **kwargs):
_pmap_cache_clears = weakref.WeakSet() # type: ignore


@api_boundary
def jvp(
fun: Callable, primals, tangents, has_aux: bool = False
) -> tuple[Any, ...]:
Expand Down Expand Up @@ -1878,6 +1890,7 @@ def fun(*tangents):

return apply_flat_fun_nokwargs(fun, io_tree, py_args)

@api_boundary
def _vjp_pullback_wrapper(name, out_primal_avals, io_tree, fun, *py_args_):
if len(py_args_) != 1:
msg = (f"The function returned by `jax.vjp` applied to {name} was called "
Expand Down Expand Up @@ -1937,6 +1950,7 @@ def vjp(fun: Callable[..., tuple[T, U]], *primals: Any,
has_aux: Literal[True],
reduce_axes: Sequence[AxisName] = ()) -> tuple[T, Callable, U]:
...
@api_boundary
def vjp(
fun: Callable, *primals, has_aux: bool = False, reduce_axes=()
) -> tuple[Any, Callable] | tuple[Any, Callable, Any]:
Expand Down Expand Up @@ -2225,6 +2239,18 @@ def _infer_src_sharding(src, x) -> Sharding | None:
return None


@lru_cache(maxsize=2048)
def _check_string_compatible_sharding(s):
"""Checks if target devices are compatible with string arrays."""
if isinstance(s, xc.Device) and s.device_kind == "cpu":
return
if (isinstance(s, Sharding)
and s._internal_device_list[0].device_kind == "cpu"):
return
raise TypeError(
"String arrays can only be sharded to CPU devices. Received"
f" unsupported device or sharding: {s}")

# TODO(yashkatariya): Generalize check_compatible_aval (maybe renamed) and use
# that to check if shardings are compatible with the input.
@lru_cache(maxsize=2048)
Expand All @@ -2235,6 +2261,10 @@ def _check_sharding(aval, s):
"`jax.device_put` only accepts `None`, `jax.sharding.Sharding`,"
" `jax.Device`, `Layout` or a pytree of these values. Received"
f" invalid value: {s}")

if isinstance(aval, core.ShapedArray) and dtypes.is_string_dtype(aval.dtype):
_check_string_compatible_sharding(s)

if isinstance(s, Sharding):
if isinstance(aval, core.AbstractToken):
aval = core.get_token_aval()
Expand Down
7 changes: 5 additions & 2 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,11 +1472,14 @@ def lattice_join(x, y):

def valid_jaxtype(x) -> bool:
try:
abstractify(x)
aval = abstractify(x)
except TypeError:
return False
else:
return True
if hasattr(aval, "dtype") and dtypes.is_string_dtype(aval.dtype):
return False
else:
return True

def check_valid_jaxtype(x):
if not valid_jaxtype(x):
Expand Down
Loading
Loading