Skip to content

Commit

Permalink
[frontend] Use cached_primitive_lowerings instead of custom cache. (#…
Browse files Browse the repository at this point in the history
…1159)

**Context:** JAX provides a cache for primitive lowerings. Which is
exactly what we use our caches for.

**Description of the Change:** Use JAX's provided cache.

**Benefits:** Less code.

**Possible Drawbacks:** None.

**Related GitHub Issues:**
  • Loading branch information
erick-xanadu authored Sep 27, 2024
1 parent a8fbc46 commit 9a3e64b
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 25 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@
* Compiling `qnode`s to asynchronous functions will no longer print to stderr in case of an error.
[(#645)](https://github.com/PennyLaneAI/catalyst/pull/645)

* Cached primitive lowerings is used instead of a custom cache structure.
[(#1159)](https://github.com/PennyLaneAI/catalyst/pull/1159)

<h3>Breaking changes</h3>

* Remove `static_size` field from `AbstractQreg` class.
Expand Down
30 changes: 13 additions & 17 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from dataclasses import dataclass
from enum import Enum
from itertools import chain
from typing import Any, Dict, Iterable, List, Union
from typing import Iterable, List, Union

import jax
import numpy as np
Expand Down Expand Up @@ -315,9 +315,6 @@ def _python_callback_def_impl(*avals, callback, custom_grad, results_aval): # p
raise NotImplementedError()


CALLBACK_OP_CACHE = {}


def _python_callback_lowering(
jax_ctx: mlir.LoweringRuleContext, *args, callback, custom_grad, results_aval
):
Expand All @@ -333,8 +330,8 @@ def _python_callback_lowering(
fn_ty = FunctionType.get(inputs=params_ty, results=results_ty)
fn_ty_attr = ir.TypeAttr.get(fn_ty)
cache_key = (callback_id, *params_ty, *results_ty)
if cache_key in CALLBACK_OP_CACHE:
callbackOp = CALLBACK_OP_CACHE[cache_key]
if cache_key in jax_ctx.module_context.cached_primitive_lowerings:
callbackOp = jax_ctx.module_context.cached_primitive_lowerings[cache_key]
symbol = callbackOp.sym_name.value
symbol_attr = ir.FlatSymbolRefAttr.get(symbol)
return CallbackCallOp(results_ty, symbol_attr, args).results
Expand All @@ -346,8 +343,7 @@ def _python_callback_lowering(
# TODO: Name mangling for callbacks
name = callback.__name__
callbackOp = CallbackOp(f"callback_{name}_{callback_id}", *attrs)
CALLBACK_OP_CACHE[cache_key] = callbackOp
callbackOp = CALLBACK_OP_CACHE[cache_key]
jax_ctx.module_context.cached_primitive_lowerings[cache_key] = callbackOp
symbol = callbackOp.sym_name.value
symbol_attr = ir.FlatSymbolRefAttr.get(symbol)
retval = CallbackCallOp(results_ty, symbol_attr, args).results
Expand Down Expand Up @@ -555,7 +551,6 @@ def _apply_registered_pass_lowering(
#
# func
#
mlir_fn_cache: Dict["catalyst.jax_tracer.Function", Any] = {}


@func_p.def_impl
Expand All @@ -578,6 +573,7 @@ def _func_def_lowering(ctx, fn, call_jaxpr, name_stack) -> str:
diff_method = "parameter-shift" if fn.diff_method == "best" else str(fn.diff_method)
func_op.attributes["diff_method"] = ir.StringAttr.get(diff_method)

ctx.cached_primitive_lowerings[fn] = func_op
return func_op


Expand Down Expand Up @@ -606,11 +602,11 @@ def _func_lowering(ctx, *args, call_jaxpr, fn, call=True):
call_jaxpr: the jaxpr representation of the fn
fn: the function being compiled
"""
if fn in mlir_fn_cache:
func_op = mlir_fn_cache[fn]
if fn in ctx.module_context.cached_primitive_lowerings:
func_op = ctx.module_context.cached_primitive_lowerings[fn]
else:
func_op = _func_def_lowering(ctx.module_context, fn, call_jaxpr, name_stack=ctx.name_stack)
mlir_fn_cache[fn] = func_op
ctx.module_context.cached_primitive_lowerings[fn] = func_op

symbol_name = func_op.name.value

Expand Down Expand Up @@ -690,7 +686,7 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params):
diffArgIndices = ir.DenseIntElementsAttr.get(argnum_numpy)
func_call_jaxpr = _get_call_jaxpr(jaxpr)
_func_lowering(ctx, *args, call_jaxpr=func_call_jaxpr, fn=fn, call=False)
func_op = mlir_fn_cache[fn]
func_op = ctx.module_context.cached_primitive_lowerings[fn]
symbol_name = func_op.name.value
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)
Expand Down Expand Up @@ -772,7 +768,7 @@ def _value_and_grad_lowering(ctx, *args, jaxpr, fn, grad_params):
call=False,
)

func_op = mlir_fn_cache[fn]
func_op = ctx.module_context.cached_primitive_lowerings[fn]
symbol_name = func_op.name.value
return ValueAndGradOp(
val_result_types,
Expand Down Expand Up @@ -832,7 +828,7 @@ def _jvp_lowering(ctx, *args, jaxpr, fn, grad_params):
assert (
len(flat_output_types) % 2 == 0
), f"The total number of result tensors is expected to be even, not {len(flat_output_types)}"
func_op = mlir_fn_cache[fn]
func_op = ctx.module_context.cached_primitive_lowerings[fn]
symbol_name = func_op.name.value
return JVPOp(
flat_output_types[: len(flat_output_types) // 2],
Expand Down Expand Up @@ -889,7 +885,7 @@ def _vjp_lowering(ctx, *args, jaxpr, fn, grad_params):
call=False,
)

func_op = mlir_fn_cache[fn]
func_op = ctx.module_context.cached_primitive_lowerings[fn]
symbol_name = func_op.name.value
return VJPOp(
func_result_types,
Expand Down Expand Up @@ -941,7 +937,7 @@ def _zne_lowering(ctx, *args, folding, jaxpr, fn):
"""
func_call_jaxpr = _get_call_jaxpr(jaxpr)
_func_lowering(ctx, *args, call_jaxpr=func_call_jaxpr, fn=fn, call=False)
func_op = mlir_fn_cache[fn]
func_op = ctx.module_context.cached_primitive_lowerings[fn]
symbol_name = func_op.name.value
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)
Expand Down
8 changes: 0 additions & 8 deletions frontend/catalyst/jax_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
wrap_init,
)
from catalyst.jax_primitives import (
CALLBACK_OP_CACHE,
AbstractQreg,
compbasis_p,
counts_p,
Expand All @@ -71,7 +70,6 @@
gphase_p,
hamiltonian_p,
hermitian_p,
mlir_fn_cache,
namedobs_p,
probs_p,
qalloc_p,
Expand Down Expand Up @@ -548,13 +546,7 @@ def lower_jaxpr_to_mlir(jaxpr, func_name):
ir.Context: the MLIR context
"""

# The compilation cache must be clear for each translation unit. Otherwise, MLIR functions
# which do not exist in the current translation unit will be assumed to exist if an equivalent
# python function is seen in the cache. This happens during testing or if we wanted to compile a
# single python function multiple times with different options.
mlir_fn_cache.clear()
MemrefCallable.clearcache()
CALLBACK_OP_CACHE.clear()

with transient_jax_config({"jax_dynamic_shapes": True}):
mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
Expand Down

0 comments on commit 9a3e64b

Please sign in to comment.