Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Feb 9, 2025
1 parent b3df4d5 commit 4fa7e91
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 43 deletions.
23 changes: 15 additions & 8 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents):
def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params):
with core.take_current_trace() as parent_trace:
tangent_trace = pe.DynamicJaxprTrace(None) # TODO(necula): fill-in the debug info (use JAX_USE_DIRECT_LINEARIZATION=1)
tangent_trace.tag = _tag
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag)
tracers = [LinearizeTracer(linearize_trace, p,
tangent_trace.new_arg(get_aval(p).to_tangent_aval()))
Expand All @@ -100,11 +101,15 @@ def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params):
out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz)
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment]
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, None)
residual_avals = map(get_aval, consts)
which_env = [(isinstance(c, pe.DynamicJaxprTracer) and
getattr(c._trace, 'tag', None) is _tag) for c in consts]
jaxpr = pe.move_envvars(jaxpr, tuple(which_env))
res, env = partition_list(which_env, consts)
residual_avals = map(get_aval, res)
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
_store.store((residual_avals, nzs_out, jaxpr))
return tuple(consts) + tuple(out_primals)
_store.store((residual_avals, nzs_out, jaxpr, env))
return *res, *out_primals

@lu.transformation2
def jvp_subtrace(f: Callable, tag: core.TraceTag, primals, tangents):
Expand Down Expand Up @@ -153,6 +158,7 @@ def _linearize_jaxpr(
primal_trace = pe.DynamicJaxprTrace(dbg)
tangent_trace = pe.DynamicJaxprTrace(dbg)
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
tangent_trace.tag = lin_trace.tag

def new_arg(trace, primal_aval, nz):
primal = primal_trace.new_arg(primal_aval)
Expand Down Expand Up @@ -193,6 +199,7 @@ def direct_linearize(traceable: lu.WrappedFun,
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals]
tangents = [Zero.from_primal_value(t) if dtype(t) == float0 else t for t in tangents]
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag)
tangent_trace.tag = linearize_trace.tag
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
tracers = [t.full_lower() for t in tracers]
with (core.set_current_trace(linearize_trace, check_leaks=True),
Expand Down Expand Up @@ -686,15 +693,15 @@ def process_call(self, call_primitive, f, tracers, params):
if isinstance(call_primitive, core.MapPrimitive):
@as_hashable_function(closure=(linearize_outs_thunk))
def new_out_axes_thunk():
residual_avals, _, _ = linearize_outs_thunk()
residual_avals, _, _, _, _ = linearize_outs_thunk()
out_axes = params['out_axes_thunk']()
return (*(0 for _ in residual_avals), *out_axes)
primal_params = dict(params, out_axes_thunk=new_out_axes_thunk)
else:
primal_params = params

all_primal_results = call_primitive.bind_with_trace(self.parent_trace, (f_primal, *primals), primal_params)
residual_avals, nzs_out, lin_jaxpr = linearize_outs_thunk()
residual_avals, nzs_out, lin_jaxpr, env = linearize_outs_thunk()
num_residuals = len(residual_avals)
residuals = all_primal_results[:num_residuals]
primals_out = all_primal_results[num_residuals:]
Expand All @@ -719,13 +726,13 @@ def new_out_axes_thunk():
new_params = update_params(params, residual_avals, nzs_in) if update_params else params

def f_tangent(*args):
residuals = args[:num_residuals]
consts = args[:num_residuals]
nz_tangents = args[num_residuals:]
return core.eval_jaxpr(lin_jaxpr, residuals, *nz_tangents)
return core.eval_jaxpr(lin_jaxpr, consts, *nz_tangents)

nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
nz_tangents_out = call_primitive.bind_with_trace(
self.tangent_trace, (lu.wrap_init(f_tangent), *residuals, *nz_tangents_in), new_params)
self.tangent_trace, (lu.wrap_init(f_tangent), *residuals, *env, *nz_tangents_in), new_params)
nz_tangents_out_iter = iter(nz_tangents_out)
tangents_out = [next(nz_tangents_out_iter) if nz else Zero.from_primal_value(primal)
for nz, primal in zip(nzs_out, primals_out)]
Expand Down
7 changes: 6 additions & 1 deletion jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,11 @@ def type_substitute(aval: AbstractValue) -> AbstractValue:
# del getvar # needed to avoid cyclic-reference closure, apparently!
return jaxpr, const_vals, env_vals

@weakref_lru_cache
def move_envvars(jaxpr: Jaxpr, which: tuple[bool, ...]) -> Jaxpr:
constvars, envvars = partition_list(which, jaxpr.constvars)
return jaxpr.replace(constvars=constvars, invars=[*envvars, *jaxpr.invars])

