Skip to content

Commit

Permalink
Update the trace_to_mlir split for #520
Browse files Browse the repository at this point in the history
  • Loading branch information
dime10 committed Feb 23, 2024
1 parent 1c555ba commit c71c322
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 26 deletions.
13 changes: 7 additions & 6 deletions frontend/catalyst/jax_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
ClosedJaxpr,
DynamicJaxprTrace,
DynamicJaxprTracer,
DynshapedJaxpr,
PyTreeDef,
PyTreeRegistry,
ShapedArray,
Expand All @@ -70,7 +71,6 @@
convert_element_type,
deduce_avals,
eval_jaxpr,
jaxpr_remove_implicit,
jaxpr_to_mlir,
make_jaxpr2,
sort_eqns,
Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(self, fn):
self.__name__ = fn.__name__

def __call__(self, *args, **kwargs):
jaxpr, _, out_tree = make_jaxpr2(self.fn)(*args)
jaxpr, out_tree = make_jaxpr2(self.fn)(*args)

def _eval_jaxpr(*args):
return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
Expand Down Expand Up @@ -354,10 +354,7 @@ def trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs):
"static_argnums": static_argnums,
"abstracted_axes": abstracted_axes,
}
jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)

# We remove implicit Jaxpr result values since we are compiling a top-level jaxpr program.
jaxpr, _ = jaxpr_remove_implicit(jaxpr, out_type)
jaxpr, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)

return jaxpr, out_treedef

Expand All @@ -381,6 +378,10 @@ def lower_jaxpr_to_mlir(jaxpr, func_name):
mlir_fn_cache.clear()

with transient_jax_config():
# We remove implicit Jaxpr result values since we are compiling a top-level jaxpr program.
if isinstance(jaxpr, DynshapedJaxpr):
jaxpr = jaxpr.remove_implicit_results()

mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)

return mlir_module, ctx
Expand Down
53 changes: 33 additions & 20 deletions frontend/catalyst/utils/jax_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@

__all__ = (
"ClosedJaxpr",
"DynshapedJaxpr",
"DynamicJaxprTrace",
"DynamicJaxprTracer",
"Jaxpr",
Expand All @@ -107,7 +108,6 @@
"_initial_style_jaxpr",
"_input_type_to_tracers",
"jaxpr_to_mlir",
"jaxpr_remove_implicit",
"make_jaxpr_effects",
"make_jaxpr2",
"new_dynamic_main2",
Expand All @@ -125,6 +125,35 @@
map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin


class DynshapedJaxpr(ClosedJaxpr):
"""A wrapper class to handle implicit/explicit result information used by JAX for dynamically
shaped arrays. Can be used inplace of any other ClosedJaxpr instance."""

def __init__(self, jaxpr: Jaxpr, consts: Sequence, output_type: OutputType):
super().__init__(jaxpr, consts)
self.output_type = output_type

def remove_implicit_results(self):
"""Remove all implicit result values from this JAXPR.
Returns:
ClosedJaxpr
"""
# Note: a more idiomatic way of doing this would be to re-trace the jaxpr and drop the
# unneeded tracers.
if not self.output_type:
return self

jaxpr = self.jaxpr
out_keep = tuple(zip(*self.output_type))[1]
outvars = [o for o, keep in zip(jaxpr._outvars, out_keep) if keep]
filtered_jaxpr = Jaxpr(
jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns, jaxpr.effects, jaxpr.debug_info
)

return ClosedJaxpr(filtered_jaxpr, self.consts)


@contextmanager
def transient_jax_config() -> Generator[None, None, None]:
"""Context manager which updates transient JAX configuration options,
Expand Down Expand Up @@ -467,22 +496,6 @@ def get_implicit_and_explicit_flat_args(abstracted_axes, *args, **kwargs):
return args_flat


def jaxpr_remove_implicit(
closed_jaxpr: ClosedJaxpr, out_type: OutputType
) -> tuple[ClosedJaxpr, OutputType]:
"""Remove all the implicit result values of the ``closed_jaxpr``."""
# Note: a more idiomatic way of doing this would be to re-trace the jaxpr and drop the unneeded
# tracers.
jaxpr = closed_jaxpr.jaxpr
out_keep = list(tuple(zip(*out_type))[1]) if len(out_type) > 0 else []
outvars = [o for o, keep in zip(jaxpr._outvars, out_keep) if keep]
out_type2 = [o for o, keep in zip(out_type, out_keep) if keep]
jaxpr2 = Jaxpr(
jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns, jaxpr.effects, jaxpr.debug_info
)
return ClosedJaxpr(jaxpr2, closed_jaxpr.consts), out_type2


def make_jaxpr2(
fun: Callable,
static_argnums: Any | None = None,
Expand Down Expand Up @@ -521,9 +534,9 @@ def make_jaxpr_f(*args, **kwargs):
in_type, in_tree = abstractify(args, kwargs)
f, out_tree_promise = flatten_fun(f, in_tree)
f = annotate(f, in_type)
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
closed_jaxpr = ClosedJaxpr(jaxpr, consts)
return closed_jaxpr, out_type, out_tree_promise()
jaxpr, output_type, consts = trace_to_jaxpr_dynamic2(f)
closed_jaxpr = DynshapedJaxpr(jaxpr, consts, output_type)
return closed_jaxpr, out_tree_promise()

make_jaxpr_f.__name__ = f"make_jaxpr2({make_jaxpr2.__name__})"
return make_jaxpr_f
Expand Down

0 comments on commit c71c322

Please sign in to comment.