Skip to content

Commit

Permalink
Make trace dispatch purely a function of context rather than a functi…
Browse files Browse the repository at this point in the history
…on of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.

PiperOrigin-RevId: 691086496
  • Loading branch information
dougalm authored and Google-ML-Automation committed Oct 29, 2024
1 parent c67cf51 commit c36e1f7
Show file tree
Hide file tree
Showing 47 changed files with 1,411 additions and 2,631 deletions.
11 changes: 4 additions & 7 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,20 +701,17 @@ def transposed(*args_flat):
transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts)
return transposed_jaxpr, cell.in_cts_zero # pytype: disable=attribute-error

def remat_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *,
jaxpr, **params):
def remat_vmap(axis_data, args, dims, *, jaxpr, **params):
assert not jaxpr.constvars
jaxpr_batched_, out_batched = batching.batch_jaxpr_axes(
pe.close_jaxpr(jaxpr), axis_size, dims,
[batching.zero_if_mapped] * len(jaxpr.outvars),
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
pe.close_jaxpr(jaxpr), axis_data, dims,
[batching.zero_if_mapped] * len(jaxpr.outvars))
jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts
if consts:
jaxpr_batched = pe.convert_constvars_jaxpr(jaxpr_batched)
out_dims = [0 if b else None for b in out_batched]
return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims
batching.axis_primitive_batchers[remat_p] = partial(remat_vmap, None)
batching.spmd_axis_primitive_batchers[remat_p] = remat_vmap
batching.fancy_primitive_batchers[remat_p] = remat_vmap

# TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule
def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn
Expand Down
29 changes: 12 additions & 17 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import weakref

import numpy as np
from contextlib import contextmanager, ExitStack
from contextlib import contextmanager

from jax._src import linear_util as lu
from jax._src import stages
Expand Down Expand Up @@ -989,10 +989,10 @@ def vmap_f(*args, **kwargs):
axis_size_ = (axis_size if axis_size is not None else
_mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap"))
try:
axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name)
out_flat = batching.batch(
flat_fun, axis_name, axis_size_, in_axes_flat,
lambda: flatten_axes("vmap out_axes", out_tree(), out_axes),
spmd_axis_name=spmd_axis_name
flat_fun, axis_data, in_axes_flat,
lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
).call_wrapped(*args_flat)
except batching.SpecMatchError as e:
out_axes_flat = flatten_axes("vmap out_axes", out_tree(), out_axes)
Expand Down Expand Up @@ -1546,16 +1546,13 @@ def cache_miss(*args, **kwargs):
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
)

map_bind_continuation, top_trace, fun_, tracers, params = (
core.map_bind_with_continuation(pxla.xla_pmap_p, p.flat_fun,
*p.flat_args, **params))
execute: Callable | None = None
if isinstance(top_trace, core.EvalTrace):
execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
out = map_bind_continuation(execute(*tracers))
else:
out = map_bind_continuation(
pxla.xla_pmap_p.process(top_trace, fun_, tracers, params))
with core.take_current_trace() as trace:
if isinstance(trace, core.EvalTrace):
execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params)
out = execute(*p.flat_args)
else:
out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params)