@weakref_lru_cache
def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
"""Moves the constvars to the start of invars."""
Expand Down Expand Up @@ -1824,7 +1829,7 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]:


class DynamicJaxprTrace(core.Trace):
__slots__ = ("frame",)
__slots__ = ("frame", "tag")

def __init__(self, debug_info: core.DebugInfo | None):
self.frame = JaxprStackFrame(debug_info)
Expand Down
50 changes: 26 additions & 24 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ def to_val_rep_pair(self, val):
if isinstance(val, ShardMapTracer):
return val.val, val.rep
elif isinstance(val, Tracer):
raise Exception("Shouldn't have any non-shard_map tracers")
raise Exception(f"Shouldn't have any non-shard_map tracers: {val}")
else:
val_ = _unmatch_spec(self.mesh, {}, val, self.context_mesh)
return val_, None
Expand Down Expand Up @@ -1589,7 +1589,7 @@ def _shard_map_linearize(trace, shard_map_p, f, tracers, mesh, in_names,

@as_hashable_function(closure=(linearize_outs_thunk))
def primal_out_names_thunk():
residual_avals, _, _ = linearize_outs_thunk()
residual_avals, _, _, _ = linearize_outs_thunk()
out_names = out_names_thunk()
# This is incorrect so we set `check_rep=False` as we do in the JVP rule.
return (*({0: all_names} for _ in residual_avals), *out_names)
Expand All @@ -1598,17 +1598,19 @@ def primal_out_names_thunk():
out_names_thunk=primal_out_names_thunk, check_rep=check_rep,
rewrite=rewrite, auto=auto)
all_primal_results = shard_map_p.bind_with_trace(
trace.parent_trace, (f_primal,) + tuple(primals), primal_params)
residual_avals, nzs_out, lin_jaxpr = linearize_outs_thunk()
num_residuals = len(residual_avals)
residuals = all_primal_results[:num_residuals]
primals_out = all_primal_results[num_residuals:]
trace.parent_trace, (f_primal, *primals), primal_params)
residual_avals, nzs_out, lin_jaxpr, env = linearize_outs_thunk()
num_res = len(residual_avals)
residuals = all_primal_results[:num_res]
primals_out = all_primal_results[num_res:]
args_to_promote = [getattr(aval, 'shape', ()) == () for aval in residual_avals]
lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote)
with core.extend_axis_env_nd(mesh.shape.items()):
lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote)
out_names = out_names_thunk()
new_in_names = (*({0: all_names} for _ in residual_avals),
new_in_names = (*({0: all_names} for _ in range(num_res)),
*({} for _ in range(len(env))),
*(ax for ax, nz in zip(in_names, nzs_in) if nz))
new_out_names = (*(ax for ax, nz in zip(out_names, nzs_out) if nz),)
new_out_names = tuple(ax for ax, nz in zip(out_names, nzs_out) if nz)
@as_hashable_function(closure=(new_out_names))
def tangent_out_names_thunk():
return new_out_names
Expand All @@ -1618,13 +1620,11 @@ def tangent_out_names_thunk():
rewrite=rewrite, auto=auto)

def f_tangent(*args):
residuals = args[:num_residuals]
nz_tangents = args[num_residuals:]
return core.eval_jaxpr(lin_jaxpr, (), *residuals, *nz_tangents)
return core.eval_jaxpr(lin_jaxpr, (), *args)

nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
nz_tangents_out = shard_map_p.bind_with_trace(trace.tangent_trace,
(lu.wrap_init(f_tangent), *residuals, *nz_tangents_in), tangent_params)
(lu.wrap_init(f_tangent), *residuals, *env, *nz_tangents_in), tangent_params)
nz_tangents_out_iter = iter(nz_tangents_out)
tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal)
for nz, primal in zip(nzs_out, primals_out)]
Expand All @@ -1634,13 +1634,13 @@ def f_tangent(*args):
@lu.transformation2
def _promote_scalar_residuals_lin(f, linearize_outs_thunk, *args, **kwargs):
ans = f(*args, **kwargs)
residual_avals, _, _ = linearize_outs_thunk()
residual_avals, _, _, _ = linearize_outs_thunk()
num_residuals = len(residual_avals)
residuals = ans[:num_residuals]
primals = ans[num_residuals:]
residuals = tuple(jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
for x in residuals)
return residuals + primals
residuals = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
for x in residuals]
return *residuals, *primals

