diff --git a/jax/_src/api.py b/jax/_src/api.py index c1bb9ff72968..e4ead66236b9 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2415,8 +2415,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811 raise ValueError("`devices` argument to `device_put_replicated must be " "a non-empty sequence.") def _device_put_replicated(x): - aval = core.unmapped_aval(len(devices), core.no_axis_name, 0, - core.get_aval(x)) + aval = core.unmapped_aval(len(devices), 0, core.get_aval(x)) assert isinstance(aval, ShapedArray) sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape) if config.pmap_no_rank_reduction.value: diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 73c9ec8f231c..a91cc24f9cd9 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -159,8 +159,7 @@ def callback_batching_rule( new_args = [arg if dim is batching.not_mapped else batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)] batched_result_avals = tuple( - core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) - for aval in result_avals) + core.unmapped_aval(axis_size, 0, aval) for aval in result_avals) # For FFI calls we must update the layouts. We handle the output layouts # here, but the input layout updates depend on the vmap_method parameter. diff --git a/jax/_src/core.py b/jax/_src/core.py index 4ac53378ca1b..d5becbcf75b1 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2346,11 +2346,11 @@ def mapped_aval(size: AxisSize, axis: int | None, else: raise TypeError(f"no mapping handler for {aval} of type {type(aval)}") -def unmapped_aval(size: AxisSize, axis_name, axis: int | None, +def unmapped_aval(size: AxisSize, axis: int | None, aval: AbstractValue) -> AbstractValue: _, handler = aval_mapping_handlers.get(type(aval), (None, None)) if handler is not None: - return handler(size, axis_name, axis, aval) + return handler(size, axis, aval) else: raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}") @@ -2366,11 +2366,10 @@ def _map_shaped_array( weak_type=aval.weak_type, sharding=sharding) def _unmap_shaped_array( - size: int, axis_name: AxisName, axis: int | None, aval: ShapedArray - ) -> ShapedArray: + size: int, axis: int | None, aval: ShapedArray) -> ShapedArray: if axis is None: return aval elif type(axis) is int: - sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, axis, axis_name)) + sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, axis, None)) if config.sharding_in_types.value else None) return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype, weak_type=aval.weak_type, sharding=sharding) @@ -2383,7 +2382,7 @@ def _map_dshaped_array( aval.weak_type) def _unmap_dshaped_array( - size: AxisSize, axis_name: AxisName, axis: int | None, aval: DShapedArray + size: AxisSize, axis: int | None, aval: DShapedArray ) -> DShapedArray: if axis is None: return aval elif type(axis) is int: @@ -2396,7 +2395,7 @@ def _unmap_dshaped_array( aval_mapping_handlers: dict[type, AvalMapHandlerPair] = { DShapedArray: (_map_dshaped_array, _unmap_dshaped_array), ShapedArray: (_map_shaped_array, _unmap_shaped_array), - AbstractToken: (lambda _, __, a: a, lambda _, __, ___, a: a) + AbstractToken: (lambda _, __, a: a, lambda _, __, a: a) } # When a mapped function is given no axis name, we generate a name object based @@ -2777,7 +2776,7 @@ def _check_map(ctx_factory, prim, in_avals, params): raise JaxprTypeError(f"Map primitive {prim} missing 'out_axes' parameter") out_axes = params["out_axes"] - binder_avals = [unmapped_aval(axis_size, axis_name, in_axis, v.aval) + binder_avals = [unmapped_aval(axis_size, in_axis, v.aval) if in_axis is not None else v.aval for v, in_axis in zip(call_jaxpr.invars, in_axes)] for binder_aval, in_aval in zip(binder_avals, in_avals): @@ -2789,7 +2788,7 @@ def _check_map(ctx_factory, prim, in_avals, params): _check_jaxpr(ctx_factory, call_jaxpr) mapped_out_avals = [v.aval for v in call_jaxpr.outvars] - out_avals = [unmapped_aval(axis_size, axis_name, out_axis, aval) + out_avals = [unmapped_aval(axis_size, out_axis, aval) if out_axis is not None else aval for aval, out_axis in zip(mapped_out_avals, out_axes)] return out_avals, filter_named_axis_effects(call_jaxpr.effects, {axis_name}) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 4ad24fd9e928..46aad30ac97d 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -1000,7 +1000,7 @@ def out_axes_thunk(): assert len(in_axes) == len(arg_cts) def unmap_zero(zero, in_axis): return (zero if in_axis is None else - Zero(core.unmapped_aval(params['axis_size'], params['axis_name'], in_axis, zero.aval))) + Zero(core.unmapped_aval(params['axis_size'], in_axis, zero.aval))) arg_cts = (unmap_zero(arg_ct, in_axis) if type(arg_ct) is Zero else arg_ct if in_axis is not None else arg_ct.sum(0) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index f4658ec2be29..d5fb5f9856a2 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -223,7 +223,7 @@ def __init__(self, a): self.a = a if isinstance(d, RaggedAxis): raise NotImplementedError else: - new_avals.append(core.unmapped_aval(sz, axis_name, d, a)) # type: ignore + new_avals.append(core.unmapped_aval(sz, d, a)) # type: ignore mentioned = {d for a in new_avals if type(a) is core.DShapedArray for d in a.shape if type(d) is Name} @@ -750,7 +750,7 @@ def _batch_jaxpr2( handle_ragged(closed_jaxpr.in_avals, dim, aval) if isinstance(dim, RaggedAxis) else (dim, aval) for dim, aval in zip(in_axes, closed_jaxpr.in_avals)]) - avals_in2 = [core.unmapped_aval(axis_data.size, axis_data.name, b, aval) + avals_in2 = [core.unmapped_aval(axis_data.size, b, aval) if b is not not_mapped else aval for aval, b in unsafe_zip(avals_in, in_axes2)] jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2) @@ -787,7 +787,7 @@ def _batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest): f, out_axes = _batch_jaxpr_inner(f, axis_data) f, out_batched = _match_axes_jaxpr(f, axis_data, out_axes_dest, out_axes) f = _batch_jaxpr_outer(f, axis_data, in_axes) - avals_in = [core.unmapped_aval(axis_data.size, axis_data.name, b, aval) if b is not not_mapped + avals_in = [core.unmapped_aval(axis_data.size, b, aval) if b is not not_mapped else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)] jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, consts), out_batched() @@ -906,9 +906,9 @@ def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False) return x elif type(src) == type(dst) == int: aval = core.mapped_aval(sz, src, x.aval) - return Zero(core.unmapped_aval(sz, name, dst, aval)) + return Zero(core.unmapped_aval(sz, dst, aval)) elif src is not_mapped and dst is not not_mapped: - return Zero(core.unmapped_aval(sz, name, dst, x.aval)) + return Zero(core.unmapped_aval(sz, dst, x.aval)) elif dst is not_mapped and sum_match: return Zero(core.mapped_aval(sz, src, x.aval)) else: diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index bd9ba286dc27..dd81d8d4a552 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -372,7 +372,7 @@ def const_out_axes_thunk(): out_axes=tuple(staged_out_axes), call_jaxpr=call_jaxpr) del staged_params['out_axes_thunk'] # The outputs of the staged-out call are Tracers with the new eqn as recipe. - out_avals = [unmapped_aval(params['axis_size'], params['axis_name'], ax, a) + out_avals = [unmapped_aval(params['axis_size'], ax, a) for ax, a in zip(staged_out_axes, out_avals_mapped)] out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) for a in out_avals] @@ -1956,7 +1956,7 @@ def process_map(self, map_primitive, f, tracers, params): raise ValueError("Ordered effects not supported for " f"map primitives: {ordered_effects}") out_axes = params['out_axes_thunk']() - out_avals = [core.unmapped_aval(axis_size, axis_name, out_axis, a) + out_avals = [core.unmapped_aval(axis_size, out_axis, a) if out_axis is not None else a for a, out_axis in zip(reduced_out_avals, out_axes)] source_info = source_info_util.current() diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index b918d3ed5726..49d95cb6f474 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -914,7 +914,7 @@ def _pmap_unmap_shaped_array( def _pmap_unmapped_aval(size: core.AxisSize, axis_name, axis: int | None, aval: core.AbstractValue) -> core.AbstractValue: if not config.pmap_no_rank_reduction.value: - return core.unmapped_aval(size, axis_name, axis, aval) + return core.unmapped_aval(size, axis, aval) _, handler = _pmap_aval_mapping_handlers.get(type(aval), (None, None)) if handler is not None: @@ -1350,7 +1350,7 @@ def _pmap_partial_eval_custom_params_updater( return new_params_known, new_params_staged def _pmap_partial_eval_custom_res_maker(params_known, aval): - return core.unmapped_aval(params_known['axis_size'], core.no_axis_name, 0, aval) + return core.unmapped_aval(params_known['axis_size'], 0, aval) def _pmap_dce_rule(used_outputs, eqn): # just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 7e54b68defc3..5a2a5608d732 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -520,7 +520,7 @@ def _stage_jaxpr_abstract_eval(*_, jaxpr): return jaxpr.out_avals, jaxpr.effects def _prepend_dim_to_aval(sz, aval): - return core.unmapped_aval(sz, None, 0, aval) + return core.unmapped_aval(sz, 0, aval) def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, linear, unroll, _split_transpose): @@ -704,7 +704,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, extensive_res = _map(trace.new_instantiated_const, extensive_res) # Create output tracers for jaxpr_unknown bind, adapting extensive shapes. carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)]) - ys_avals = [core.unmapped_aval(length, None, 0, y_aval) + ys_avals = [core.unmapped_aval(length, 0, y_aval) for y_aval in y_avals] out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) for a in itertools.chain(carry_avals, ys_avals)] @@ -1071,7 +1071,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn): # Create residual variables. intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals) - ext_avals = [core.unmapped_aval(eqn.params['length'], None, 0, a) + ext_avals = [core.unmapped_aval(eqn.params['length'], 0, a) for a in ext_avals_mapped] newvar = core.gensym() intensive_res = _map(newvar, intensive_avals) @@ -1149,7 +1149,7 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, jaxpr.in_avals, [num_consts, num_carry]) carry_avals_jaxpr, y_avals_mapped = split_list(jaxpr.out_avals, [num_carry]) x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals) - y_avals = [core.unmapped_aval(length, None, 0, a) + y_avals = [core.unmapped_aval(length, 0, a) for a in y_avals_mapped] if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)): diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 276e40f09b57..1120c0bf44eb 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -367,9 +367,8 @@ def __hash__(self): def _map_ref(size, axis, ref_aval): return AbstractRef(core.mapped_aval(size, axis, ref_aval.inner_aval)) -def _unmap_ref(size, axis_name, axis, ref_aval): - return AbstractRef(core.unmapped_aval(size, axis_name, axis, - ref_aval.inner_aval)) +def _unmap_ref(size, axis, ref_aval): + return AbstractRef(core.unmapped_aval(size, axis, ref_aval.inner_aval)) core.aval_mapping_handlers[AbstractRef] = (_map_ref, _unmap_ref) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 131b8a9645f2..660d2b57d4e0 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1613,7 +1613,7 @@ def fun(*res_and_args): res, args = split_list(res_and_args, [len(jaxpr.constvars)]) res = [_rem_singleton(x) if w else x for x, w in zip(res, which)] return core.eval_jaxpr(jaxpr, res, *args) - res_avals = [core.unmapped_aval(1, None, 0, v.aval) if w else v.aval + res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval for v, w in zip(jaxpr.constvars, which)] in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]] jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(fun, in_avals) @@ -1740,7 +1740,7 @@ def staged(*args): res_, ins = split_list(args, [len(which)]) res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)] return core.eval_jaxpr(jaxpr_staged, (), *res, *ins) - res_avals = [core.unmapped_aval(1, None, 0, v.aval) if w else v.aval + res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval for w, v in zip(which_, jaxpr_staged.invars[:len(which)])] avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]] jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic(staged, avals_in)