out_tree, out_flat = p.out_tree, out
out_pytree_def = out_tree()
Expand Down Expand Up @@ -1802,7 +1799,7 @@ def linearize(fun: Callable, *primals, has_aux: bool = False
>>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.)
...
>>> jax.jvp(f, (2.,), (3.,))
(Array(3.26819, dtype=float32, weak_type=True), Array(-5.00753, dtype=float32, weak_type=True))
(Array(3.2681944, dtype=float32, weak_type=True), Array(-5.007528, dtype=float32, weak_type=True))
>>> y, f_jvp = jax.linearize(f, 2.)
>>> print(y)
3.2681944
Expand Down Expand Up @@ -2160,9 +2157,7 @@ def make_jaxpr(
@wraps(fun)
@api_boundary
def make_jaxpr_f(*args, **kwargs):
with ExitStack() as stack:
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
with core.extend_axis_env_nd(axis_env or []):
traced = jit(fun, static_argnums=static_argnums,
abstracted_axes=abstracted_axes).trace(*args, **kwargs)
# `jit` converts tracers in consts to args but that breaks the semantics of
Expand Down
1 change: 0 additions & 1 deletion jax/_src/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,6 @@ def io_callback(
flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes)
flat_result_avals = map(lambda x: core.ShapedArray(x.shape, x.dtype),
flat_shape_dtypes)
flat_args = map(core.raise_as_much_as_possible, flat_args)
out_flat = io_callback_p.bind(
*flat_args,
callback=_FlatCallback(callback, in_tree),
Expand Down
24 changes: 21 additions & 3 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def trace_context():
return (axis_env_state, mesh_context_manager, xla_metadata_context_manager,
compute_on_context_manager, enable_x64.value,
numpy_rank_promotion.value, default_matmul_precision.value,
dynamic_shapes.value, numpy_dtype_promotion.value,
dynamic_shapes.value,
eager_constant_folding.value,
numpy_dtype_promotion.value,
default_device.value, random_seed_offset.value,
threefry_partitionable.value,
threefry_gpu_kernel_lowering.value,
Expand Down Expand Up @@ -832,6 +834,7 @@ class _GlobalExtraJitContext(NamedTuple):
numpy_dtype_promotion: str | None = None
default_matmul_precision: Any | None = None
dynamic_shapes: bool = False
eager_constant_folding: bool = False
random_seed_offset: int = 0
threefry_partitionable: bool = False
threefry_gpu_kernel_lowering: bool = False
Expand All @@ -858,7 +861,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
The initialization, which uses both config.py and core.py is done using
`_update_thread_local_jit_state` in core.py to prevent circular imports.
"""
dynamic_trace_state: Any | None = None
trace_state: Any | None = None
axis_env_state: Hashable = ()
mesh_context_manager: Hashable = ()
compute_on_context_manager: Hashable = ()
Expand All @@ -873,6 +876,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
numpy_dtype_promotion: str | None = None
default_matmul_precision: Any | None = None
dynamic_shapes: bool | None = None
eager_constant_folding : bool | None = None
random_seed_offset: int | None = None
threefry_partitionable: bool | None = None
threefry_gpu_kernel_lowering: bool | None = None
Expand Down Expand Up @@ -909,7 +913,6 @@ def update_thread_local_jit_state(**kw):
tmp = context._replace(**kw)
tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp)


# TODO(b/214340779): remove flag when XLA:CPU is improved.
jax2tf_associative_scan_reductions = bool_state(
name='jax2tf_associative_scan_reductions',
Expand Down Expand Up @@ -1163,6 +1166,11 @@ def _update_jax_memories_thread_local(val):
update_thread_local_hook=lambda val: update_thread_local_jit_state(
sharding_in_types=val))

data_dependent_tracing_fallback = bool_state(
name='jax_data_dependent_tracing_fallback',
default=False,
help=('When True, falls back to trace dispatch based on data dependence '
'instead of throwing an escaped tracer error.'))

softmax_custom_jvp = bool_state(
name='jax_softmax_custom_jvp',
Expand Down Expand Up @@ -1530,6 +1538,16 @@ def _update_disable_jit_thread_local(val):
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(dynamic_shapes=val))

# This is for stackless backward compat with e.g. equinox
eager_constant_folding = bool_state(
name='eager_constant_folding',
default=False,
help=('Attempt constant folding during staging.'),
update_global_hook=lambda val: \
_update_global_jit_state(eager_constant_folding=val),
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(eager_constant_folding=val))

# This flag is temporary during rollout of the remat barrier.
# TODO(parkers): Remove if there are no complaints.
remat_opt_barrier = bool_state(
Expand Down
Loading

0 comments on commit c36e1f7

Please sign in to comment.