@lu.transformation2
def _promote_scalar_residuals(f: Callable, *args, **kwargs):
Expand Down Expand Up @@ -1679,9 +1679,9 @@ def _shard_map_transpose(out_cts, *args,
check_rep, rewrite, auto):
mb_div = lambda x, y: x / y if y != 1 else x
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else x if rewrite or dtypes.dtype(x) == dtypes.float0
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto))))
for ns, x in zip(out_names, out_cts)]
else x if rewrite or dtypes.dtype(x) == dtypes.float0
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto))))
for ns, x in zip(out_names, out_cts)]
args = tuple(x if type(x) is not ad.UndefinedPrimal else
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
for ns, x in zip(in_names, args))
Expand All @@ -1692,13 +1692,14 @@ def fun_trans_callable(out_cts, args):
jaxpr_known, jaxpr_unknown, _, _ = pe.partial_eval_jaxpr_nounits(
pe.close_jaxpr(jaxpr), map(ad.is_undefined_primal, args), False)
res_reshaped = core.jaxpr_as_fun(jaxpr_known)(*res)
# TODO TODO transpose only wrt invars!
out = ad.backward_pass(
jaxpr_unknown.jaxpr, False, (), (*res_reshaped, *undefs), out_cts
)
out = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else x if rewrite
else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto)))
for ns, x in zip(in_names, out)]
else x if rewrite
else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto)))
for ns, x in zip(in_names, out)]
return out

fun_trans = lu.wrap_init(fun_trans_callable,
Expand Down Expand Up @@ -1832,6 +1833,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,

# TODO(mattjj): remove this mechanism when we revise mesh scopes
def _all_mesh_names_except_spmd(mesh: Mesh, trace=None) -> tuple[AxisName, ...]:
del trace
spmd_names = core.get_axis_env().spmd_axis_names
return tuple(name for name in mesh.axis_names if name not in spmd_names)

Expand Down
3 changes: 2 additions & 1 deletion tests/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@
def call(f, *args):
return jit(f)(*args)

@util.curry
def core_call(f, *args):
args, in_tree = jax.tree.flatten(args)
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
out = core.call_p.bind(f, *args)
return jax.tree.unflatten(out_tree(), out)
# call = core_call
core_call = util.curry(core_call)

@util.curry
def core_closed_call(f, *args):
Expand Down
18 changes: 9 additions & 9 deletions tests/shard_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2615,7 +2615,7 @@ def sample(num: int, make_gen: Callable[[], Chooser]) -> Iterator[CaseSpec]:
name, *case = sample_one(rng, make_gen())
if name not in seen:
seen.add(name)
yield name, *case
yield case

# To sample one test spec, we run the generator, getting back sequences of
# options from it and sending in our choices from those options until finally a
Expand Down Expand Up @@ -2730,7 +2730,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
def make_mesh(mesh_shape):
return jtu.create_mesh(tuple(mesh_shape.values()), tuple(mesh_shape))

@parameterized.named_parameters(
@parameterized.parameters(
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
mesh = self.make_mesh(mesh)
Expand All @@ -2739,7 +2739,7 @@ def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
expected = ref(fun, mesh, in_specs, out_specs)(*args)
self.assertAllClose(expected, out, check_dtypes=False)

@parameterized.named_parameters(
@parameterized.parameters(
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
mesh = self.make_mesh(mesh)
Expand All @@ -2748,9 +2748,9 @@ def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
expected = ref(fun, mesh, in_specs, out_specs)(*args)
self.assertAllClose(expected, out, check_dtypes=False)

@parameterized.named_parameters(
(name + f'_check_rep={check_rep}', *params, check_rep)
for (name, *params) in sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)
@parameterized.parameters(
(*params, check_rep)
for params in sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)
for check_rep in [True, False]
)
@jax.default_matmul_precision("float32")
Expand All @@ -2762,7 +2762,7 @@ def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _, check_rep):
f = jax.jit(f)
jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)

@parameterized.named_parameters(
@parameterized.parameters(
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
@jax.default_matmul_precision("float32")
def test_grads_closure(self, fun, mesh, jit, in_specs, out_specs, args, _):
Expand All @@ -2781,7 +2781,7 @@ def g(*args):
return g(*args)
jtu.check_grads(f, (0.2, *closed_over_args), order=2, atol=1e-2, rtol=1e-2)

@parameterized.named_parameters(
@parameterized.parameters(
sample(jtu.NUM_GENERATED_CASES.value,
partial(sample_shmap_batched, 5)))
def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref):
Expand All @@ -2804,7 +2804,7 @@ def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref):
tol = 1e-2 if jtu.test_device_matches(['tpu']) else None
self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol)

@parameterized.named_parameters(
@parameterized.parameters(
sample(jtu.NUM_GENERATED_CASES.value,
partial(sample_shmap_batched, 5)))
def test_vmap_closure(self, bdims, fun, mesh, jit, in_specs, out_specs, args, _):
Expand Down

0 comments on commit 4fa7e91

Please sign in to comment.