diff --git a/doc/changelog.md b/doc/changelog.md
index 738f79e004..61dec29de4 100644
--- a/doc/changelog.md
+++ b/doc/changelog.md
@@ -51,6 +51,44 @@
Improvements
+* Catalyst will now remember previously compiled functions when the PyTree metadata of arguments
+ changes, in addition to already rememebering compiled functions when static arguments change.
+ [(#522)](https://github.com/PennyLaneAI/catalyst/pull/531)
+
+ The following example will no longer trigger a third compilation:
+ ```py
+ @qjit
+ def func(x):
+ print("compiling")
+ return x
+ ```
+ ```pycon
+ >>> func([1,]); # list
+ compiling
+ >>> func((2,)); # tuple
+ compiling
+ >>> func([3,]); # list
+
+ ```
+
+ Note however that in order to keep overheads low, changing the argument *type* or *shape* (in a
+ promotion incompatible way) may override a previously stored function (with identical PyTree
+ metadata and static argument values):
+ ```py
+ @qjit
+ def func(x):
+ print("compiling")
+ return x
+ ```
+ ```pycon
+ >>> func(jnp.array(1)); # scalar
+ compiling
+ >>> func(jnp.array([2.])); # 1-D array
+ compiling
+ >>> func(jnp.array(3)); # scalar
+ compiling
+ ```
+
* Keep the structure of the function return when taking the derivatives, JVP and VJP (pytrees support).
[(#500)](https://github.com/PennyLaneAI/catalyst/pull/500)
[(#501)](https://github.com/PennyLaneAI/catalyst/pull/501)
@@ -143,6 +181,24 @@
Breaking changes
+* The Catalyst Python frontend has been partially refactored. The impact on user-facing
+ functionality is minimal, but the location of certain classes and methods used by the package
+ may have changed.
+ [(#529)](https://github.com/PennyLaneAI/catalyst/pull/529)
+ [(#522)](https://github.com/PennyLaneAI/catalyst/pull/531)
+
+ The following changes have been made:
+ * Some debug methods and features on the QJIT class have been turned into free functions and moved
+ to the `catalyst.debug` module, which will now appear in the public documention. This includes
+ compiling a program from ir, obtaining a C program to invoke a compiled function from, and
+ printing fine-grained MLIR compilation stages.
+ * The `compilation_pipelines.py` module has been renamed to `jit.py`, and certain functionality
+ has been moved out (see following items).
+ * A new module `compiled_functions.py` now manages low-level access to compiled functions.
+ * A new module `tracing/type_signatures.py` handles functionality related managing arguments
+ and type signatures during the tracing process.
+ * The `contexts.py` module has been moved from `utils` to the new `tracing` sub-module.
+
* `QCtrl` is overriden and never used.
[(#522)](https://github.com/PennyLaneAI/catalyst/pull/522)
@@ -247,6 +303,11 @@
Bug fixes
+* Catalyst will no longer print a warning that recompilation is triggered when a `@qjit` decorated
+ function with no arguments is invoke without having been compiled first, for example via the use
+ of `target="mlir"`.
+ [(#522)](https://github.com/PennyLaneAI/catalyst/pull/531)
+
* Only set `JAX_DYNAMIC_SHAPES` configuration option during `trace_to_mlir()`.
[(#526)](https://github.com/PennyLaneAI/catalyst/pull/526)
diff --git a/doc/dev/sharp_bits.rst b/doc/dev/sharp_bits.rst
index 800fc1207c..e25c19b12a 100644
--- a/doc/dev/sharp_bits.rst
+++ b/doc/dev/sharp_bits.rst
@@ -310,7 +310,6 @@ array(0.16996714)
However, deviating from this will result in recompilation and a warning message:
>>> circuit(jnp.array([1.4, 1.4, 0.3, 0.1]))
-catalyst/compilation_pipelines.py:592:
UserWarning: Provided arguments did not match declared signature, recompiling...
Tracing occurring
array(0.16996714)
diff --git a/frontend/README.rst b/frontend/README.rst
index 7cb687e50b..388fbad36e 100644
--- a/frontend/README.rst
+++ b/frontend/README.rst
@@ -38,7 +38,7 @@ The ``catalyst`` Python package is a mixed Python package which relies on some C
and the following modules:
-- `compilation_pipelines.py `_:
+- `jit.py `_:
This module contains classes and decorators for just-in-time and ahead-of-time compilation of
hybrid quantum-classical functions using Catalyst.
diff --git a/frontend/catalyst/__init__.py b/frontend/catalyst/__init__.py
index 6b7b65f234..20cc2f8223 100644
--- a/frontend/catalyst/__init__.py
+++ b/frontend/catalyst/__init__.py
@@ -67,7 +67,8 @@
from catalyst import debug
from catalyst.ag_utils import AutoGraphError, autograph_source
-from catalyst.compilation_pipelines import QJIT, CompileOptions, qjit
+from catalyst.compiler import CompileOptions
+from catalyst.jit import QJIT, qjit
from catalyst.pennylane_extensions import (
adjoint,
cond,
diff --git a/frontend/catalyst/compiled_functions.py b/frontend/catalyst/compiled_functions.py
index 1809e52d5c..ad1b9ca7e3 100644
--- a/frontend/catalyst/compiled_functions.py
+++ b/frontend/catalyst/compiled_functions.py
@@ -15,10 +15,12 @@
"""This module contains classes to manage compiled functions and their underlying resources."""
import ctypes
+from dataclasses import dataclass
+from typing import Tuple
import numpy as np
from jax.interpreters import mlir
-from jax.tree_util import tree_flatten, tree_unflatten
+from jax.tree_util import PyTreeDef, tree_flatten, tree_unflatten
from mlir_quantum.runtime import (
as_ctype,
get_ranked_memref_descriptor,
@@ -27,9 +29,15 @@
)
from catalyst.jax_extras import get_implicit_and_explicit_flat_args
-from catalyst.tracing.type_signatures import filter_static_args
+from catalyst.tracing.type_signatures import (
+ TypeCompatibility,
+ filter_static_args,
+ get_decomposed_signature,
+ typecheck_signatures,
+)
from catalyst.utils import wrapper # pylint: disable=no-name-in-module
from catalyst.utils.c_template import get_template, mlir_type_to_numpy_type
+from catalyst.utils.filesystem import Directory
class SharedObjectManager:
@@ -43,17 +51,19 @@ class SharedObjectManager:
"""
def __init__(self, shared_object_file, func_name):
+ self.shared_object_file = shared_object_file
self.shared_object = None
+ self.func_name = func_name
self.function = None
self.setup = None
self.teardown = None
self.mem_transfer = None
- self.open(shared_object_file, func_name)
+ self.open()
- def open(self, shared_object_file, func_name):
+ def open(self):
"""Open the sharead object and load symbols."""
- self.shared_object = ctypes.CDLL(shared_object_file)
- self.function, self.setup, self.teardown, self.mem_transfer = self.load_symbols(func_name)
+ self.shared_object = ctypes.CDLL(self.shared_object_file)
+ self.function, self.setup, self.teardown, self.mem_transfer = self.load_symbols()
def close(self):
"""Close the shared object"""
@@ -66,17 +76,14 @@ def close(self):
# pylint: disable=protected-access
dlclose(self.shared_object._handle)
- def load_symbols(self, func_name):
+ def load_symbols(self):
"""Load symbols necessary for for execution of the compiled function.
- Args:
- func_name: name of compiled function to be executed
-
Returns:
- function: function handle
- setup: handle to the setup function, which initializes the device
- teardown: handle to the teardown function, which tears down the device
- mem_transfer: memory transfer shared object
+ CFuncPtr: handle to the main function of the program
+ CFuncPtr: handle to the setup function, which initializes the device
+ CFuncPtr: handle to the teardown function, which tears down the device
+ CFuncPtr: handle to the memory transfer function for program results
"""
setup = self.shared_object.setup
@@ -88,7 +95,7 @@ def load_symbols(self, func_name):
teardown.restypes = None
# We are calling the c-interface
- function = self.shared_object["_catalyst_pyface_" + func_name]
+ function = self.shared_object["_catalyst_pyface_" + self.func_name]
# Guaranteed from _mlir_ciface specification
function.restypes = None
# Not needed, computed from the arguments.
@@ -122,7 +129,6 @@ class CompiledFunction:
"""
def __init__(self, shared_object_file, func_name, restype, compile_options):
- self.shared_object_file = shared_object_file
self.shared_object = SharedObjectManager(shared_object_file, func_name)
self.compile_options = compile_options
self.return_type_c_abi = None
@@ -140,7 +146,7 @@ def _exec(shared_object, has_return, numpy_dict, *args):
*args: arguments to the function
Returns:
- retval: the value computed by the function or None if the function has no return value
+ the return values computed by the function or None if the function has no results
"""
with shared_object as lib:
@@ -256,8 +262,8 @@ def args_to_memref_descs(self, restype, args):
args: the JAX arrays to be used as arguments to the function
Returns:
- c_abi_args: a list of memref descriptor pointers to return values and parameters
- numpy_arg_buffer: A list to the return values. It must be kept around until the function
+ List: a list of memref descriptor pointers to return values and parameters
+ List: A list to the return values. It must be kept around until the function
finishes execution as the memref descriptors will point to memory locations inside
numpy arrays.
@@ -327,3 +333,126 @@ def __call__(self, *args, **kwargs):
)
return result
+
+
+@dataclass
+class CacheKey:
+ """A key by which to identify entries in the compiled function cache.
+
+ The cache only rembers one compiled function for each possible combination of:
+ - dynamic argument PyTree metadata
+ - static argument values
+ """
+
+ treedef: PyTreeDef
+ static_args: Tuple
+
+ def __hash__(self) -> int:
+ if not hasattr(self, "_hash"):
+ # pylint: disable=attribute-defined-outside-init
+ self._hash = hash((self.treedef, self.static_args))
+ return self._hash
+
+
+@dataclass
+class CacheEntry:
+ """An entry in the compiled function cache.
+
+ For each compiled function, the cache stores the dynamic argument signature, the output PyTree
+ definition, as well as the workspace in which compilation takes place.
+ """
+
+ compiled_fn: CompiledFunction
+ signature: Tuple
+ out_treedef: PyTreeDef
+ workspace: Directory
+
+
+class CompilationCache:
+ """Class to manage CompiledFunction instances and retrieving previously compiled versions of
+ a function.
+
+ A full match requires the following properties to match:
+ - dynamic argument signature (the shape and dtype of flattened arrays)
+ - dynamic argument PyTree definitions
+ - static argument values
+
+ In order to allow some flexibility in the type of arguments provided by the user, a match is
+ also produced if the dtype of dynamic arguments can be promoted to the dtype of an existing
+ signature via JAX type promotion rules. Additional leniency is provided in the shape of
+ dynamic arguments in accordance with the abstracted axis specification.
+
+ To avoid excess promotion checks, only one function version is stored at a time for a given
+ combination of PyTreeDefs and static arguments.
+ """
+
+ def __init__(self, static_argnums, abstracted_axes):
+ self.static_argnums = static_argnums
+ self.abstracted_axes = abstracted_axes
+ self.cache = {}
+
+ def get_function_status_and_key(self, args):
+ """Check if the provided arguments match an existing function in the cache. The cache
+ status of the function is returned as a compilation action:
+ - no match: requires compilation
+ - partial match: requires argument promotion
+ - full match: skip promotion
+
+ Args:
+ args: arguments to match to existing functions
+
+ Returns:
+ TypeCompatibility
+ CacheKey | None
+ """
+ if not self.cache:
+ return TypeCompatibility.NEEDS_COMPILATION, None
+
+ flat_runtime_sig, treedef, static_args = get_decomposed_signature(args, self.static_argnums)
+ key = CacheKey(treedef, static_args)
+ if key not in self.cache:
+ return TypeCompatibility.NEEDS_COMPILATION, None
+
+ entry = self.cache[key]
+ runtime_signature = tree_unflatten(treedef, flat_runtime_sig)
+ action = typecheck_signatures(entry.signature, runtime_signature, self.abstracted_axes)
+ return action, key
+
+ def lookup(self, args):
+ """Get a function (if present) that matches the provided argument signature. Also computes
+ whether promotion is necessary.
+
+ Args:
+ args (Iterable): the arguments to match to existing functions
+
+ Returns:
+ CacheEntry | None: the matched cache entry
+ bool: whether the matched entry requires argument promotion
+ """
+ action, key = self.get_function_status_and_key(args)
+
+ if action == TypeCompatibility.NEEDS_COMPILATION:
+ return None, None
+ elif action == TypeCompatibility.NEEDS_PROMOTION:
+ return self.cache[key], True
+ else:
+ assert action == TypeCompatibility.CAN_SKIP_PROMOTION
+ return self.cache[key], False
+
+ def insert(self, fn, args, out_treedef, workspace):
+ """Inserts the provided function into the cache.
+
+ Args:
+ fn (CompiledFunction): compilation result
+ args (Iterable): arguments to determine cache key and additional metadata from
+ out_treedef (PyTreeDef): the output shape of the function
+ workspace (Directory): directory where compilation artifacts are stored
+ """
+ assert isinstance(fn, CompiledFunction)
+
+ flat_signature, treedef, static_args = get_decomposed_signature(args, self.static_argnums)
+ signature = tree_unflatten(treedef, flat_signature)
+
+ key = CacheKey(treedef, static_args)
+ entry = CacheEntry(fn, signature, out_treedef, workspace)
+ self.cache[key] = entry
diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py
index c5e04d53bd..66d6067cb5 100644
--- a/frontend/catalyst/compiler.py
+++ b/frontend/catalyst/compiler.py
@@ -524,10 +524,3 @@ def get_output_of(self, pipeline) -> Optional[str]:
return None
return self.last_compiler_output.get_pipeline_output(pipeline)
-
- def print(self, pipeline):
- """Print the output IR of pass.
- Args:
- pipeline (str): name of pass class
- """
- print(self.get_output_of(pipeline)) # pragma: no cover
diff --git a/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py b/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py
index 58438bdf43..5557eb5f21 100644
--- a/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py
+++ b/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py
@@ -840,12 +840,11 @@ def cudaq_backend_info(device):
# We could also pass abstract arguments here in *args
# the same way we do so in Catalyst.
# But I think that is redundant now given make_jaxpr2
- _, jaxpr, _, out_tree = trace_to_jaxpr(func, static_args, abs_axes, *args)
+ jaxpr, out_treedef = trace_to_jaxpr(func, static_args, abs_axes, args, {})
# TODO(@erick-xanadu):
# What about static_args?
- # We could return _out_type2 as well
- return jaxpr, out_tree
+ return jaxpr, out_treedef
def interpret(fun):
diff --git a/frontend/catalyst/debug.py b/frontend/catalyst/debug.py
index 5622b5bcc0..0e29ce9403 100644
--- a/frontend/catalyst/debug.py
+++ b/frontend/catalyst/debug.py
@@ -25,6 +25,7 @@
from catalyst.compiler import Compiler
from catalyst.jax_primitives import print_p
from catalyst.tracing.contexts import EvaluationContext
+from catalyst.tracing.type_signatures import filter_static_args, promote_arguments
from catalyst.utils.filesystem import WorkspaceManager
@@ -74,6 +75,36 @@ def func(x: float):
builtins.print(x)
+def print_compilation_stage(fn, stage):
+ """Print one of the recorded compilation stages for a JIT-compiled function.
+
+ The stages are indexed by their Catalyst compilation pipeline name, which are either provided
+ by the user as a compilation option, or predefined in ``catalyst.compiler``.
+
+ Requires ``keep_intermediate=True``.
+
+ Args:
+ fn (QJIT): a qjit-decorated function
+ stage (str): string corresponding with the name of the stage to be printed
+
+ **Example**
+
+ .. code-block:: python
+
+ @qjit(keep_intermediate=True)
+ def func(x: float):
+ return x
+
+ debug.print_compilation_stage(func, "HLOLoweringPass")
+ """
+ EvaluationContext.check_is_not_tracing("C interface cannot be generated from tracing context.")
+
+ if not isinstance(fn, catalyst.QJIT):
+ raise TypeError(f"First argument needs to be a 'QJIT' object, got a {type(fn)}.")
+
+ print(fn.compiler.get_output_of(stage))
+
+
def get_cmain(fn, *args):
"""Return a C program that calls a jitted function with the provided arguments.
@@ -89,13 +120,13 @@ def get_cmain(fn, *args):
if not isinstance(fn, catalyst.QJIT):
raise TypeError(f"First argument needs to be a 'QJIT' object, got a {type(fn)}.")
- # TODO: will be removed in part 2 of the refactor
- # pylint: disable=protected-access
- complied_function, args = fn._ensure_real_arguments_and_formal_parameters_are_compatible(
- fn.compiled_function, *args
- )
+ requires_promotion = fn.jit_compile(args)
+
+ if requires_promotion:
+ dynamic_args = filter_static_args(args, fn.compile_options.static_argnums)
+ args = promote_arguments(fn.c_sig, dynamic_args)
- return complied_function.get_cmain(*args)
+ return fn.compiled_function.get_cmain(*args)
# pylint: disable=line-too-long
@@ -132,7 +163,7 @@ def compile_from_mlir(ir, compiler=None, compile_options=None):
}
\"""
- compiled_function = compile_from_mlir(ir)
+ compiled_function = debug.compile_from_mlir(ir)
>>> compiled_function(0.1)
[0.1]
diff --git a/frontend/catalyst/jax_extras/__init__.py b/frontend/catalyst/jax_extras/__init__.py
index 5b42c0174d..e9de9e51c4 100644
--- a/frontend/catalyst/jax_extras/__init__.py
+++ b/frontend/catalyst/jax_extras/__init__.py
@@ -23,6 +23,7 @@
ClosedJaxpr,
DynamicJaxprTrace,
DynamicJaxprTracer,
+ DynshapedClosedJaxpr,
Jaxpr,
PyTreeDef,
PyTreeRegistry,
@@ -40,7 +41,6 @@
infer_lambda_input_type,
initial_style_jaxprs_with_common_consts1,
initial_style_jaxprs_with_common_consts2,
- jaxpr_remove_implicit,
make_jaxpr2,
make_jaxpr_effects,
new_dynamic_main2,
diff --git a/frontend/catalyst/jax_extras/tracing.py b/frontend/catalyst/jax_extras/tracing.py
index e10ce69ba8..aad3e376f8 100644
--- a/frontend/catalyst/jax_extras/tracing.py
+++ b/frontend/catalyst/jax_extras/tracing.py
@@ -44,6 +44,7 @@
from jax.core import ClosedJaxpr, Jaxpr, JaxprEqn, MainTrace, OutputType
from jax.core import Primitive as JaxprPrimitive
from jax.core import ShapedArray, Trace, eval_jaxpr, gensym, thread_local_state
+from jax.extend.linear_util import wrap_init
from jax.interpreters.partial_eval import (
DynamicJaxprTrace,
DynamicJaxprTracer,
@@ -51,7 +52,6 @@
make_jaxpr_effects,
)
from jax.lax import convert_element_type
-from jax.linear_util import wrap_init
from jax.tree_util import (
PyTreeDef,
tree_flatten,
@@ -68,6 +68,7 @@
__all__ = (
"ClosedJaxpr",
+ "DynshapedClosedJaxpr",
"DynamicJaxprTrace",
"DynamicJaxprTracer",
"Jaxpr",
@@ -84,7 +85,6 @@
"_initial_style_jaxpr",
"_input_type_to_tracers",
"_module_name_regex",
- "jaxpr_remove_implicit",
"make_jaxpr_effects",
"make_jaxpr2",
"new_dynamic_main2",
@@ -102,6 +102,35 @@
map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin
+class DynshapedClosedJaxpr(ClosedJaxpr):
+ """A wrapper class to handle implicit/explicit result information used by JAX for dynamically
+ shaped arrays. Can be used inplace of any other ClosedJaxpr instance."""
+
+ def __init__(self, jaxpr: Jaxpr, consts: Sequence, output_type: OutputType):
+ super().__init__(jaxpr, consts)
+ self.output_type = output_type
+
+ def remove_implicit_results(self):
+ """Remove all implicit result values from this JAXPR.
+
+ Returns:
+ ClosedJaxpr
+ """
+ # Note: a more idiomatic way of doing this would be to re-trace the jaxpr and drop the
+ # unneeded tracers.
+ if not self.output_type:
+ return self
+
+ jaxpr = self.jaxpr
+ out_keep = tuple(zip(*self.output_type))[1]
+ outvars = [o for o, keep in zip(jaxpr._outvars, out_keep) if keep]
+ filtered_jaxpr = Jaxpr(
+ jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns, jaxpr.effects, jaxpr.debug_info
+ )
+
+ return ClosedJaxpr(filtered_jaxpr, self.consts)
+
+
@contextmanager
def transient_jax_config() -> Generator[None, None, None]:
"""Context manager which updates transient JAX configuration options,
@@ -325,22 +354,6 @@ def get_implicit_and_explicit_flat_args(abstracted_axes, *args, **kwargs):
return args_flat
-def jaxpr_remove_implicit(
- closed_jaxpr: ClosedJaxpr, out_type: OutputType
-) -> tuple[ClosedJaxpr, OutputType]:
- """Remove all the implicit result values of the ``closed_jaxpr``."""
- # Note: a more idiomatic way of doing this would be to re-trace the jaxpr and drop the unneeded
- # tracers.
- jaxpr = closed_jaxpr.jaxpr
- out_keep = list(tuple(zip(*out_type))[1]) if len(out_type) > 0 else []
- outvars = [o for o, keep in zip(jaxpr._outvars, out_keep) if keep]
- out_type2 = [o for o, keep in zip(out_type, out_keep) if keep]
- jaxpr2 = Jaxpr(
- jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns, jaxpr.effects, jaxpr.debug_info
- )
- return ClosedJaxpr(jaxpr2, closed_jaxpr.consts), out_type2
-
-
def make_jaxpr2(
fun: Callable,
static_argnums: Any | None = None,
@@ -379,9 +392,9 @@ def make_jaxpr_f(*args, **kwargs):
in_type, in_tree = abstractify(args, kwargs)
f, out_tree_promise = flatten_fun(f, in_tree)
f = annotate(f, in_type)
- jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
- closed_jaxpr = ClosedJaxpr(jaxpr, consts)
- return closed_jaxpr, out_type, out_tree_promise()
+ jaxpr, output_type, consts = trace_to_jaxpr_dynamic2(f)
+ closed_jaxpr = DynshapedClosedJaxpr(jaxpr, consts, output_type)
+ return closed_jaxpr, out_tree_promise()
make_jaxpr_f.__name__ = f"make_jaxpr2({make_jaxpr2.__name__})"
return make_jaxpr_f
diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py
index e7e97175e4..8254603661 100644
--- a/frontend/catalyst/jax_tracer.py
+++ b/frontend/catalyst/jax_tracer.py
@@ -31,6 +31,7 @@
ClosedJaxpr,
DynamicJaxprTrace,
DynamicJaxprTracer,
+ DynshapedClosedJaxpr,
PyTreeDef,
PyTreeRegistry,
ShapedArray,
@@ -39,7 +40,6 @@
convert_element_type,
deduce_avals,
eval_jaxpr,
- jaxpr_remove_implicit,
jaxpr_to_mlir,
make_jaxpr2,
sort_eqns,
@@ -100,7 +100,7 @@ def __init__(self, fn):
self.__name__ = fn.__name__
def __call__(self, *args, **kwargs):
- jaxpr, _, out_tree = make_jaxpr2(self.fn)(*args)
+ jaxpr, out_tree = make_jaxpr2(self.fn)(*args)
def _eval_jaxpr(*args):
return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
@@ -332,11 +332,12 @@ def has_nested_tapes(op: Operation) -> bool:
)
-def trace_to_jaxpr(func, static_argnums, abstracted_axes, *args, **kwargs):
- """Trace a function to JAXPR.
+def trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs):
+ """Trace a Python function to JAXPR.
Args:
- func: python function to be lowered
+ func: python function to be traced
+ static_argnums: indices of static arguments.
abstracted_axes: abstracted axes specification. Necessary for JAX to use dynamic tensor
sizes.
args: arguments to ``func``
@@ -344,44 +345,32 @@ def trace_to_jaxpr(func, static_argnums, abstracted_axes, *args, **kwargs):
Returns:
ClosedJaxpr: the Jaxpr program corresponding to ``func``
- ClosedJaxpr: the Jaxpr program corresponding to ``func`` without implicit result values.
- jax.OutputType: Jaxpr output type (a list of abstract values paired with
- explicintess flags).
PyTreeDef: PyTree-shape of the return values in ``PyTreeDef``
"""
- with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
- make_jaxpr_kwargs = {
- "abstracted_axes": abstracted_axes,
- "static_argnums": static_argnums,
- }
- jaxpr, out_type, out_tree = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
-
- # We remove implicit Jaxpr result values since we are compiling a top-level jaxpr program.
- jaxpr2, out_type2 = jaxpr_remove_implicit(jaxpr, out_type)
+ with transient_jax_config():
+ with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
+ make_jaxpr_kwargs = {
+ "static_argnums": static_argnums,
+ "abstracted_axes": abstracted_axes,
+ }
+ jaxpr, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
- return jaxpr, jaxpr2, out_type2, out_tree
+ return jaxpr, out_treedef
-def trace_to_mlir(func, static_argnums, abstracted_axes, *args, **kwargs):
- """Lower a Python function into an MLIR module.
+def lower_jaxpr_to_mlir(jaxpr, func_name):
+ """Lower a JAXPR to MLIR.
Args:
- func: python function to be lowered
- static_argnums: indices of static arguments.
- abstracted_axes: abstracted axes specification. Necessary for JAX to use dynamic tensor
- sizes.
- args: arguments to ``func``
- kwargs: keyword arguments to ``func``
+ ClosedJaxpr: the JAXPR to lower to MLIR
+ func_name: a name to use for the MLIR function
Returns:
- MLIR module: the MLIR module corresponding to ``func``
- MLIR context: the MLIR context
- ClosedJaxpr: the Jaxpr program corresponding to ``func``
- jax.OutputType: Jaxpr output type (a list of abstract values paired with
- explicintess flags).
- PyTreeDef: PyTree-shape of the return values in ``PyTreeDef``
+ ir.Module: the MLIR module coontaining the JAX program
+ ir.Context: the MLIR context
"""
+
# The compilation cache must be clear for each translation unit. Otherwise, MLIR functions
# which do not exist in the current translation unit will be assumed to exist if an equivalent
# python function is seen in the cache. This happens during testing or if we wanted to compile a
@@ -389,12 +378,13 @@ def trace_to_mlir(func, static_argnums, abstracted_axes, *args, **kwargs):
mlir_fn_cache.clear()
with transient_jax_config():
- jaxpr, postprocessed_jaxpr, out_type, out_tree = trace_to_jaxpr(
- func, static_argnums, abstracted_axes, *args, **kwargs
- )
- module, context = jaxpr_to_mlir(func.__name__, postprocessed_jaxpr)
+ # We remove implicit Jaxpr result values since we are compiling a top-level jaxpr program.
+ if isinstance(jaxpr, DynshapedClosedJaxpr):
+ jaxpr = jaxpr.remove_implicit_results()
+
+ mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
- return module, context, jaxpr, out_type, out_tree
+ return mlir_module, ctx
def trace_quantum_tape(
diff --git a/frontend/catalyst/compilation_pipelines.py b/frontend/catalyst/jit.py
similarity index 70%
rename from frontend/catalyst/compilation_pipelines.py
rename to frontend/catalyst/jit.py
index 958c64aedc..f8b7941dd8 100644
--- a/frontend/catalyst/compilation_pipelines.py
+++ b/frontend/catalyst/jit.py
@@ -13,36 +13,33 @@
# limitations under the License.
"""This module contains classes and decorators for just-in-time and ahead-of-time
-compiling of hybrid quantum-classical functions using Catalyst.
+compilation of hybrid quantum-classical functions using Catalyst.
"""
+import copy
import functools
import inspect
-import pathlib
+import os
import warnings
-from copy import deepcopy
import jax
import jax.numpy as jnp
import pennylane as qml
-from jax.interpreters.mlir import ir
+from jax.interpreters import mlir
from jax.tree_util import tree_flatten, tree_unflatten
import catalyst
from catalyst.ag_utils import run_autograph
-from catalyst.compiled_functions import CompiledFunction
+from catalyst.compiled_functions import CompilationCache, CompiledFunction
from catalyst.compiler import CompileOptions, Compiler
-from catalyst.jax_tracer import trace_to_mlir
-from catalyst.pennylane_extensions import QFunc
+from catalyst.jax_tracer import lower_jaxpr_to_mlir, trace_to_jaxpr
from catalyst.tracing.contexts import EvaluationContext
from catalyst.tracing.type_signatures import (
- TypeCompatibility,
filter_static_args,
get_abstract_signature,
get_type_annotations,
merge_static_args,
promote_arguments,
- typecheck_signatures,
)
from catalyst.utils.c_template import mlir_type_to_numpy_type
from catalyst.utils.exceptions import CompileError
@@ -69,150 +66,198 @@ class QJIT:
the :func:`~.qjit` documentation for more details.
Args:
- fn (Callable): the quantum or classical function
- compile_options (Optional[CompileOptions]): common compilation options
+ fn (Callable): the quantum or classical function to compile
+ compile_options (CompileOptions): compilation options to use
"""
def __init__(self, fn, compile_options):
+ self.original_function = fn
self.compile_options = compile_options
self.compiler = Compiler(compile_options)
- self.original_function = fn
- self.user_function = fn
- self.jaxed_function = None
+ self.fn_cache = CompilationCache(
+ compile_options.static_argnums, compile_options.abstracted_axes
+ )
+ # Active state of the compiler.
+ # TODO: rework ownership of workspace, possibly CompiledFunction
+ self.workspace = None
+ self.c_sig = None
+ self.out_treedef = None
self.compiled_function = None
+ self.jaxed_function = None
+ # IRs are only available for the most recently traced function.
+ self.jaxpr = None
+ self.mlir = None # string form (historic presence)
self.mlir_module = None
- self.user_typed = False
- self.c_sig = None
- self.out_tree = None
- self._jaxpr = None
- self._mlir = None
- self._llvmir = None
- self.function_name = None
- self.preferred_workspace_dir = None
- self.stored_compiled_functions = {}
- self.workspace_used = False
+ self.qir = None
functools.update_wrapper(self, fn)
+ self.user_sig = get_type_annotations(fn)
+ self._validate_configuration()
- if compile_options.autograph:
- self.user_function = run_autograph(fn)
+ self.user_function = self.pre_compilation()
- # QJIT is the owner of workspace.
- # do not move to compiler.
- preferred_workspace_dir = (
- pathlib.Path.cwd() if self.compile_options.keep_intermediate else None
- )
- self.preferred_workspace_dir = preferred_workspace_dir
+ # Static arguments require values, so we cannot AOT compile.
+ if self.user_sig is not None and not self.compile_options.static_argnums:
+ self.aot_compile()
- # pylint: disable=no-member
- # Guaranteed to exist after functools.update_wrapper
- self.function_name = self.__name__
+ def __call__(self, *args, **kwargs):
+ # Transparantly call Python function in case of nested QJIT calls.
+ if EvaluationContext.is_tracing():
+ return self.user_function(*args, **kwargs)
- self.workspace = WorkspaceManager.get_or_create_workspace(
- self.function_name, preferred_workspace_dir
- )
+ requires_promotion = self.jit_compile(args)
- parameter_types = get_type_annotations(self.user_function)
+ # If we receive tracers as input, dispatch to the JAX integration.
+ if any(isinstance(arg, jax.core.Tracer) for arg in tree_flatten(args)[0]):
+ if self.jaxed_function is None:
+ self.jaxed_function = JAX_QJIT(self) # lazy gradient compilation
+ return self.jaxed_function(*args, **kwargs)
- for argnum in self.compile_options.static_argnums:
- if argnum < 0 or argnum >= len(parameter_types):
- msg = f"argnum {argnum} is beyond the valid range of [0, {len(parameter_types)})."
- raise CompileError(msg)
+ elif requires_promotion:
+ dynamic_args = filter_static_args(args, self.compile_options.static_argnums)
+ args = promote_arguments(self.c_sig, dynamic_args)
- if parameter_types is not None and not self.compile_options.static_argnums:
- self.user_typed = True
- self.mlir_module = self.get_mlir(*parameter_types)
- if self.compile_options.target == "binary":
- self.compiled_function = self.compile()
+ return self.run(args, kwargs)
- def print_stage(self, stage):
- """Print one of the recorded stages.
+ def aot_compile(self):
+ """Compile Python function on initialization using the type hint signature."""
- Args:
- stage: string corresponding with the name of the stage to be printed
- """
- self.compiler.print(stage) # pragma: nocover
+ self.workspace = self._get_workspace()
- @property
- def mlir(self):
- """str: Returns the MLIR intermediate representation
- of the quantum program.
- """
- return self._mlir
+ # TODO: awkward, refactor or redesign the target feature
+ if self.compile_options.target in ("jaxpr", "mlir", "binary"):
+ self.jaxpr, self.out_treedef, self.c_sig = self.capture(self.user_sig or ())
- @property
- def jaxpr(self):
- """str: Returns the JAXPR intermediate representation
- of the quantum program.
- """
- return self._jaxpr
+ if self.compile_options.target in ("mlir", "binary"):
+ self.mlir_module, self.mlir = self.generate_ir()
- @property
- def qir(self):
- """str: Returns the LLVM and QIR intermediate representation
- of the quantum program. Only available if the function was compiled to binary.
- """
- return self._llvmir
+ if self.compile_options.target in ("binary",):
+ self.compiled_function, self.qir = self.compile()
+ self.fn_cache.insert(
+ self.compiled_function, self.user_sig, self.out_treedef, self.workspace
+ )
- def get_static_args_hash(self, *args):
- """Get hash values of all static arguments.
+ def jit_compile(self, args):
+ """Compile Python function on invocation using the provided arguments.
Args:
- args: arguments to the compiled function.
+ args (Iterable): arguments to use for program capture
+
Returns:
- a tuple of hash values of all static arguments.
+ bool: whether the provided arguments will require promotion to be used with the compiled
+ function
"""
- static_argnums = self.compile_options.static_argnums
- static_args_hash = tuple(
- hash(args[idx]) for idx in range(len(args)) if idx in static_argnums
- )
- return static_args_hash
- def get_mlir(self, *args):
- """Trace :func:`~.user_function`
+ cached_fn, requires_promotion = self.fn_cache.lookup(args)
+
+ if cached_fn is None:
+ if self.user_sig and not self.compile_options.static_argnums:
+ msg = "Provided arguments did not match declared signature, recompiling..."
+ warnings.warn(msg, UserWarning)
+
+ # Cleanup before recompilation:
+ # - recompilation should always happen in new workspace
+ # - compiled functions for jax integration are not yet cached
+ # - close existing shared library
+ self.workspace = self._get_workspace()
+ self.jaxed_function = None
+ if self.compiled_function and self.compiled_function.shared_object:
+ self.compiled_function.shared_object.close()
+
+ self.jaxpr, self.out_treedef, self.c_sig = self.capture(args)
+ self.mlir_module, self.mlir = self.generate_ir()
+ self.compiled_function, self.qir = self.compile()
+
+ self.fn_cache.insert(self.compiled_function, args, self.out_treedef, self.workspace)
+
+ elif self.compiled_function is not cached_fn.compiled_fn:
+ # Restore active state from cache.
+ self.workspace = cached_fn.workspace
+ self.compiled_function = cached_fn.compiled_fn
+ self.out_treedef = cached_fn.out_treedef
+ self.c_sig = cached_fn.signature
+ self.jaxed_function = None
+
+ self.compiled_function.shared_object.open()
+
+ return requires_promotion
+
+ # Processing Stages #
+
+ def pre_compilation(self):
+ """Perform pre-processing tasks on the Python function, such as AST transformations."""
+ processed_fn = self.original_function
+
+ if self.compile_options.autograph:
+ processed_fn = run_autograph(self.original_function)
+
+ return processed_fn
+
+ def capture(self, args):
+ """Capture the JAX program representation (JAXPR) of the wrapped function.
Args:
- *args: either the concrete values to be passed as arguments to ``fn`` or abstract values
+ args (Iterable): arguments to use for program capture
Returns:
- an MLIR module
+ ClosedJaxpr: captured JAXPR
+ PyTreeDef: PyTree metadata of the function output
+ Tuple[Any]: the dynamic argument signature
"""
+
static_argnums = self.compile_options.static_argnums
+ abstracted_axes = self.compile_options.abstracted_axes
+
dynamic_args = filter_static_args(args, static_argnums)
- self.c_sig = get_abstract_signature(dynamic_args)
+ dynamic_sig = get_abstract_signature(dynamic_args)
+ full_sig = merge_static_args(dynamic_sig, args, static_argnums)
with Patcher(
- (qml.QNode, "__call__", QFunc.__call__),
+ (qml.QNode, "__call__", catalyst.pennylane_extensions.QFunc.__call__),
):
- func = self.user_function
- sig = merge_static_args(self.c_sig, args, static_argnums)
- abstracted_axes = self.compile_options.abstracted_axes
- mlir_module, ctx, jaxpr, _, self.out_tree = trace_to_mlir(
- func, static_argnums, abstracted_axes, *sig
+ # TODO: improve PyTree handling
+ jaxpr, treedef = trace_to_jaxpr(
+ self.user_function, static_argnums, abstracted_axes, full_sig, {}
)
+ return jaxpr, treedef, dynamic_sig
+
+ def generate_ir(self):
+ """Generate Catalyst's intermediate representation (IR) as an MLIR module.
+
+ Returns:
+ Tuple[ir.Module, str]: the in-memory MLIR module and its string representation
+ """
+
+ mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
+
+ # Inject Runtime Library-specific functions (e.g. setup/teardown).
inject_functions(mlir_module, ctx)
- self._jaxpr = jaxpr
- canonicalizer_options = deepcopy(self.compile_options)
- canonicalizer_options.pipelines = [("0_canonicalize", ["canonicalize"])]
- canonicalizer_options.lower_to_llvm = False
- canonicalizer = Compiler(canonicalizer_options)
- _, self._mlir, _ = canonicalizer.run(mlir_module, self.workspace)
- return mlir_module
+
+ # Canonicalize the MLIR since there can be a lot of redundancy coming from JAX.
+ options = copy.deepcopy(self.compile_options)
+ options.pipelines = [("0_canonicalize", ["canonicalize"])]
+ options.lower_to_llvm = False
+ canonicalizer = Compiler(options)
+
+ # TODO: the in-memory and textual form are different after this, consider unification
+ _, mlir_string, _ = canonicalizer.run(mlir_module, self.workspace)
+
+ return mlir_module, mlir_string
def compile(self):
- """Compile the current MLIR module."""
+ """Compile an MLIR module to LLVMIR and shared library code.
- if self.compiled_function and self.compiled_function.shared_object:
- self.compiled_function.shared_object.close()
+ Returns:
+ Tuple[CompiledFunction, str]: the compilation result and LLVMIR
+ """
- # WARNING: assumption is that the first function
- # is the entry point to the compiled program.
+ # WARNING: assumption is that the first function is the entry point to the compiled program.
entry_point_func = self.mlir_module.body.operations[0]
restype = entry_point_func.type.results
for res in restype:
- baseType = ir.RankedTensorType(res).element_type
+ baseType = mlir.ir.RankedTensorType(res).element_type
# This will make a check before sending it to the compiler that the return type
# is actually available in most systems. f16 needs a special symbol and linking
# will fail if it is not available.
@@ -221,118 +266,51 @@ def compile(self):
# The function name out of MLIR has quotes around it, which we need to remove.
# The MLIR function name is actually a derived type from string which has no
# `replace` method, so we need to get a regular Python string out of it.
- qfunc_name = str(self.mlir_module.body.operations[0].name).replace('"', "")
+ func_name = str(self.mlir_module.body.operations[0].name).replace('"', "")
+ shared_object, llvm_ir, _ = self.compiler.run(self.mlir_module, self.workspace)
- shared_object, llvm_ir, _inferred_func_data = self.compiler.run(
- self.mlir_module, self.workspace
- )
+ shared_object, llvm_ir, _ = self.compiler.run(self.mlir_module, self.workspace)
+ compiled_fn = CompiledFunction(shared_object, func_name, restype, self.compile_options)
+
+ return compiled_fn, llvm_ir
- self._llvmir = llvm_ir
- options = self.compile_options
- compiled_function = CompiledFunction(shared_object, qfunc_name, restype, options)
- return compiled_function
-
- def _ensure_real_arguments_and_formal_parameters_are_compatible(self, function, *args):
- """Logic to decide whether the function needs to be recompiled
- given ``*args`` and whether ``*args`` need to be promoted.
- A function may need to be compiled if:
- 1. It was not compiled before. Without static arguments, the compiled function
- should be stored in ``self.compiled_function``. With static arguments,
- ``self.compile_options.static_argnums`` stores all previous compiled ones.
- 2. The real arguments sent to the function are not promotable to the type of the
- formal parameters.
+ def run(self, args, kwargs):
+ """Invoke a previously compiled function with the supplied arguments.
Args:
- function: an instance of ``CompiledFunction`` that may need recompilation
- *args: arguments that may be promoted.
+ args (Iterable): the positional arguments to the compiled function
+ kwargs: the keyword arguments to the compiled function
Returns:
- function: an instance of ``CompiledFunction`` that may have been recompiled
- *args: arguments that may have been promoted
+ Any: results of the execution arranged into the original function's output PyTrees
"""
- static_argnums = self.compile_options.static_argnums
- dynamic_args = filter_static_args(args, static_argnums)
- r_sig = get_abstract_signature(dynamic_args)
-
- has_been_compiled = self.compiled_function is not None
- if static_argnums:
- static_args_hash = self.get_static_args_hash(*args)
- prev_function, _ = self.stored_compiled_functions.get(static_args_hash, (None, None))
- has_been_compiled = False
- if prev_function:
- function = prev_function
- has_been_compiled = True
- function.shared_object.open(function.shared_object_file, function.func_name)
- elif self.workspace_used:
- # Create a new space for the new function if the workspace is used.
- self.workspace = WorkspaceManager.get_or_create_workspace(
- self.function_name, self.preferred_workspace_dir
- )
- # The workspace is unused only for first compilation with static arguments.
- self.workspace_used = True
-
- next_action = TypeCompatibility.UNKNOWN
- if not has_been_compiled:
- next_action = TypeCompatibility.NEEDS_COMPILATION
- else:
- abstracted_axes = self.compile_options.abstracted_axes
- next_action = typecheck_signatures(self.c_sig, r_sig, abstracted_axes)
-
- if next_action == TypeCompatibility.NEEDS_PROMOTION:
- args = promote_arguments(self.c_sig, dynamic_args)
- elif next_action == TypeCompatibility.NEEDS_COMPILATION:
- if self.user_typed:
- msg = "Provided arguments did not match declared signature, recompiling..."
- warnings.warn(msg, UserWarning)
- sig = merge_static_args(r_sig, args, static_argnums)
- self.mlir_module = self.get_mlir(*sig)
- function = self.compile()
- else:
- assert next_action == TypeCompatibility.CAN_SKIP_PROMOTION
-
- return function, args
-
- def __call__(self, *args, **kwargs):
- static_argnums = self.compile_options.static_argnums
- if EvaluationContext.is_tracing():
- return self.user_function(*args, **kwargs)
+ results = self.compiled_function(*args, **kwargs)
- function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
- self.compiled_function, *args
- )
+ # TODO: Move this to the compiled function object.
+ return tree_unflatten(self.out_treedef, results)
- recompilation_happened = (
- function != self.compiled_function
- ) and function not in self.stored_compiled_functions.values()
+ # Helper Methods #
- # Check if a function is created and add newly created ones into the hash table.
- if static_argnums and recompilation_happened:
- static_args_hash = self.get_static_args_hash(*args)
- workspace = self.workspace
- self.stored_compiled_functions[static_args_hash] = (function, workspace)
+ def _validate_configuration(self):
+ """Run validations on the supplied options and parameters."""
+ if not hasattr(self.original_function, "__name__"):
+ self.__name__ = "unknown" # allow these cases anyways?
- self.compiled_function = function
-
- args_data, _args_shape = tree_flatten(args)
- if any(isinstance(arg, jax.core.Tracer) for arg in args_data):
- # Only compile a derivative version of the compiled function when needed.
- if self.jaxed_function is None or recompilation_happened:
- self.jaxed_function = JAX_QJIT(self)
-
- return self.jaxed_function(*args, **kwargs)
-
- data = self.compiled_function(*args, **kwargs)
+ # TODO: remove bug since this only works with user-provided signatures
+ parameter_types = self.user_sig or ()
+ for argnum in self.compile_options.static_argnums:
+ if argnum < 0 or argnum >= len(parameter_types):
+ msg = f"argnum {argnum} is beyond the valid range of [0, {len(parameter_types)})."
+ raise CompileError(msg)
- # Unflatten the return value w.r.t. the original PyTree definition if available
- if self.out_tree is not None:
- data = tree_unflatten(self.out_tree, data)
+ def _get_workspace(self):
+ """Get or create a workspace to use for compilation."""
- # For the classical and pennylane_extensions compilation path,
- if isinstance(data, (list, tuple)) and len(data) == 1:
- data = data[0]
+ workspace_name = self.__name__
+ preferred_workspace_dir = os.getcwd() if self.compile_options.keep_intermediate else None
- return data
+ return WorkspaceManager.get_or_create_workspace(workspace_name, preferred_workspace_dir)
class JAX_QJIT:
@@ -366,8 +344,8 @@ def wrap_callback(qjit_function, *args, **kwargs):
)
# Unflatten the return value w.r.t. the original PyTree definition if available
- assert qjit_function.out_tree is not None, "PyTree shape must not be none."
- return tree_unflatten(qjit_function.out_tree, data)
+ assert qjit_function.out_treedef is not None, "PyTree shape must not be none."
+ return tree_unflatten(qjit_function.out_treedef, data)
def get_derivative_qjit(self, argnums):
"""Compile a function computing the derivative of the wrapped QJIT for the given argnums."""
diff --git a/frontend/catalyst/pennylane_extensions.py b/frontend/catalyst/pennylane_extensions.py
index b73932ecf2..e141439359 100644
--- a/frontend/catalyst/pennylane_extensions.py
+++ b/frontend/catalyst/pennylane_extensions.py
@@ -202,14 +202,14 @@ def dec_no_params(fn):
Differentiable = Union[Function, QNode]
-DifferentiableLike = Union[Differentiable, Callable, "catalyst.compilation_pipelines.QJIT"]
+DifferentiableLike = Union[Differentiable, Callable, "catalyst.QJIT"]
def _ensure_differentiable(f: DifferentiableLike) -> Differentiable:
"""Narrows down the set of the supported differentiable objects."""
# Unwrap the function from an existing QJIT object.
- if isinstance(f, catalyst.compilation_pipelines.QJIT):
+ if isinstance(f, catalyst.QJIT):
f = f.user_function
if isinstance(f, (Function, QNode)):
diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py
index c212d61b4e..0396a95592 100644
--- a/frontend/catalyst/tracing/type_signatures.py
+++ b/frontend/catalyst/tracing/type_signatures.py
@@ -77,8 +77,7 @@ def filter_static_args(args, static_argnums):
return tuple(args[idx] for idx in range(len(args)) if idx not in static_argnums)
-# TODO: remove pragma in part 2
-def split_static_args(args, static_argnums): # pragma: nocover
+def split_static_args(args, static_argnums):
"""Split arguments into static and dynamic values using the provided index list.
Args:
@@ -121,8 +120,7 @@ def merge_static_args(signature, args, static_argnums):
return tuple(merged_sig)
-# TODO: remove pragma in part 2
-def get_decomposed_signature(args, static_argnums): # pragma: nocover
+def get_decomposed_signature(args, static_argnums):
"""Decompose function arguments into dynamic and static arguments, where the dynamic arguments
are further processed into abstract values and PyTree metadata. All values returned by this
function are hashable.
diff --git a/frontend/test/lit/test_tensor_ops.mlir.py b/frontend/test/lit/test_tensor_ops.mlir.py
index 9ba48b914d..fc8a9c8542 100644
--- a/frontend/test/lit/test_tensor_ops.mlir.py
+++ b/frontend/test/lit/test_tensor_ops.mlir.py
@@ -18,6 +18,7 @@
from jax import numpy as jnp
from catalyst import measure, qjit
+from catalyst.debug import print_compilation_stage
# Test methodology:
# Each mathematical function found in numpy
@@ -42,7 +43,7 @@ def test_ewise_arctan2(x, y):
test_ewise_arctan2(jnp.array(1.0), jnp.array(2.0))
-test_ewise_arctan2.print_stage("BufferizationPass")
+print_compilation_stage(test_ewise_arctan2, "BufferizationPass")
# Need more time to test
# jnp.ldexp
@@ -75,7 +76,7 @@ def test_ewise_add(x, y):
test_ewise_add(jnp.array(1.0), jnp.array(2.0))
-test_ewise_add.print_stage("BufferizationPass")
+print_compilation_stage(test_ewise_add, "BufferizationPass")
# CHECK-LABEL: test_ewise_mult
@@ -90,7 +91,7 @@ def test_ewise_mult(x, y):
test_ewise_mult(jnp.array(1.0), jnp.array(2.0))
-test_ewise_mult.print_stage("BufferizationPass")
+print_compilation_stage(test_ewise_mult, "BufferizationPass")
# CHECK-LABEL: test_ewise_div
@@ -105,7 +106,7 @@ def test_ewise_div(x, y):
test_ewise_div(jnp.array(1.0), jnp.array(2.0))
-test_ewise_div.print_stage("BufferizationPass")
+print_compilation_stage(test_ewise_div, "BufferizationPass")
# CHECK-LABEL: test_ewise_power
@@ -120,7 +121,7 @@ def test_ewise_power(x, y):
test_ewise_power(jnp.array(1.0), jnp.array(2.0))
-test_ewise_power.print_stage("BufferizationPass")
+print_compilation_stage(test_ewise_power, "BufferizationPass")
# CHECK-LABEL: test_ewise_sub
@@ -135,7 +136,7 @@ def test_ewise_sub(x, y):
test_ewise_sub(jnp.array(1.0), jnp.array(2.0))
-test_ewise_sub.print_stage("BufferizationPass")
+print_compilation_stage(test_ewise_sub, "BufferizationPass")
@qjit(keep_intermediate=True)
@@ -150,7 +151,7 @@ def test_ewise_true_div(x, y):
test_ewise_true_div(jnp.array(1.0), jnp.array(2.0))
-test_ewise_true_div.print_stage("BufferizationPass")
+print_compilation_stage(test_ewise_true_div, "BufferizationPass")
# Not sure why the following ops are not working
# perhaps they rely on another function?
@@ -169,7 +170,7 @@ def test_ewise_float_power(x, y):
test_ewise_float_power(jnp.array(1.0), jnp.array(2.0))
-test_ewise_float_power.print_stage("BufferizationPass")
+print_compilation_stage(test_ewise_float_power, "BufferizationPass")
# Not sure why the following ops are not working
@@ -194,7 +195,7 @@ def test_ewise_maximum(x, y):
test_ewise_maximum(jnp.array(1.0), jnp.array(2.0))
-test_ewise_maximum.print_stage("BufferizationPass")
+print_compilation_stage(test_ewise_maximum, "BufferizationPass")
# Only single function support
# * jnp.fmax
@@ -212,7 +213,7 @@ def test_ewise_minimum(x, y):
test_ewise_minimum(jnp.array(1.0), jnp.array(2.0))
-test_ewise_minimum.print_stage("BufferizationPass")
+print_compilation_stage(test_ewise_minimum, "BufferizationPass")
# Only single function support
# * jnp.fmin
diff --git a/frontend/test/pytest/test_autograph.py b/frontend/test/pytest/test_autograph.py
index f3afd75fc5..b7cc3e5f36 100644
--- a/frontend/test/pytest/test_autograph.py
+++ b/frontend/test/pytest/test_autograph.py
@@ -169,6 +169,8 @@ def test_unsupported_object(self):
class FN:
"""Test object."""
+ __name__ = "unknown"
+
def __call__(self, x):
return x**2
diff --git a/frontend/test/pytest/test_c_template.py b/frontend/test/pytest/test_c_template.py
index 6827f9eb0e..742125db56 100644
--- a/frontend/test/pytest/test_c_template.py
+++ b/frontend/test/pytest/test_c_template.py
@@ -15,13 +15,8 @@
"""Unit tests for contents of c_template
"""
import numpy as np
-import pennylane as qml
-import pytest
-from catalyst import qjit
-from catalyst.debug import get_cmain
from catalyst.utils.c_template import CType, CVariable
-from catalyst.utils.exceptions import CompileError
class TestCType:
@@ -113,68 +108,3 @@ def test_get_strides_nonzero_rank(self):
x = np.array([1])
# pylint: disable=protected-access
assert CVariable._get_strides(x) == "1"
-
-
-# pylint: disable=too-few-public-methods
-class TestCProgramGeneration:
- """Test C Program generation"""
-
- def test_program_generation(self):
- """Test C Program generation"""
- dev = qml.device("lightning.qubit", wires=2)
-
- @qjit
- @qml.qnode(dev)
- def f(x: float):
- """Returns two states."""
- qml.RX(x, wires=1)
- return qml.state(), qml.state()
-
- template = get_cmain(f, 4.0)
- assert "main" in template
- assert "struct result_t result_val;" in template
- assert "buff_0 = 4.0" in template
- assert "arg_0 = { &buff_0, &buff_0, 0 }" in template
- assert "_catalyst_ciface_jit_f(&result_val, &arg_0);" in template
-
- def test_program_without_return_nor_arguments(self):
- """Test program without return value nor arguments."""
-
- @qjit
- def f():
- """No-op function."""
- return None
-
- template = get_cmain(f)
- assert "struct result_t result_val;" not in template
- assert "buff_0" not in template
- assert "arg_0" not in template
-
-
-class TestCProgramGenerationErrors:
- """Test errors raised from the c program generation feature."""
-
- def test_raises_error_if_tracing(self):
- """Test errors if c program generation requested during tracing."""
-
- @qjit
- def f(x: float):
- """Identity function."""
- return x
-
- with pytest.raises(CompileError, match="C interface cannot be generated"):
-
- @qjit
- def error_fn(x: float):
- """Should raise an error as we try to generate the C template during tracing."""
- return get_cmain(f, x)
-
- def test_error_non_qjit_object(self):
- """An error should be raised if the object supplied to the debug function is not a QJIT."""
-
- def f(x: float):
- """Identity function."""
- return x
-
- with pytest.raises(TypeError, match="First argument needs to be a 'QJIT' object"):
- get_cmain(f, 0.5)
diff --git a/frontend/test/pytest/test_compiler.py b/frontend/test/pytest/test_compiler.py
index 769e05b8aa..01d1ab702b 100644
--- a/frontend/test/pytest/test_compiler.py
+++ b/frontend/test/pytest/test_compiler.py
@@ -160,13 +160,13 @@ def run_from_ir(self, *_args, **_kwargs):
filename = str(pathlib.Path(output).absolute())
return filename, "", ["", ""]
- @qjit(target="fake_binary")
+ @qjit(target="mlir")
@qml.qnode(qml.device(backend, wires=1))
def cpp_exception_test():
return None
cpp_exception_test.compiler = MockCompiler(cpp_exception_test.compiler.options)
- compiled_function = cpp_exception_test.compile()
+ compiled_function, _ = cpp_exception_test.compile()
with pytest.raises(RuntimeError, match="Hello world"):
compiled_function()
@@ -180,6 +180,31 @@ def test_linker_driver_invalid_file(self):
class TestCompilerState:
"""Test states that the compiler can reach."""
+ def test_invalid_target(self):
+ """Test that nothing happens in AOT compilation when an invalid target is provided."""
+
+ @qjit(target="hello")
+ def f():
+ return 0
+
+ assert f.jaxpr is None
+ assert f.mlir is None
+ assert f.qir is None
+ assert f.compiled_function is None
+
+ def test_callable_without_name(self):
+ """Test that a callable without __name__ property can be compiled, if it is otherwise
+ traceable."""
+
+ class NoNameClass:
+ def __call__(self, x):
+ return x
+
+ f = qjit(NoNameClass())
+
+ assert f(3) == 3
+ assert f.__name__ == "unknown"
+
def test_print_stages(self, backend):
"""Test that after compiling the intermediate files exist."""
diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py
index c03f358da8..4e1d79e148 100644
--- a/frontend/test/pytest/test_conditionals.py
+++ b/frontend/test/pytest/test_conditionals.py
@@ -51,7 +51,7 @@ def cond_fn():
def asline(text):
return " ".join(map(lambda x: x.strip(), str(text).split("\n")))
- assert asline(expected) == asline(circuit._jaxpr) # pylint: disable=protected-access
+ assert asline(expected) == asline(circuit.jaxpr)
class TestCond:
diff --git a/frontend/test/pytest/test_debug.py b/frontend/test/pytest/test_debug.py
index ec26ac3045..5506b33d53 100644
--- a/frontend/test/pytest/test_debug.py
+++ b/frontend/test/pytest/test_debug.py
@@ -19,7 +19,7 @@
from catalyst import debug, for_loop, qjit
from catalyst.compiler import CompileOptions, Compiler
-from catalyst.debug import compile_from_mlir
+from catalyst.debug import compile_from_mlir, get_cmain, print_compilation_stage
from catalyst.utils.exceptions import CompileError
from catalyst.utils.runtime import get_lib_path
@@ -189,6 +189,34 @@ def func2():
assert out == "hello\ngoodbye\n"
+class TestPrintStage:
+ """Test that compilation pipeline results can be printed."""
+
+ def test_hlo_lowering_stage(self, capsys):
+ """Test that the IR can be printed after the HLO lowering pipeline."""
+
+ @qjit(keep_intermediate=True)
+ def func():
+ return 0
+
+ print_compilation_stage(func, "HLOLoweringPass")
+
+ out, err = capsys.readouterr()
+ assert "@jit_func() -> tensor" in out
+ assert "stablehlo.constant" not in out
+
+ func.workspace.cleanup()
+
+ def test_invalid_object(self):
+ """Test the function on a non-QJIT object."""
+
+ def func():
+ return 0
+
+ with pytest.raises(TypeError, match="needs to be a 'QJIT' object"):
+ print_compilation_stage(func, "HLOLoweringPass")
+
+
class TestCompileFromIR:
"""Test the debug feature that compiles from a string representation of the IR."""
@@ -287,5 +315,78 @@ def test_parsing_errors(self):
assert "Failed to parse module as LLVM source" in e.value.args[0]
+class TestCProgramGeneration:
+ """Test C Program generation"""
+
+ def test_program_generation(self):
+ """Test C Program generation"""
+ dev = qml.device("lightning.qubit", wires=2)
+
+ @qjit
+ @qml.qnode(dev)
+ def f(x: float):
+ """Returns two states."""
+ qml.RX(x, wires=1)
+ return qml.state(), qml.state()
+
+ template = get_cmain(f, 4.0)
+ assert "main" in template
+ assert "struct result_t result_val;" in template
+ assert "buff_0 = 4.0" in template
+ assert "arg_0 = { &buff_0, &buff_0, 0 }" in template
+ assert "_catalyst_ciface_jit_f(&result_val, &arg_0);" in template
+
+ def test_program_without_return_nor_arguments(self):
+ """Test program without return value nor arguments."""
+
+ @qjit
+ def f():
+ """No-op function."""
+ return None
+
+ template = get_cmain(f)
+ assert "struct result_t result_val;" not in template
+ assert "buff_0" not in template
+ assert "arg_0" not in template
+
+ def test_generation_with_promotion(self):
+ """Test that C program generation works on QJIT objects and args that require promotion."""
+
+ @qjit
+ def f(x: float):
+ """Identity function."""
+ return x
+
+ template = get_cmain(f, 1)
+
+ assert "main" in template
+ assert "buff_0 = 1.0" in template # argument was automaatically promoted
+
+ def test_raises_error_if_tracing(self):
+ """Test errors if c program generation requested during tracing."""
+
+ @qjit
+ def f(x: float):
+ """Identity function."""
+ return x
+
+ with pytest.raises(CompileError, match="C interface cannot be generated"):
+
+ @qjit
+ def error_fn(x: float):
+ """Should raise an error as we try to generate the C template during tracing."""
+ return get_cmain(f, x)
+
+ def test_error_non_qjit_object(self):
+ """An error should be raised if the object supplied to the debug function is not a QJIT."""
+
+ def f(x: float):
+ """Identity function."""
+ return x
+
+ with pytest.raises(TypeError, match="First argument needs to be a 'QJIT' object"):
+ get_cmain(f, 0.5)
+
+
if __name__ == "__main__":
pytest.main(["-x", __file__])
diff --git a/frontend/test/pytest/test_jax_integration.py b/frontend/test/pytest/test_jax_integration.py
index 7ca44e22a5..36c8e3c3e6 100644
--- a/frontend/test/pytest/test_jax_integration.py
+++ b/frontend/test/pytest/test_jax_integration.py
@@ -22,7 +22,7 @@
import pytest
from catalyst import for_loop, measure, qjit
-from catalyst.compilation_pipelines import JAX_QJIT
+from catalyst.jit import JAX_QJIT
class TestJAXJIT:
diff --git a/frontend/test/pytest/test_static_arguments.py b/frontend/test/pytest/test_static_arguments.py
index d19ef6ffae..e1646357ad 100644
--- a/frontend/test/pytest/test_static_arguments.py
+++ b/frontend/test/pytest/test_static_arguments.py
@@ -25,7 +25,7 @@
class TestStaticArguments:
"""Test QJIT with static arguments."""
- @pytest.mark.parametrize("argnums", [(()), (None)])
+ @pytest.mark.parametrize("argnums", [(), None])
def test_zero_static_argument(self, argnums):
"""Test QJIT with zero static argument."""
@@ -37,7 +37,7 @@ def f(
assert f(1) == 1
- @pytest.mark.parametrize("argnums", [(-1), (100)])
+ @pytest.mark.parametrize("argnums", [-1, 100])
def test_out_of_bounds_static_argument(self, argnums):
"""Test QJIT with invalid static argument index."""
diff --git a/frontend/test/pytest/test_tracing.py b/frontend/test/pytest/test_tracing.py
new file mode 100644
index 0000000000..9b7b84a38f
--- /dev/null
+++ b/frontend/test/pytest/test_tracing.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for Catalyst's tracing module."""
+
+import jax
+import pytest
+
+from catalyst.jax_tracer import lower_jaxpr_to_mlir
+
+
+def test_jaxpr_lowering_without_dynshapes():
+ """Test that the lowering function can be used without Catalyst's dynamic shape support."""
+
+ def f():
+ return 0
+
+ jaxpr = jax.make_jaxpr(f)()
+ result, _ = lower_jaxpr_to_mlir(jaxpr, "test_fn")
+
+ assert "@jit_test_fn() -> tensor" in str(result)
+
+
+if __name__ == "__main__":
+ pytest.main(["-x", __file__])