diff --git a/doc/changelog.md b/doc/changelog.md index 738f79e004..61dec29de4 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -51,6 +51,44 @@

Improvements

+* Catalyst will now remember previously compiled functions when the PyTree metadata of arguments + changes, in addition to already rememebering compiled functions when static arguments change. + [(#522)](https://github.com/PennyLaneAI/catalyst/pull/531) + + The following example will no longer trigger a third compilation: + ```py + @qjit + def func(x): + print("compiling") + return x + ``` + ```pycon + >>> func([1,]); # list + compiling + >>> func((2,)); # tuple + compiling + >>> func([3,]); # list + + ``` + + Note however that in order to keep overheads low, changing the argument *type* or *shape* (in a + promotion incompatible way) may override a previously stored function (with identical PyTree + metadata and static argument values): + ```py + @qjit + def func(x): + print("compiling") + return x + ``` + ```pycon + >>> func(jnp.array(1)); # scalar + compiling + >>> func(jnp.array([2.])); # 1-D array + compiling + >>> func(jnp.array(3)); # scalar + compiling + ``` + * Keep the structure of the function return when taking the derivatives, JVP and VJP (pytrees support). [(#500)](https://github.com/PennyLaneAI/catalyst/pull/500) [(#501)](https://github.com/PennyLaneAI/catalyst/pull/501) @@ -143,6 +181,24 @@

Breaking changes

+* The Catalyst Python frontend has been partially refactored. The impact on user-facing + functionality is minimal, but the location of certain classes and methods used by the package + may have changed. + [(#529)](https://github.com/PennyLaneAI/catalyst/pull/529) + [(#522)](https://github.com/PennyLaneAI/catalyst/pull/531) + + The following changes have been made: + * Some debug methods and features on the QJIT class have been turned into free functions and moved + to the `catalyst.debug` module, which will now appear in the public documention. This includes + compiling a program from ir, obtaining a C program to invoke a compiled function from, and + printing fine-grained MLIR compilation stages. + * The `compilation_pipelines.py` module has been renamed to `jit.py`, and certain functionality + has been moved out (see following items). + * A new module `compiled_functions.py` now manages low-level access to compiled functions. + * A new module `tracing/type_signatures.py` handles functionality related managing arguments + and type signatures during the tracing process. + * The `contexts.py` module has been moved from `utils` to the new `tracing` sub-module. + * `QCtrl` is overriden and never used. [(#522)](https://github.com/PennyLaneAI/catalyst/pull/522) @@ -247,6 +303,11 @@

Bug fixes

+* Catalyst will no longer print a warning that recompilation is triggered when a `@qjit` decorated + function with no arguments is invoke without having been compiled first, for example via the use + of `target="mlir"`. + [(#522)](https://github.com/PennyLaneAI/catalyst/pull/531) + * Only set `JAX_DYNAMIC_SHAPES` configuration option during `trace_to_mlir()`. [(#526)](https://github.com/PennyLaneAI/catalyst/pull/526) diff --git a/doc/dev/sharp_bits.rst b/doc/dev/sharp_bits.rst index 800fc1207c..e25c19b12a 100644 --- a/doc/dev/sharp_bits.rst +++ b/doc/dev/sharp_bits.rst @@ -310,7 +310,6 @@ array(0.16996714) However, deviating from this will result in recompilation and a warning message: >>> circuit(jnp.array([1.4, 1.4, 0.3, 0.1])) -catalyst/compilation_pipelines.py:592: UserWarning: Provided arguments did not match declared signature, recompiling... Tracing occurring array(0.16996714) diff --git a/frontend/README.rst b/frontend/README.rst index 7cb687e50b..388fbad36e 100644 --- a/frontend/README.rst +++ b/frontend/README.rst @@ -38,7 +38,7 @@ The ``catalyst`` Python package is a mixed Python package which relies on some C and the following modules: -- `compilation_pipelines.py `_: +- `jit.py `_: This module contains classes and decorators for just-in-time and ahead-of-time compilation of hybrid quantum-classical functions using Catalyst. diff --git a/frontend/catalyst/__init__.py b/frontend/catalyst/__init__.py index 6b7b65f234..20cc2f8223 100644 --- a/frontend/catalyst/__init__.py +++ b/frontend/catalyst/__init__.py @@ -67,7 +67,8 @@ from catalyst import debug from catalyst.ag_utils import AutoGraphError, autograph_source -from catalyst.compilation_pipelines import QJIT, CompileOptions, qjit +from catalyst.compiler import CompileOptions +from catalyst.jit import QJIT, qjit from catalyst.pennylane_extensions import ( adjoint, cond, diff --git a/frontend/catalyst/compiled_functions.py b/frontend/catalyst/compiled_functions.py index 1809e52d5c..ad1b9ca7e3 100644 --- a/frontend/catalyst/compiled_functions.py +++ b/frontend/catalyst/compiled_functions.py @@ -15,10 +15,12 @@ """This module contains classes to manage compiled functions and their underlying resources.""" import ctypes +from dataclasses import dataclass +from typing import Tuple import numpy as np from jax.interpreters import mlir -from jax.tree_util import tree_flatten, tree_unflatten +from jax.tree_util import PyTreeDef, tree_flatten, tree_unflatten from mlir_quantum.runtime import ( as_ctype, get_ranked_memref_descriptor, @@ -27,9 +29,15 @@ ) from catalyst.jax_extras import get_implicit_and_explicit_flat_args -from catalyst.tracing.type_signatures import filter_static_args +from catalyst.tracing.type_signatures import ( + TypeCompatibility, + filter_static_args, + get_decomposed_signature, + typecheck_signatures, +) from catalyst.utils import wrapper # pylint: disable=no-name-in-module from catalyst.utils.c_template import get_template, mlir_type_to_numpy_type +from catalyst.utils.filesystem import Directory class SharedObjectManager: @@ -43,17 +51,19 @@ class SharedObjectManager: """ def __init__(self, shared_object_file, func_name): + self.shared_object_file = shared_object_file self.shared_object = None + self.func_name = func_name self.function = None self.setup = None self.teardown = None self.mem_transfer = None - self.open(shared_object_file, func_name) + self.open() - def open(self, shared_object_file, func_name): + def open(self): """Open the sharead object and load symbols.""" - self.shared_object = ctypes.CDLL(shared_object_file) - self.function, self.setup, self.teardown, self.mem_transfer = self.load_symbols(func_name) + self.shared_object = ctypes.CDLL(self.shared_object_file) + self.function, self.setup, self.teardown, self.mem_transfer = self.load_symbols() def close(self): """Close the shared object""" @@ -66,17 +76,14 @@ def close(self): # pylint: disable=protected-access dlclose(self.shared_object._handle) - def load_symbols(self, func_name): + def load_symbols(self): """Load symbols necessary for for execution of the compiled function. - Args: - func_name: name of compiled function to be executed - Returns: - function: function handle - setup: handle to the setup function, which initializes the device - teardown: handle to the teardown function, which tears down the device - mem_transfer: memory transfer shared object + CFuncPtr: handle to the main function of the program + CFuncPtr: handle to the setup function, which initializes the device + CFuncPtr: handle to the teardown function, which tears down the device + CFuncPtr: handle to the memory transfer function for program results """ setup = self.shared_object.setup @@ -88,7 +95,7 @@ def load_symbols(self, func_name): teardown.restypes = None # We are calling the c-interface - function = self.shared_object["_catalyst_pyface_" + func_name] + function = self.shared_object["_catalyst_pyface_" + self.func_name] # Guaranteed from _mlir_ciface specification function.restypes = None # Not needed, computed from the arguments. @@ -122,7 +129,6 @@ class CompiledFunction: """ def __init__(self, shared_object_file, func_name, restype, compile_options): - self.shared_object_file = shared_object_file self.shared_object = SharedObjectManager(shared_object_file, func_name) self.compile_options = compile_options self.return_type_c_abi = None @@ -140,7 +146,7 @@ def _exec(shared_object, has_return, numpy_dict, *args): *args: arguments to the function Returns: - retval: the value computed by the function or None if the function has no return value + the return values computed by the function or None if the function has no results """ with shared_object as lib: @@ -256,8 +262,8 @@ def args_to_memref_descs(self, restype, args): args: the JAX arrays to be used as arguments to the function Returns: - c_abi_args: a list of memref descriptor pointers to return values and parameters - numpy_arg_buffer: A list to the return values. It must be kept around until the function + List: a list of memref descriptor pointers to return values and parameters + List: A list to the return values. It must be kept around until the function finishes execution as the memref descriptors will point to memory locations inside numpy arrays. @@ -327,3 +333,126 @@ def __call__(self, *args, **kwargs): ) return result + + +@dataclass +class CacheKey: + """A key by which to identify entries in the compiled function cache. + + The cache only rembers one compiled function for each possible combination of: + - dynamic argument PyTree metadata + - static argument values + """ + + treedef: PyTreeDef + static_args: Tuple + + def __hash__(self) -> int: + if not hasattr(self, "_hash"): + # pylint: disable=attribute-defined-outside-init + self._hash = hash((self.treedef, self.static_args)) + return self._hash + + +@dataclass +class CacheEntry: + """An entry in the compiled function cache. + + For each compiled function, the cache stores the dynamic argument signature, the output PyTree + definition, as well as the workspace in which compilation takes place. + """ + + compiled_fn: CompiledFunction + signature: Tuple + out_treedef: PyTreeDef + workspace: Directory + + +class CompilationCache: + """Class to manage CompiledFunction instances and retrieving previously compiled versions of + a function. + + A full match requires the following properties to match: + - dynamic argument signature (the shape and dtype of flattened arrays) + - dynamic argument PyTree definitions + - static argument values + + In order to allow some flexibility in the type of arguments provided by the user, a match is + also produced if the dtype of dynamic arguments can be promoted to the dtype of an existing + signature via JAX type promotion rules. Additional leniency is provided in the shape of + dynamic arguments in accordance with the abstracted axis specification. + + To avoid excess promotion checks, only one function version is stored at a time for a given + combination of PyTreeDefs and static arguments. + """ + + def __init__(self, static_argnums, abstracted_axes): + self.static_argnums = static_argnums + self.abstracted_axes = abstracted_axes + self.cache = {} + + def get_function_status_and_key(self, args): + """Check if the provided arguments match an existing function in the cache. The cache + status of the function is returned as a compilation action: + - no match: requires compilation + - partial match: requires argument promotion + - full match: skip promotion + + Args: + args: arguments to match to existing functions + + Returns: + TypeCompatibility + CacheKey | None + """ + if not self.cache: + return TypeCompatibility.NEEDS_COMPILATION, None + + flat_runtime_sig, treedef, static_args = get_decomposed_signature(args, self.static_argnums) + key = CacheKey(treedef, static_args) + if key not in self.cache: + return TypeCompatibility.NEEDS_COMPILATION, None + + entry = self.cache[key] + runtime_signature = tree_unflatten(treedef, flat_runtime_sig) + action = typecheck_signatures(entry.signature, runtime_signature, self.abstracted_axes) + return action, key + + def lookup(self, args): + """Get a function (if present) that matches the provided argument signature. Also computes + whether promotion is necessary. + + Args: + args (Iterable): the arguments to match to existing functions + + Returns: + CacheEntry | None: the matched cache entry + bool: whether the matched entry requires argument promotion + """ + action, key = self.get_function_status_and_key(args) + + if action == TypeCompatibility.NEEDS_COMPILATION: + return None, None + elif action == TypeCompatibility.NEEDS_PROMOTION: + return self.cache[key], True + else: + assert action == TypeCompatibility.CAN_SKIP_PROMOTION + return self.cache[key], False + + def insert(self, fn, args, out_treedef, workspace): + """Inserts the provided function into the cache. + + Args: + fn (CompiledFunction): compilation result + args (Iterable): arguments to determine cache key and additional metadata from + out_treedef (PyTreeDef): the output shape of the function + workspace (Directory): directory where compilation artifacts are stored + """ + assert isinstance(fn, CompiledFunction) + + flat_signature, treedef, static_args = get_decomposed_signature(args, self.static_argnums) + signature = tree_unflatten(treedef, flat_signature) + + key = CacheKey(treedef, static_args) + entry = CacheEntry(fn, signature, out_treedef, workspace) + self.cache[key] = entry diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index c5e04d53bd..66d6067cb5 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -524,10 +524,3 @@ def get_output_of(self, pipeline) -> Optional[str]: return None return self.last_compiler_output.get_pipeline_output(pipeline) - - def print(self, pipeline): - """Print the output IR of pass. - Args: - pipeline (str): name of pass class - """ - print(self.get_output_of(pipeline)) # pragma: no cover diff --git a/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py b/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py index 58438bdf43..5557eb5f21 100644 --- a/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py +++ b/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py @@ -840,12 +840,11 @@ def cudaq_backend_info(device): # We could also pass abstract arguments here in *args # the same way we do so in Catalyst. # But I think that is redundant now given make_jaxpr2 - _, jaxpr, _, out_tree = trace_to_jaxpr(func, static_args, abs_axes, *args) + jaxpr, out_treedef = trace_to_jaxpr(func, static_args, abs_axes, args, {}) # TODO(@erick-xanadu): # What about static_args? - # We could return _out_type2 as well - return jaxpr, out_tree + return jaxpr, out_treedef def interpret(fun): diff --git a/frontend/catalyst/debug.py b/frontend/catalyst/debug.py index 5622b5bcc0..0e29ce9403 100644 --- a/frontend/catalyst/debug.py +++ b/frontend/catalyst/debug.py @@ -25,6 +25,7 @@ from catalyst.compiler import Compiler from catalyst.jax_primitives import print_p from catalyst.tracing.contexts import EvaluationContext +from catalyst.tracing.type_signatures import filter_static_args, promote_arguments from catalyst.utils.filesystem import WorkspaceManager @@ -74,6 +75,36 @@ def func(x: float): builtins.print(x) +def print_compilation_stage(fn, stage): + """Print one of the recorded compilation stages for a JIT-compiled function. + + The stages are indexed by their Catalyst compilation pipeline name, which are either provided + by the user as a compilation option, or predefined in ``catalyst.compiler``. + + Requires ``keep_intermediate=True``. + + Args: + fn (QJIT): a qjit-decorated function + stage (str): string corresponding with the name of the stage to be printed + + **Example** + + .. code-block:: python + + @qjit(keep_intermediate=True) + def func(x: float): + return x + + debug.print_compilation_stage(func, "HLOLoweringPass") + """ + EvaluationContext.check_is_not_tracing("C interface cannot be generated from tracing context.") + + if not isinstance(fn, catalyst.QJIT): + raise TypeError(f"First argument needs to be a 'QJIT' object, got a {type(fn)}.") + + print(fn.compiler.get_output_of(stage)) + + def get_cmain(fn, *args): """Return a C program that calls a jitted function with the provided arguments. @@ -89,13 +120,13 @@ def get_cmain(fn, *args): if not isinstance(fn, catalyst.QJIT): raise TypeError(f"First argument needs to be a 'QJIT' object, got a {type(fn)}.") - # TODO: will be removed in part 2 of the refactor - # pylint: disable=protected-access - complied_function, args = fn._ensure_real_arguments_and_formal_parameters_are_compatible( - fn.compiled_function, *args - ) + requires_promotion = fn.jit_compile(args) + + if requires_promotion: + dynamic_args = filter_static_args(args, fn.compile_options.static_argnums) + args = promote_arguments(fn.c_sig, dynamic_args) - return complied_function.get_cmain(*args) + return fn.compiled_function.get_cmain(*args) # pylint: disable=line-too-long @@ -132,7 +163,7 @@ def compile_from_mlir(ir, compiler=None, compile_options=None): } \""" - compiled_function = compile_from_mlir(ir) + compiled_function = debug.compile_from_mlir(ir) >>> compiled_function(0.1) [0.1] diff --git a/frontend/catalyst/jax_extras/__init__.py b/frontend/catalyst/jax_extras/__init__.py index 5b42c0174d..e9de9e51c4 100644 --- a/frontend/catalyst/jax_extras/__init__.py +++ b/frontend/catalyst/jax_extras/__init__.py @@ -23,6 +23,7 @@ ClosedJaxpr, DynamicJaxprTrace, DynamicJaxprTracer, + DynshapedClosedJaxpr, Jaxpr, PyTreeDef, PyTreeRegistry, @@ -40,7 +41,6 @@ infer_lambda_input_type, initial_style_jaxprs_with_common_consts1, initial_style_jaxprs_with_common_consts2, - jaxpr_remove_implicit, make_jaxpr2, make_jaxpr_effects, new_dynamic_main2, diff --git a/frontend/catalyst/jax_extras/tracing.py b/frontend/catalyst/jax_extras/tracing.py index e10ce69ba8..aad3e376f8 100644 --- a/frontend/catalyst/jax_extras/tracing.py +++ b/frontend/catalyst/jax_extras/tracing.py @@ -44,6 +44,7 @@ from jax.core import ClosedJaxpr, Jaxpr, JaxprEqn, MainTrace, OutputType from jax.core import Primitive as JaxprPrimitive from jax.core import ShapedArray, Trace, eval_jaxpr, gensym, thread_local_state +from jax.extend.linear_util import wrap_init from jax.interpreters.partial_eval import ( DynamicJaxprTrace, DynamicJaxprTracer, @@ -51,7 +52,6 @@ make_jaxpr_effects, ) from jax.lax import convert_element_type -from jax.linear_util import wrap_init from jax.tree_util import ( PyTreeDef, tree_flatten, @@ -68,6 +68,7 @@ __all__ = ( "ClosedJaxpr", + "DynshapedClosedJaxpr", "DynamicJaxprTrace", "DynamicJaxprTracer", "Jaxpr", @@ -84,7 +85,6 @@ "_initial_style_jaxpr", "_input_type_to_tracers", "_module_name_regex", - "jaxpr_remove_implicit", "make_jaxpr_effects", "make_jaxpr2", "new_dynamic_main2", @@ -102,6 +102,35 @@ map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin +class DynshapedClosedJaxpr(ClosedJaxpr): + """A wrapper class to handle implicit/explicit result information used by JAX for dynamically + shaped arrays. Can be used inplace of any other ClosedJaxpr instance.""" + + def __init__(self, jaxpr: Jaxpr, consts: Sequence, output_type: OutputType): + super().__init__(jaxpr, consts) + self.output_type = output_type + + def remove_implicit_results(self): + """Remove all implicit result values from this JAXPR. + + Returns: + ClosedJaxpr + """ + # Note: a more idiomatic way of doing this would be to re-trace the jaxpr and drop the + # unneeded tracers. + if not self.output_type: + return self + + jaxpr = self.jaxpr + out_keep = tuple(zip(*self.output_type))[1] + outvars = [o for o, keep in zip(jaxpr._outvars, out_keep) if keep] + filtered_jaxpr = Jaxpr( + jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns, jaxpr.effects, jaxpr.debug_info + ) + + return ClosedJaxpr(filtered_jaxpr, self.consts) + + @contextmanager def transient_jax_config() -> Generator[None, None, None]: """Context manager which updates transient JAX configuration options, @@ -325,22 +354,6 @@ def get_implicit_and_explicit_flat_args(abstracted_axes, *args, **kwargs): return args_flat -def jaxpr_remove_implicit( - closed_jaxpr: ClosedJaxpr, out_type: OutputType -) -> tuple[ClosedJaxpr, OutputType]: - """Remove all the implicit result values of the ``closed_jaxpr``.""" - # Note: a more idiomatic way of doing this would be to re-trace the jaxpr and drop the unneeded - # tracers. - jaxpr = closed_jaxpr.jaxpr - out_keep = list(tuple(zip(*out_type))[1]) if len(out_type) > 0 else [] - outvars = [o for o, keep in zip(jaxpr._outvars, out_keep) if keep] - out_type2 = [o for o, keep in zip(out_type, out_keep) if keep] - jaxpr2 = Jaxpr( - jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns, jaxpr.effects, jaxpr.debug_info - ) - return ClosedJaxpr(jaxpr2, closed_jaxpr.consts), out_type2 - - def make_jaxpr2( fun: Callable, static_argnums: Any | None = None, @@ -379,9 +392,9 @@ def make_jaxpr_f(*args, **kwargs): in_type, in_tree = abstractify(args, kwargs) f, out_tree_promise = flatten_fun(f, in_tree) f = annotate(f, in_type) - jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f) - closed_jaxpr = ClosedJaxpr(jaxpr, consts) - return closed_jaxpr, out_type, out_tree_promise() + jaxpr, output_type, consts = trace_to_jaxpr_dynamic2(f) + closed_jaxpr = DynshapedClosedJaxpr(jaxpr, consts, output_type) + return closed_jaxpr, out_tree_promise() make_jaxpr_f.__name__ = f"make_jaxpr2({make_jaxpr2.__name__})" return make_jaxpr_f diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index e7e97175e4..8254603661 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -31,6 +31,7 @@ ClosedJaxpr, DynamicJaxprTrace, DynamicJaxprTracer, + DynshapedClosedJaxpr, PyTreeDef, PyTreeRegistry, ShapedArray, @@ -39,7 +40,6 @@ convert_element_type, deduce_avals, eval_jaxpr, - jaxpr_remove_implicit, jaxpr_to_mlir, make_jaxpr2, sort_eqns, @@ -100,7 +100,7 @@ def __init__(self, fn): self.__name__ = fn.__name__ def __call__(self, *args, **kwargs): - jaxpr, _, out_tree = make_jaxpr2(self.fn)(*args) + jaxpr, out_tree = make_jaxpr2(self.fn)(*args) def _eval_jaxpr(*args): return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) @@ -332,11 +332,12 @@ def has_nested_tapes(op: Operation) -> bool: ) -def trace_to_jaxpr(func, static_argnums, abstracted_axes, *args, **kwargs): - """Trace a function to JAXPR. +def trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs): + """Trace a Python function to JAXPR. Args: - func: python function to be lowered + func: python function to be traced + static_argnums: indices of static arguments. abstracted_axes: abstracted axes specification. Necessary for JAX to use dynamic tensor sizes. args: arguments to ``func`` @@ -344,44 +345,32 @@ def trace_to_jaxpr(func, static_argnums, abstracted_axes, *args, **kwargs): Returns: ClosedJaxpr: the Jaxpr program corresponding to ``func`` - ClosedJaxpr: the Jaxpr program corresponding to ``func`` without implicit result values. - jax.OutputType: Jaxpr output type (a list of abstract values paired with - explicintess flags). PyTreeDef: PyTree-shape of the return values in ``PyTreeDef`` """ - with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION): - make_jaxpr_kwargs = { - "abstracted_axes": abstracted_axes, - "static_argnums": static_argnums, - } - jaxpr, out_type, out_tree = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs) - - # We remove implicit Jaxpr result values since we are compiling a top-level jaxpr program. - jaxpr2, out_type2 = jaxpr_remove_implicit(jaxpr, out_type) + with transient_jax_config(): + with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION): + make_jaxpr_kwargs = { + "static_argnums": static_argnums, + "abstracted_axes": abstracted_axes, + } + jaxpr, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs) - return jaxpr, jaxpr2, out_type2, out_tree + return jaxpr, out_treedef -def trace_to_mlir(func, static_argnums, abstracted_axes, *args, **kwargs): - """Lower a Python function into an MLIR module. +def lower_jaxpr_to_mlir(jaxpr, func_name): + """Lower a JAXPR to MLIR. Args: - func: python function to be lowered - static_argnums: indices of static arguments. - abstracted_axes: abstracted axes specification. Necessary for JAX to use dynamic tensor - sizes. - args: arguments to ``func`` - kwargs: keyword arguments to ``func`` + ClosedJaxpr: the JAXPR to lower to MLIR + func_name: a name to use for the MLIR function Returns: - MLIR module: the MLIR module corresponding to ``func`` - MLIR context: the MLIR context - ClosedJaxpr: the Jaxpr program corresponding to ``func`` - jax.OutputType: Jaxpr output type (a list of abstract values paired with - explicintess flags). - PyTreeDef: PyTree-shape of the return values in ``PyTreeDef`` + ir.Module: the MLIR module coontaining the JAX program + ir.Context: the MLIR context """ + # The compilation cache must be clear for each translation unit. Otherwise, MLIR functions # which do not exist in the current translation unit will be assumed to exist if an equivalent # python function is seen in the cache. This happens during testing or if we wanted to compile a @@ -389,12 +378,13 @@ def trace_to_mlir(func, static_argnums, abstracted_axes, *args, **kwargs): mlir_fn_cache.clear() with transient_jax_config(): - jaxpr, postprocessed_jaxpr, out_type, out_tree = trace_to_jaxpr( - func, static_argnums, abstracted_axes, *args, **kwargs - ) - module, context = jaxpr_to_mlir(func.__name__, postprocessed_jaxpr) + # We remove implicit Jaxpr result values since we are compiling a top-level jaxpr program. + if isinstance(jaxpr, DynshapedClosedJaxpr): + jaxpr = jaxpr.remove_implicit_results() + + mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr) - return module, context, jaxpr, out_type, out_tree + return mlir_module, ctx def trace_quantum_tape( diff --git a/frontend/catalyst/compilation_pipelines.py b/frontend/catalyst/jit.py similarity index 70% rename from frontend/catalyst/compilation_pipelines.py rename to frontend/catalyst/jit.py index 958c64aedc..f8b7941dd8 100644 --- a/frontend/catalyst/compilation_pipelines.py +++ b/frontend/catalyst/jit.py @@ -13,36 +13,33 @@ # limitations under the License. """This module contains classes and decorators for just-in-time and ahead-of-time -compiling of hybrid quantum-classical functions using Catalyst. +compilation of hybrid quantum-classical functions using Catalyst. """ +import copy import functools import inspect -import pathlib +import os import warnings -from copy import deepcopy import jax import jax.numpy as jnp import pennylane as qml -from jax.interpreters.mlir import ir +from jax.interpreters import mlir from jax.tree_util import tree_flatten, tree_unflatten import catalyst from catalyst.ag_utils import run_autograph -from catalyst.compiled_functions import CompiledFunction +from catalyst.compiled_functions import CompilationCache, CompiledFunction from catalyst.compiler import CompileOptions, Compiler -from catalyst.jax_tracer import trace_to_mlir -from catalyst.pennylane_extensions import QFunc +from catalyst.jax_tracer import lower_jaxpr_to_mlir, trace_to_jaxpr from catalyst.tracing.contexts import EvaluationContext from catalyst.tracing.type_signatures import ( - TypeCompatibility, filter_static_args, get_abstract_signature, get_type_annotations, merge_static_args, promote_arguments, - typecheck_signatures, ) from catalyst.utils.c_template import mlir_type_to_numpy_type from catalyst.utils.exceptions import CompileError @@ -69,150 +66,198 @@ class QJIT: the :func:`~.qjit` documentation for more details. Args: - fn (Callable): the quantum or classical function - compile_options (Optional[CompileOptions]): common compilation options + fn (Callable): the quantum or classical function to compile + compile_options (CompileOptions): compilation options to use """ def __init__(self, fn, compile_options): + self.original_function = fn self.compile_options = compile_options self.compiler = Compiler(compile_options) - self.original_function = fn - self.user_function = fn - self.jaxed_function = None + self.fn_cache = CompilationCache( + compile_options.static_argnums, compile_options.abstracted_axes + ) + # Active state of the compiler. + # TODO: rework ownership of workspace, possibly CompiledFunction + self.workspace = None + self.c_sig = None + self.out_treedef = None self.compiled_function = None + self.jaxed_function = None + # IRs are only available for the most recently traced function. + self.jaxpr = None + self.mlir = None # string form (historic presence) self.mlir_module = None - self.user_typed = False - self.c_sig = None - self.out_tree = None - self._jaxpr = None - self._mlir = None - self._llvmir = None - self.function_name = None - self.preferred_workspace_dir = None - self.stored_compiled_functions = {} - self.workspace_used = False + self.qir = None functools.update_wrapper(self, fn) + self.user_sig = get_type_annotations(fn) + self._validate_configuration() - if compile_options.autograph: - self.user_function = run_autograph(fn) + self.user_function = self.pre_compilation() - # QJIT is the owner of workspace. - # do not move to compiler. - preferred_workspace_dir = ( - pathlib.Path.cwd() if self.compile_options.keep_intermediate else None - ) - self.preferred_workspace_dir = preferred_workspace_dir + # Static arguments require values, so we cannot AOT compile. + if self.user_sig is not None and not self.compile_options.static_argnums: + self.aot_compile() - # pylint: disable=no-member - # Guaranteed to exist after functools.update_wrapper - self.function_name = self.__name__ + def __call__(self, *args, **kwargs): + # Transparantly call Python function in case of nested QJIT calls. + if EvaluationContext.is_tracing(): + return self.user_function(*args, **kwargs) - self.workspace = WorkspaceManager.get_or_create_workspace( - self.function_name, preferred_workspace_dir - ) + requires_promotion = self.jit_compile(args) - parameter_types = get_type_annotations(self.user_function) + # If we receive tracers as input, dispatch to the JAX integration. + if any(isinstance(arg, jax.core.Tracer) for arg in tree_flatten(args)[0]): + if self.jaxed_function is None: + self.jaxed_function = JAX_QJIT(self) # lazy gradient compilation + return self.jaxed_function(*args, **kwargs) - for argnum in self.compile_options.static_argnums: - if argnum < 0 or argnum >= len(parameter_types): - msg = f"argnum {argnum} is beyond the valid range of [0, {len(parameter_types)})." - raise CompileError(msg) + elif requires_promotion: + dynamic_args = filter_static_args(args, self.compile_options.static_argnums) + args = promote_arguments(self.c_sig, dynamic_args) - if parameter_types is not None and not self.compile_options.static_argnums: - self.user_typed = True - self.mlir_module = self.get_mlir(*parameter_types) - if self.compile_options.target == "binary": - self.compiled_function = self.compile() + return self.run(args, kwargs) - def print_stage(self, stage): - """Print one of the recorded stages. + def aot_compile(self): + """Compile Python function on initialization using the type hint signature.""" - Args: - stage: string corresponding with the name of the stage to be printed - """ - self.compiler.print(stage) # pragma: nocover + self.workspace = self._get_workspace() - @property - def mlir(self): - """str: Returns the MLIR intermediate representation - of the quantum program. - """ - return self._mlir + # TODO: awkward, refactor or redesign the target feature + if self.compile_options.target in ("jaxpr", "mlir", "binary"): + self.jaxpr, self.out_treedef, self.c_sig = self.capture(self.user_sig or ()) - @property - def jaxpr(self): - """str: Returns the JAXPR intermediate representation - of the quantum program. - """ - return self._jaxpr + if self.compile_options.target in ("mlir", "binary"): + self.mlir_module, self.mlir = self.generate_ir() - @property - def qir(self): - """str: Returns the LLVM and QIR intermediate representation - of the quantum program. Only available if the function was compiled to binary. - """ - return self._llvmir + if self.compile_options.target in ("binary",): + self.compiled_function, self.qir = self.compile() + self.fn_cache.insert( + self.compiled_function, self.user_sig, self.out_treedef, self.workspace + ) - def get_static_args_hash(self, *args): - """Get hash values of all static arguments. + def jit_compile(self, args): + """Compile Python function on invocation using the provided arguments. Args: - args: arguments to the compiled function. + args (Iterable): arguments to use for program capture + Returns: - a tuple of hash values of all static arguments. + bool: whether the provided arguments will require promotion to be used with the compiled + function """ - static_argnums = self.compile_options.static_argnums - static_args_hash = tuple( - hash(args[idx]) for idx in range(len(args)) if idx in static_argnums - ) - return static_args_hash - def get_mlir(self, *args): - """Trace :func:`~.user_function` + cached_fn, requires_promotion = self.fn_cache.lookup(args) + + if cached_fn is None: + if self.user_sig and not self.compile_options.static_argnums: + msg = "Provided arguments did not match declared signature, recompiling..." + warnings.warn(msg, UserWarning) + + # Cleanup before recompilation: + # - recompilation should always happen in new workspace + # - compiled functions for jax integration are not yet cached + # - close existing shared library + self.workspace = self._get_workspace() + self.jaxed_function = None + if self.compiled_function and self.compiled_function.shared_object: + self.compiled_function.shared_object.close() + + self.jaxpr, self.out_treedef, self.c_sig = self.capture(args) + self.mlir_module, self.mlir = self.generate_ir() + self.compiled_function, self.qir = self.compile() + + self.fn_cache.insert(self.compiled_function, args, self.out_treedef, self.workspace) + + elif self.compiled_function is not cached_fn.compiled_fn: + # Restore active state from cache. + self.workspace = cached_fn.workspace + self.compiled_function = cached_fn.compiled_fn + self.out_treedef = cached_fn.out_treedef + self.c_sig = cached_fn.signature + self.jaxed_function = None + + self.compiled_function.shared_object.open() + + return requires_promotion + + # Processing Stages # + + def pre_compilation(self): + """Perform pre-processing tasks on the Python function, such as AST transformations.""" + processed_fn = self.original_function + + if self.compile_options.autograph: + processed_fn = run_autograph(self.original_function) + + return processed_fn + + def capture(self, args): + """Capture the JAX program representation (JAXPR) of the wrapped function. Args: - *args: either the concrete values to be passed as arguments to ``fn`` or abstract values + args (Iterable): arguments to use for program capture Returns: - an MLIR module + ClosedJaxpr: captured JAXPR + PyTreeDef: PyTree metadata of the function output + Tuple[Any]: the dynamic argument signature """ + static_argnums = self.compile_options.static_argnums + abstracted_axes = self.compile_options.abstracted_axes + dynamic_args = filter_static_args(args, static_argnums) - self.c_sig = get_abstract_signature(dynamic_args) + dynamic_sig = get_abstract_signature(dynamic_args) + full_sig = merge_static_args(dynamic_sig, args, static_argnums) with Patcher( - (qml.QNode, "__call__", QFunc.__call__), + (qml.QNode, "__call__", catalyst.pennylane_extensions.QFunc.__call__), ): - func = self.user_function - sig = merge_static_args(self.c_sig, args, static_argnums) - abstracted_axes = self.compile_options.abstracted_axes - mlir_module, ctx, jaxpr, _, self.out_tree = trace_to_mlir( - func, static_argnums, abstracted_axes, *sig + # TODO: improve PyTree handling + jaxpr, treedef = trace_to_jaxpr( + self.user_function, static_argnums, abstracted_axes, full_sig, {} ) + return jaxpr, treedef, dynamic_sig + + def generate_ir(self): + """Generate Catalyst's intermediate representation (IR) as an MLIR module. + + Returns: + Tuple[ir.Module, str]: the in-memory MLIR module and its string representation + """ + + mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__) + + # Inject Runtime Library-specific functions (e.g. setup/teardown). inject_functions(mlir_module, ctx) - self._jaxpr = jaxpr - canonicalizer_options = deepcopy(self.compile_options) - canonicalizer_options.pipelines = [("0_canonicalize", ["canonicalize"])] - canonicalizer_options.lower_to_llvm = False - canonicalizer = Compiler(canonicalizer_options) - _, self._mlir, _ = canonicalizer.run(mlir_module, self.workspace) - return mlir_module + + # Canonicalize the MLIR since there can be a lot of redundancy coming from JAX. + options = copy.deepcopy(self.compile_options) + options.pipelines = [("0_canonicalize", ["canonicalize"])] + options.lower_to_llvm = False + canonicalizer = Compiler(options) + + # TODO: the in-memory and textual form are different after this, consider unification + _, mlir_string, _ = canonicalizer.run(mlir_module, self.workspace) + + return mlir_module, mlir_string def compile(self): - """Compile the current MLIR module.""" + """Compile an MLIR module to LLVMIR and shared library code. - if self.compiled_function and self.compiled_function.shared_object: - self.compiled_function.shared_object.close() + Returns: + Tuple[CompiledFunction, str]: the compilation result and LLVMIR + """ - # WARNING: assumption is that the first function - # is the entry point to the compiled program. + # WARNING: assumption is that the first function is the entry point to the compiled program. entry_point_func = self.mlir_module.body.operations[0] restype = entry_point_func.type.results for res in restype: - baseType = ir.RankedTensorType(res).element_type + baseType = mlir.ir.RankedTensorType(res).element_type # This will make a check before sending it to the compiler that the return type # is actually available in most systems. f16 needs a special symbol and linking # will fail if it is not available. @@ -221,118 +266,51 @@ def compile(self): # The function name out of MLIR has quotes around it, which we need to remove. # The MLIR function name is actually a derived type from string which has no # `replace` method, so we need to get a regular Python string out of it. - qfunc_name = str(self.mlir_module.body.operations[0].name).replace('"', "") + func_name = str(self.mlir_module.body.operations[0].name).replace('"', "") + shared_object, llvm_ir, _ = self.compiler.run(self.mlir_module, self.workspace) - shared_object, llvm_ir, _inferred_func_data = self.compiler.run( - self.mlir_module, self.workspace - ) + shared_object, llvm_ir, _ = self.compiler.run(self.mlir_module, self.workspace) + compiled_fn = CompiledFunction(shared_object, func_name, restype, self.compile_options) + + return compiled_fn, llvm_ir - self._llvmir = llvm_ir - options = self.compile_options - compiled_function = CompiledFunction(shared_object, qfunc_name, restype, options) - return compiled_function - - def _ensure_real_arguments_and_formal_parameters_are_compatible(self, function, *args): - """Logic to decide whether the function needs to be recompiled - given ``*args`` and whether ``*args`` need to be promoted. - A function may need to be compiled if: - 1. It was not compiled before. Without static arguments, the compiled function - should be stored in ``self.compiled_function``. With static arguments, - ``self.compile_options.static_argnums`` stores all previous compiled ones. - 2. The real arguments sent to the function are not promotable to the type of the - formal parameters. + def run(self, args, kwargs): + """Invoke a previously compiled function with the supplied arguments. Args: - function: an instance of ``CompiledFunction`` that may need recompilation - *args: arguments that may be promoted. + args (Iterable): the positional arguments to the compiled function + kwargs: the keyword arguments to the compiled function Returns: - function: an instance of ``CompiledFunction`` that may have been recompiled - *args: arguments that may have been promoted + Any: results of the execution arranged into the original function's output PyTrees """ - static_argnums = self.compile_options.static_argnums - dynamic_args = filter_static_args(args, static_argnums) - r_sig = get_abstract_signature(dynamic_args) - - has_been_compiled = self.compiled_function is not None - if static_argnums: - static_args_hash = self.get_static_args_hash(*args) - prev_function, _ = self.stored_compiled_functions.get(static_args_hash, (None, None)) - has_been_compiled = False - if prev_function: - function = prev_function - has_been_compiled = True - function.shared_object.open(function.shared_object_file, function.func_name) - elif self.workspace_used: - # Create a new space for the new function if the workspace is used. - self.workspace = WorkspaceManager.get_or_create_workspace( - self.function_name, self.preferred_workspace_dir - ) - # The workspace is unused only for first compilation with static arguments. - self.workspace_used = True - - next_action = TypeCompatibility.UNKNOWN - if not has_been_compiled: - next_action = TypeCompatibility.NEEDS_COMPILATION - else: - abstracted_axes = self.compile_options.abstracted_axes - next_action = typecheck_signatures(self.c_sig, r_sig, abstracted_axes) - - if next_action == TypeCompatibility.NEEDS_PROMOTION: - args = promote_arguments(self.c_sig, dynamic_args) - elif next_action == TypeCompatibility.NEEDS_COMPILATION: - if self.user_typed: - msg = "Provided arguments did not match declared signature, recompiling..." - warnings.warn(msg, UserWarning) - sig = merge_static_args(r_sig, args, static_argnums) - self.mlir_module = self.get_mlir(*sig) - function = self.compile() - else: - assert next_action == TypeCompatibility.CAN_SKIP_PROMOTION - - return function, args - - def __call__(self, *args, **kwargs): - static_argnums = self.compile_options.static_argnums - if EvaluationContext.is_tracing(): - return self.user_function(*args, **kwargs) + results = self.compiled_function(*args, **kwargs) - function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible( - self.compiled_function, *args - ) + # TODO: Move this to the compiled function object. + return tree_unflatten(self.out_treedef, results) - recompilation_happened = ( - function != self.compiled_function - ) and function not in self.stored_compiled_functions.values() + # Helper Methods # - # Check if a function is created and add newly created ones into the hash table. - if static_argnums and recompilation_happened: - static_args_hash = self.get_static_args_hash(*args) - workspace = self.workspace - self.stored_compiled_functions[static_args_hash] = (function, workspace) + def _validate_configuration(self): + """Run validations on the supplied options and parameters.""" + if not hasattr(self.original_function, "__name__"): + self.__name__ = "unknown" # allow these cases anyways? - self.compiled_function = function - - args_data, _args_shape = tree_flatten(args) - if any(isinstance(arg, jax.core.Tracer) for arg in args_data): - # Only compile a derivative version of the compiled function when needed. - if self.jaxed_function is None or recompilation_happened: - self.jaxed_function = JAX_QJIT(self) - - return self.jaxed_function(*args, **kwargs) - - data = self.compiled_function(*args, **kwargs) + # TODO: remove bug since this only works with user-provided signatures + parameter_types = self.user_sig or () + for argnum in self.compile_options.static_argnums: + if argnum < 0 or argnum >= len(parameter_types): + msg = f"argnum {argnum} is beyond the valid range of [0, {len(parameter_types)})." + raise CompileError(msg) - # Unflatten the return value w.r.t. the original PyTree definition if available - if self.out_tree is not None: - data = tree_unflatten(self.out_tree, data) + def _get_workspace(self): + """Get or create a workspace to use for compilation.""" - # For the classical and pennylane_extensions compilation path, - if isinstance(data, (list, tuple)) and len(data) == 1: - data = data[0] + workspace_name = self.__name__ + preferred_workspace_dir = os.getcwd() if self.compile_options.keep_intermediate else None - return data + return WorkspaceManager.get_or_create_workspace(workspace_name, preferred_workspace_dir) class JAX_QJIT: @@ -366,8 +344,8 @@ def wrap_callback(qjit_function, *args, **kwargs): ) # Unflatten the return value w.r.t. the original PyTree definition if available - assert qjit_function.out_tree is not None, "PyTree shape must not be none." - return tree_unflatten(qjit_function.out_tree, data) + assert qjit_function.out_treedef is not None, "PyTree shape must not be none." + return tree_unflatten(qjit_function.out_treedef, data) def get_derivative_qjit(self, argnums): """Compile a function computing the derivative of the wrapped QJIT for the given argnums.""" diff --git a/frontend/catalyst/pennylane_extensions.py b/frontend/catalyst/pennylane_extensions.py index b73932ecf2..e141439359 100644 --- a/frontend/catalyst/pennylane_extensions.py +++ b/frontend/catalyst/pennylane_extensions.py @@ -202,14 +202,14 @@ def dec_no_params(fn): Differentiable = Union[Function, QNode] -DifferentiableLike = Union[Differentiable, Callable, "catalyst.compilation_pipelines.QJIT"] +DifferentiableLike = Union[Differentiable, Callable, "catalyst.QJIT"] def _ensure_differentiable(f: DifferentiableLike) -> Differentiable: """Narrows down the set of the supported differentiable objects.""" # Unwrap the function from an existing QJIT object. - if isinstance(f, catalyst.compilation_pipelines.QJIT): + if isinstance(f, catalyst.QJIT): f = f.user_function if isinstance(f, (Function, QNode)): diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py index c212d61b4e..0396a95592 100644 --- a/frontend/catalyst/tracing/type_signatures.py +++ b/frontend/catalyst/tracing/type_signatures.py @@ -77,8 +77,7 @@ def filter_static_args(args, static_argnums): return tuple(args[idx] for idx in range(len(args)) if idx not in static_argnums) -# TODO: remove pragma in part 2 -def split_static_args(args, static_argnums): # pragma: nocover +def split_static_args(args, static_argnums): """Split arguments into static and dynamic values using the provided index list. Args: @@ -121,8 +120,7 @@ def merge_static_args(signature, args, static_argnums): return tuple(merged_sig) -# TODO: remove pragma in part 2 -def get_decomposed_signature(args, static_argnums): # pragma: nocover +def get_decomposed_signature(args, static_argnums): """Decompose function arguments into dynamic and static arguments, where the dynamic arguments are further processed into abstract values and PyTree metadata. All values returned by this function are hashable. diff --git a/frontend/test/lit/test_tensor_ops.mlir.py b/frontend/test/lit/test_tensor_ops.mlir.py index 9ba48b914d..fc8a9c8542 100644 --- a/frontend/test/lit/test_tensor_ops.mlir.py +++ b/frontend/test/lit/test_tensor_ops.mlir.py @@ -18,6 +18,7 @@ from jax import numpy as jnp from catalyst import measure, qjit +from catalyst.debug import print_compilation_stage # Test methodology: # Each mathematical function found in numpy @@ -42,7 +43,7 @@ def test_ewise_arctan2(x, y): test_ewise_arctan2(jnp.array(1.0), jnp.array(2.0)) -test_ewise_arctan2.print_stage("BufferizationPass") +print_compilation_stage(test_ewise_arctan2, "BufferizationPass") # Need more time to test # jnp.ldexp @@ -75,7 +76,7 @@ def test_ewise_add(x, y): test_ewise_add(jnp.array(1.0), jnp.array(2.0)) -test_ewise_add.print_stage("BufferizationPass") +print_compilation_stage(test_ewise_add, "BufferizationPass") # CHECK-LABEL: test_ewise_mult @@ -90,7 +91,7 @@ def test_ewise_mult(x, y): test_ewise_mult(jnp.array(1.0), jnp.array(2.0)) -test_ewise_mult.print_stage("BufferizationPass") +print_compilation_stage(test_ewise_mult, "BufferizationPass") # CHECK-LABEL: test_ewise_div @@ -105,7 +106,7 @@ def test_ewise_div(x, y): test_ewise_div(jnp.array(1.0), jnp.array(2.0)) -test_ewise_div.print_stage("BufferizationPass") +print_compilation_stage(test_ewise_div, "BufferizationPass") # CHECK-LABEL: test_ewise_power @@ -120,7 +121,7 @@ def test_ewise_power(x, y): test_ewise_power(jnp.array(1.0), jnp.array(2.0)) -test_ewise_power.print_stage("BufferizationPass") +print_compilation_stage(test_ewise_power, "BufferizationPass") # CHECK-LABEL: test_ewise_sub @@ -135,7 +136,7 @@ def test_ewise_sub(x, y): test_ewise_sub(jnp.array(1.0), jnp.array(2.0)) -test_ewise_sub.print_stage("BufferizationPass") +print_compilation_stage(test_ewise_sub, "BufferizationPass") @qjit(keep_intermediate=True) @@ -150,7 +151,7 @@ def test_ewise_true_div(x, y): test_ewise_true_div(jnp.array(1.0), jnp.array(2.0)) -test_ewise_true_div.print_stage("BufferizationPass") +print_compilation_stage(test_ewise_true_div, "BufferizationPass") # Not sure why the following ops are not working # perhaps they rely on another function? @@ -169,7 +170,7 @@ def test_ewise_float_power(x, y): test_ewise_float_power(jnp.array(1.0), jnp.array(2.0)) -test_ewise_float_power.print_stage("BufferizationPass") +print_compilation_stage(test_ewise_float_power, "BufferizationPass") # Not sure why the following ops are not working @@ -194,7 +195,7 @@ def test_ewise_maximum(x, y): test_ewise_maximum(jnp.array(1.0), jnp.array(2.0)) -test_ewise_maximum.print_stage("BufferizationPass") +print_compilation_stage(test_ewise_maximum, "BufferizationPass") # Only single function support # * jnp.fmax @@ -212,7 +213,7 @@ def test_ewise_minimum(x, y): test_ewise_minimum(jnp.array(1.0), jnp.array(2.0)) -test_ewise_minimum.print_stage("BufferizationPass") +print_compilation_stage(test_ewise_minimum, "BufferizationPass") # Only single function support # * jnp.fmin diff --git a/frontend/test/pytest/test_autograph.py b/frontend/test/pytest/test_autograph.py index f3afd75fc5..b7cc3e5f36 100644 --- a/frontend/test/pytest/test_autograph.py +++ b/frontend/test/pytest/test_autograph.py @@ -169,6 +169,8 @@ def test_unsupported_object(self): class FN: """Test object.""" + __name__ = "unknown" + def __call__(self, x): return x**2 diff --git a/frontend/test/pytest/test_c_template.py b/frontend/test/pytest/test_c_template.py index 6827f9eb0e..742125db56 100644 --- a/frontend/test/pytest/test_c_template.py +++ b/frontend/test/pytest/test_c_template.py @@ -15,13 +15,8 @@ """Unit tests for contents of c_template """ import numpy as np -import pennylane as qml -import pytest -from catalyst import qjit -from catalyst.debug import get_cmain from catalyst.utils.c_template import CType, CVariable -from catalyst.utils.exceptions import CompileError class TestCType: @@ -113,68 +108,3 @@ def test_get_strides_nonzero_rank(self): x = np.array([1]) # pylint: disable=protected-access assert CVariable._get_strides(x) == "1" - - -# pylint: disable=too-few-public-methods -class TestCProgramGeneration: - """Test C Program generation""" - - def test_program_generation(self): - """Test C Program generation""" - dev = qml.device("lightning.qubit", wires=2) - - @qjit - @qml.qnode(dev) - def f(x: float): - """Returns two states.""" - qml.RX(x, wires=1) - return qml.state(), qml.state() - - template = get_cmain(f, 4.0) - assert "main" in template - assert "struct result_t result_val;" in template - assert "buff_0 = 4.0" in template - assert "arg_0 = { &buff_0, &buff_0, 0 }" in template - assert "_catalyst_ciface_jit_f(&result_val, &arg_0);" in template - - def test_program_without_return_nor_arguments(self): - """Test program without return value nor arguments.""" - - @qjit - def f(): - """No-op function.""" - return None - - template = get_cmain(f) - assert "struct result_t result_val;" not in template - assert "buff_0" not in template - assert "arg_0" not in template - - -class TestCProgramGenerationErrors: - """Test errors raised from the c program generation feature.""" - - def test_raises_error_if_tracing(self): - """Test errors if c program generation requested during tracing.""" - - @qjit - def f(x: float): - """Identity function.""" - return x - - with pytest.raises(CompileError, match="C interface cannot be generated"): - - @qjit - def error_fn(x: float): - """Should raise an error as we try to generate the C template during tracing.""" - return get_cmain(f, x) - - def test_error_non_qjit_object(self): - """An error should be raised if the object supplied to the debug function is not a QJIT.""" - - def f(x: float): - """Identity function.""" - return x - - with pytest.raises(TypeError, match="First argument needs to be a 'QJIT' object"): - get_cmain(f, 0.5) diff --git a/frontend/test/pytest/test_compiler.py b/frontend/test/pytest/test_compiler.py index 769e05b8aa..01d1ab702b 100644 --- a/frontend/test/pytest/test_compiler.py +++ b/frontend/test/pytest/test_compiler.py @@ -160,13 +160,13 @@ def run_from_ir(self, *_args, **_kwargs): filename = str(pathlib.Path(output).absolute()) return filename, "", ["", ""] - @qjit(target="fake_binary") + @qjit(target="mlir") @qml.qnode(qml.device(backend, wires=1)) def cpp_exception_test(): return None cpp_exception_test.compiler = MockCompiler(cpp_exception_test.compiler.options) - compiled_function = cpp_exception_test.compile() + compiled_function, _ = cpp_exception_test.compile() with pytest.raises(RuntimeError, match="Hello world"): compiled_function() @@ -180,6 +180,31 @@ def test_linker_driver_invalid_file(self): class TestCompilerState: """Test states that the compiler can reach.""" + def test_invalid_target(self): + """Test that nothing happens in AOT compilation when an invalid target is provided.""" + + @qjit(target="hello") + def f(): + return 0 + + assert f.jaxpr is None + assert f.mlir is None + assert f.qir is None + assert f.compiled_function is None + + def test_callable_without_name(self): + """Test that a callable without __name__ property can be compiled, if it is otherwise + traceable.""" + + class NoNameClass: + def __call__(self, x): + return x + + f = qjit(NoNameClass()) + + assert f(3) == 3 + assert f.__name__ == "unknown" + def test_print_stages(self, backend): """Test that after compiling the intermediate files exist.""" diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py index c03f358da8..4e1d79e148 100644 --- a/frontend/test/pytest/test_conditionals.py +++ b/frontend/test/pytest/test_conditionals.py @@ -51,7 +51,7 @@ def cond_fn(): def asline(text): return " ".join(map(lambda x: x.strip(), str(text).split("\n"))) - assert asline(expected) == asline(circuit._jaxpr) # pylint: disable=protected-access + assert asline(expected) == asline(circuit.jaxpr) class TestCond: diff --git a/frontend/test/pytest/test_debug.py b/frontend/test/pytest/test_debug.py index ec26ac3045..5506b33d53 100644 --- a/frontend/test/pytest/test_debug.py +++ b/frontend/test/pytest/test_debug.py @@ -19,7 +19,7 @@ from catalyst import debug, for_loop, qjit from catalyst.compiler import CompileOptions, Compiler -from catalyst.debug import compile_from_mlir +from catalyst.debug import compile_from_mlir, get_cmain, print_compilation_stage from catalyst.utils.exceptions import CompileError from catalyst.utils.runtime import get_lib_path @@ -189,6 +189,34 @@ def func2(): assert out == "hello\ngoodbye\n" +class TestPrintStage: + """Test that compilation pipeline results can be printed.""" + + def test_hlo_lowering_stage(self, capsys): + """Test that the IR can be printed after the HLO lowering pipeline.""" + + @qjit(keep_intermediate=True) + def func(): + return 0 + + print_compilation_stage(func, "HLOLoweringPass") + + out, err = capsys.readouterr() + assert "@jit_func() -> tensor" in out + assert "stablehlo.constant" not in out + + func.workspace.cleanup() + + def test_invalid_object(self): + """Test the function on a non-QJIT object.""" + + def func(): + return 0 + + with pytest.raises(TypeError, match="needs to be a 'QJIT' object"): + print_compilation_stage(func, "HLOLoweringPass") + + class TestCompileFromIR: """Test the debug feature that compiles from a string representation of the IR.""" @@ -287,5 +315,78 @@ def test_parsing_errors(self): assert "Failed to parse module as LLVM source" in e.value.args[0] +class TestCProgramGeneration: + """Test C Program generation""" + + def test_program_generation(self): + """Test C Program generation""" + dev = qml.device("lightning.qubit", wires=2) + + @qjit + @qml.qnode(dev) + def f(x: float): + """Returns two states.""" + qml.RX(x, wires=1) + return qml.state(), qml.state() + + template = get_cmain(f, 4.0) + assert "main" in template + assert "struct result_t result_val;" in template + assert "buff_0 = 4.0" in template + assert "arg_0 = { &buff_0, &buff_0, 0 }" in template + assert "_catalyst_ciface_jit_f(&result_val, &arg_0);" in template + + def test_program_without_return_nor_arguments(self): + """Test program without return value nor arguments.""" + + @qjit + def f(): + """No-op function.""" + return None + + template = get_cmain(f) + assert "struct result_t result_val;" not in template + assert "buff_0" not in template + assert "arg_0" not in template + + def test_generation_with_promotion(self): + """Test that C program generation works on QJIT objects and args that require promotion.""" + + @qjit + def f(x: float): + """Identity function.""" + return x + + template = get_cmain(f, 1) + + assert "main" in template + assert "buff_0 = 1.0" in template # argument was automaatically promoted + + def test_raises_error_if_tracing(self): + """Test errors if c program generation requested during tracing.""" + + @qjit + def f(x: float): + """Identity function.""" + return x + + with pytest.raises(CompileError, match="C interface cannot be generated"): + + @qjit + def error_fn(x: float): + """Should raise an error as we try to generate the C template during tracing.""" + return get_cmain(f, x) + + def test_error_non_qjit_object(self): + """An error should be raised if the object supplied to the debug function is not a QJIT.""" + + def f(x: float): + """Identity function.""" + return x + + with pytest.raises(TypeError, match="First argument needs to be a 'QJIT' object"): + get_cmain(f, 0.5) + + if __name__ == "__main__": pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/test_jax_integration.py b/frontend/test/pytest/test_jax_integration.py index 7ca44e22a5..36c8e3c3e6 100644 --- a/frontend/test/pytest/test_jax_integration.py +++ b/frontend/test/pytest/test_jax_integration.py @@ -22,7 +22,7 @@ import pytest from catalyst import for_loop, measure, qjit -from catalyst.compilation_pipelines import JAX_QJIT +from catalyst.jit import JAX_QJIT class TestJAXJIT: diff --git a/frontend/test/pytest/test_static_arguments.py b/frontend/test/pytest/test_static_arguments.py index d19ef6ffae..e1646357ad 100644 --- a/frontend/test/pytest/test_static_arguments.py +++ b/frontend/test/pytest/test_static_arguments.py @@ -25,7 +25,7 @@ class TestStaticArguments: """Test QJIT with static arguments.""" - @pytest.mark.parametrize("argnums", [(()), (None)]) + @pytest.mark.parametrize("argnums", [(), None]) def test_zero_static_argument(self, argnums): """Test QJIT with zero static argument.""" @@ -37,7 +37,7 @@ def f( assert f(1) == 1 - @pytest.mark.parametrize("argnums", [(-1), (100)]) + @pytest.mark.parametrize("argnums", [-1, 100]) def test_out_of_bounds_static_argument(self, argnums): """Test QJIT with invalid static argument index.""" diff --git a/frontend/test/pytest/test_tracing.py b/frontend/test/pytest/test_tracing.py new file mode 100644 index 0000000000..9b7b84a38f --- /dev/null +++ b/frontend/test/pytest/test_tracing.py @@ -0,0 +1,36 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Catalyst's tracing module.""" + +import jax +import pytest + +from catalyst.jax_tracer import lower_jaxpr_to_mlir + + +def test_jaxpr_lowering_without_dynshapes(): + """Test that the lowering function can be used without Catalyst's dynamic shape support.""" + + def f(): + return 0 + + jaxpr = jax.make_jaxpr(f)() + result, _ = lower_jaxpr_to_mlir(jaxpr, "test_fn") + + assert "@jit_test_fn() -> tensor" in str(result) + + +if __name__ == "__main__": + pytest.main(["-x", __file__])