From 96fef3bedba5a525b3f2eafd474ea7ba26bb78d9 Mon Sep 17 00:00:00 2001 From: David Ittah Date: Wed, 21 Feb 2024 17:32:53 -0500 Subject: [PATCH] Structural refactor of the compilation_pipelines module (#529) This is part 1 of a QJIT refactor which reorganizes some of the classes and methods in the `compilation_pipelines.py` module. The second part will refactor the `QJIT` class itself. Benefits: - a new `tracing` sub-module is created that will over time contain all functionality pertaining to program capture from Python - reusable functions are grouped together by their purpose, such as signature handling and debugging - classes and modules are slimmed down and more focused - improved docstrings Linting is not done yet, and I will also add changelog. [sc-57014] --------- Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com> --- doc/dev/debugging.rst | 5 +- frontend/catalyst/ag_primitives.py | 2 +- frontend/catalyst/compilation_pipelines.py | 572 ++---------------- frontend/catalyst/compiled_functions.py | 329 ++++++++++ frontend/catalyst/debug.py | 95 ++- frontend/catalyst/jax_primitives.py | 2 +- frontend/catalyst/jax_tracer.py | 6 +- frontend/catalyst/pennylane_extensions.py | 6 +- .../catalyst/{utils => tracing}/contexts.py | 4 +- frontend/catalyst/tracing/type_signatures.py | 236 ++++++++ frontend/test/pytest/test_c_template.py | 17 +- frontend/test/pytest/test_compiler.py | 72 --- frontend/test/pytest/test_contexts.py | 2 +- frontend/test/pytest/test_debug.py | 103 ++++ frontend/test/pytest/test_jit_behaviour.py | 45 +- setup.py | 4 +- 16 files changed, 862 insertions(+), 638 deletions(-) create mode 100644 frontend/catalyst/compiled_functions.py rename frontend/catalyst/{utils => tracing}/contexts.py (98%) create mode 100644 frontend/catalyst/tracing/type_signatures.py diff --git a/doc/dev/debugging.rst b/doc/dev/debugging.rst index e1d7b63e0d..3111cfa114 100644 --- a/doc/dev/debugging.rst +++ b/doc/dev/debugging.rst @@ -35,9 +35,9 @@ Below is an example of how to obtain a C program that can be linked against the def identity(x): return x - print(identity.get_cmain(1.0)) + print(debug.get_cmain(identity, 1.0)) -Using the ``QJIT.get_cmain`` function, the following string is returned to the user: +Using the ``debug.get_cmain`` function, the following string is returned to the user: .. code-block:: C @@ -298,4 +298,3 @@ And finally some real LLVMIR adhering to the QIR specification: The LLVMIR code is compiled to an object file using the LLVM static compiler and linked to the runtime libraries. The generated shared object is stored by the caching mechanism in Catalyst for future calls. - diff --git a/frontend/catalyst/ag_primitives.py b/frontend/catalyst/ag_primitives.py index 8bfb6b4a71..7e2636484f 100644 --- a/frontend/catalyst/ag_primitives.py +++ b/frontend/catalyst/ag_primitives.py @@ -44,7 +44,7 @@ import catalyst from catalyst.ag_utils import AutoGraphError -from catalyst.utils.contexts import EvaluationContext +from catalyst.tracing.contexts import EvaluationContext from catalyst.utils.jax_extras import DynamicJaxprTracer, ShapedArray from catalyst.utils.patching import Patcher diff --git a/frontend/catalyst/compilation_pipelines.py b/frontend/catalyst/compilation_pipelines.py index 8d2049064e..66a5e54bf9 100644 --- a/frontend/catalyst/compilation_pipelines.py +++ b/frontend/catalyst/compilation_pipelines.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 Xanadu Quantum Technologies Inc. +# Copyright 2022-2024 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,48 +11,45 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """This module contains classes and decorators for just-in-time and ahead-of-time compiling of hybrid quantum-classical functions using Catalyst. """ # pylint: disable=too-many-lines -import ctypes import functools import inspect import pathlib -import typing import warnings from copy import deepcopy -from enum import Enum import jax import jax.numpy as jnp -import numpy as np import pennylane as qml -from jax._src.interpreters.partial_eval import infer_lambda_input_type -from jax._src.pjit import _flat_axes_specs from jax.interpreters.mlir import ir from jax.tree_util import tree_flatten, tree_unflatten -from mlir_quantum.runtime import ( - as_ctype, - get_ranked_memref_descriptor, - make_nd_memref_descriptor, - make_zero_d_memref_descriptor, -) import catalyst from catalyst.ag_utils import run_autograph +from catalyst.compiled_functions import CompiledFunction from catalyst.compiler import CompileOptions, Compiler from catalyst.jax_tracer import trace_to_mlir from catalyst.pennylane_extensions import QFunc -from catalyst.utils import wrapper # pylint: disable=no-name-in-module -from catalyst.utils.c_template import get_template, mlir_type_to_numpy_type -from catalyst.utils.contexts import EvaluationContext +from catalyst.tracing.contexts import EvaluationContext +from catalyst.tracing.type_signatures import ( + TypeCompatibility, + filter_static_args, + get_abstract_signature, + get_type_annotations, + merge_static_args, + promote_arguments, + typecheck_signatures, +) +from catalyst.utils.c_template import mlir_type_to_numpy_type from catalyst.utils.exceptions import CompileError from catalyst.utils.filesystem import WorkspaceManager from catalyst.utils.gen_mlir import inject_functions -from catalyst.utils.jax_extras import get_aval2, get_implicit_and_explicit_flat_args from catalyst.utils.patching import Patcher # Required for JAX tracer objects as PennyLane wires. @@ -64,429 +61,6 @@ jax.config.update("jax_enable_x64", True) -def are_params_annotated(f: typing.Callable): - """Return true if all parameters are typed-annotated.""" - signature = inspect.signature(f) - parameters = signature.parameters - return all(p.annotation is not inspect.Parameter.empty for p in parameters.values()) - - -def get_type_annotations(func: typing.Callable): - """Get all type annotations if all parameters are typed-annotated.""" - params_are_annotated = are_params_annotated(func) - if params_are_annotated: - return getattr(func, "__annotations__", {}).values() - - return None - - -class SharedObjectManager: - """Shared object manager. - - Manages the life time of the shared object. When is it loaded, when to close it. - - Args: - shared_object_file: path to shared object containing compiled function - func_name: name of compiled function - """ - - def __init__(self, shared_object_file, func_name): - self.shared_object = None - self.function = None - self.setup = None - self.teardown = None - self.mem_transfer = None - self.open(shared_object_file, func_name) - - def open(self, shared_object_file, func_name): - """Open the sharead object and load symbols.""" - self.shared_object = ctypes.CDLL(shared_object_file) - self.function, self.setup, self.teardown, self.mem_transfer = self.load_symbols(func_name) - - def close(self): - """Close the shared object""" - self.function = None - self.setup = None - self.teardown = None - self.mem_transfer = None - dlclose = ctypes.CDLL(None).dlclose - dlclose.argtypes = [ctypes.c_void_p] - # pylint: disable=protected-access - dlclose(self.shared_object._handle) - - def load_symbols(self, func_name): - """Load symbols necessary for for execution of the compiled function. - - Args: - func_name: name of compiled function to be executed - - Returns: - function: function handle - setup: handle to the setup function, which initializes the device - teardown: handle to the teardown function, which tears down the device - mem_transfer: memory transfer shared object - """ - - setup = self.shared_object.setup - setup.argtypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_char_p)] - setup.restypes = ctypes.c_int - - teardown = self.shared_object.teardown - teardown.argtypes = None - teardown.restypes = None - - # We are calling the c-interface - function = self.shared_object["_catalyst_pyface_" + func_name] - # Guaranteed from _mlir_ciface specification - function.restypes = None - # Not needed, computed from the arguments. - # function.argyptes - - mem_transfer = self.shared_object["_mlir_memory_transfer"] - - return function, setup, teardown, mem_transfer - - def __enter__(self): - params_to_setup = [b"jitted-function"] - argc = len(params_to_setup) - array_of_char_ptrs = (ctypes.c_char_p * len(params_to_setup))() - array_of_char_ptrs[:] = params_to_setup - self.setup(ctypes.c_int(argc), array_of_char_ptrs) - return self - - def __exit__(self, _type, _value, _traceback): - self.teardown() - - -class TypeCompatibility(Enum): - """Enum class for state machine. - - The state represent the transition between states. - """ - - UNKNOWN = 0 - CAN_SKIP_PROMOTION = 1 - NEEDS_PROMOTION = 2 - NEEDS_COMPILATION = 3 - - -class CompiledFunction: - """CompiledFunction, represents a Compiled Function. - - Args: - shared_object_file: path to shared object containing compiled function - func_name: name of compiled function - restype: list of MLIR tensor types representing the result of the compiled function - """ - - def __init__( - self, - shared_object_file, - func_name, - restype, - compile_options, - ): - self.shared_object_file = shared_object_file - self.shared_object = SharedObjectManager(shared_object_file, func_name) - self.return_type_c_abi = None - self.func_name = func_name - self.restype = restype - self.compile_options = compile_options - - @staticmethod - def typecheck(abstracted_axes, compiled_signature, runtime_signature): - """Whether arguments can be promoted. - - Args: - compiled_signature: user supplied signature, obtain from either an annotation or a - previously compiled implementation of the compiled function - runtime_signature: runtime signature - - Returns: - bool. - """ - compiled_data, compiled_shape = tree_flatten(compiled_signature) - runtime_data, runtime_shape = tree_flatten(runtime_signature) - with Patcher( - # pylint: disable=protected-access - (jax._src.interpreters.partial_eval, "get_aval", get_aval2), - ): - axes_specs_compile = _flat_axes_specs(abstracted_axes, *compiled_signature, {}) - axes_specs_runtime = _flat_axes_specs(abstracted_axes, *runtime_signature, {}) - in_type_compiled = infer_lambda_input_type(axes_specs_compile, compiled_data) - in_type_runtime = infer_lambda_input_type(axes_specs_runtime, runtime_data) - - if in_type_compiled == in_type_runtime: - return TypeCompatibility.CAN_SKIP_PROMOTION - - if compiled_shape != runtime_shape: - return TypeCompatibility.NEEDS_COMPILATION - - best_case = TypeCompatibility.CAN_SKIP_PROMOTION - for c_param, r_param in zip(compiled_data, runtime_data): - if c_param.dtype != r_param.dtype: - best_case = TypeCompatibility.NEEDS_PROMOTION - - if c_param.shape != r_param.shape: - return TypeCompatibility.NEEDS_COMPILATION - - promote_to = jax.numpy.promote_types(r_param.dtype, c_param.dtype) - if c_param.dtype != promote_to: - return TypeCompatibility.NEEDS_COMPILATION - - return best_case - - @staticmethod - def promote_arguments(compiled_signature, *args): - """Promote arguments from the type specified in args to the type specified by - compiled_signature. - - Args: - compiled_signature: user supplied signature, obtain from either an annotation or a - previously compiled implementation of the compiled function - *args: actual arguments to the function - - Returns: - promoted_args: Arguments after promotion. - """ - compiled_data, compiled_shape = tree_flatten(compiled_signature) - runtime_data, runtime_shape = tree_flatten(args) - assert ( - compiled_shape == runtime_shape - ), "Compiled function incompatible runtime arguments' shape" - - promoted_args = [] - for c_param, r_param in zip(compiled_data, runtime_data): - assert isinstance(c_param, jax.core.ShapedArray) - r_param = jax.numpy.asarray(r_param) - arg_dtype = r_param.dtype - promote_to = jax.numpy.promote_types(arg_dtype, c_param.dtype) - promoted_arg = jax.numpy.asarray(r_param, dtype=promote_to) - promoted_args.append(promoted_arg) - return tree_unflatten(compiled_shape, promoted_args) - - @staticmethod - def get_runtime_signature(*args): - """Get signature from arguments. - - Args: - *args: arguments to the compiled function - - Returns: - a list of JAX shaped arrays - """ - args_data, args_shape = tree_flatten(args) - - try: - r_sig = [] - for arg in args_data: - r_sig.append(jax.api_util.shaped_abstractify(arg)) - # Unflatten JAX abstracted args to preserve the shape - return tree_unflatten(args_shape, r_sig) - except Exception as exc: - arg_type = type(arg) - raise TypeError(f"Unsupported argument type: {arg_type}") from exc - - @staticmethod - def _exec(shared_object, has_return, numpy_dict, *args): - """Execute the compiled function with arguments ``*args``. - - Args: - lib: Shared object - has_return: whether the function returns a value or not - numpy_dict: dictionary of numpy arrays of buffers from the runtime - *args: arguments to the function - - Returns: - retval: the value computed by the function or None if the function has no return value - """ - - with shared_object as lib: - result_desc = type(args[0].contents) if has_return else None - retval = wrapper.wrap(lib.function, args, result_desc, lib.mem_transfer, numpy_dict) - - return retval - - @staticmethod - def get_ranked_memref_descriptor_from_mlir_tensor_type(mlir_tensor_type): - """Convert an MLIR tensor type to a memref descriptor. - - Args: - mlir_tensor_type: an MLIR tensor type - Returns: - a memref descriptor with empty data - """ - assert mlir_tensor_type - assert mlir_tensor_type is not tuple - shape = ir.RankedTensorType(mlir_tensor_type).shape - mlir_element_type = ir.RankedTensorType(mlir_tensor_type).element_type - numpy_element_type = mlir_type_to_numpy_type(mlir_element_type) - ctp = as_ctype(numpy_element_type) - if shape: - memref_descriptor = make_nd_memref_descriptor(len(shape), ctp)() - else: - memref_descriptor = make_zero_d_memref_descriptor(ctp)() - - return memref_descriptor - - @staticmethod - def get_etypes(mlir_tensor_type): - """Get element type for an MLIR tensor type.""" - mlir_element_type = ir.RankedTensorType(mlir_tensor_type).element_type - return mlir_type_to_numpy_type(mlir_element_type) - - @staticmethod - def get_sizes(mlir_tensor_type): - """Get element type size for an MLIR tensor type.""" - mlir_element_type = ir.RankedTensorType(mlir_tensor_type).element_type - numpy_type = mlir_type_to_numpy_type(mlir_element_type) - dtype = np.dtype(numpy_type) - return dtype.itemsize - - @staticmethod - def get_ranks(mlir_tensor_type): - """Get rank for an MLIR tensor type.""" - shape = ir.RankedTensorType(mlir_tensor_type).shape - return len(shape) if shape else 0 - - def getCompiledReturnValueType(self, mlir_tensor_types): - """Compute the type for the return value and memoize it - - This type does not need to be recomputed as it is generated once per compiled function. - Args: - mlir_tensor_types: a list of MLIR tensor types which match the expected return type - Returns: - a pointer to a CompiledFunctionReturnValue, which corresponds to a structure in which - fields match the expected return types - """ - - if self.return_type_c_abi is not None: - return self.return_type_c_abi - - error_msg = """This function must be called with a non-zero length list as an argument.""" - assert mlir_tensor_types, error_msg - _get_rmd = CompiledFunction.get_ranked_memref_descriptor_from_mlir_tensor_type - return_fields_types = [_get_rmd(mlir_tensor_type) for mlir_tensor_type in mlir_tensor_types] - ranks = [ - CompiledFunction.get_ranks(mlir_tensor_type) for mlir_tensor_type in mlir_tensor_types - ] - - etypes = [ - CompiledFunction.get_etypes(mlir_tensor_type) for mlir_tensor_type in mlir_tensor_types - ] - - sizes = [ - CompiledFunction.get_sizes(mlir_tensor_type) for mlir_tensor_type in mlir_tensor_types - ] - - class CompiledFunctionReturnValue(ctypes.Structure): - """Programmatically create a structure which holds tensors of varying base types.""" - - _fields_ = [("f" + str(i), type(t)) for i, t in enumerate(return_fields_types)] - _ranks_ = ranks - _etypes_ = etypes - _sizes_ = sizes - - return_value = CompiledFunctionReturnValue() - return_value_pointer = ctypes.pointer(return_value) - self.return_type_c_abi = return_value_pointer - return self.return_type_c_abi - - def restype_to_memref_descs(self, mlir_tensor_types): - """Converts the return type to a compatible type for the expected ABI. - - Args: - mlir_tensor_types: a list of MLIR tensor types which match the expected return type - Returns: - a pointer to a CompiledFunctionReturnValue, which corresponds to a structure in which - fields match the expected return types - """ - return self.getCompiledReturnValueType(mlir_tensor_types) - - def args_to_memref_descs(self, restype, args): - """Convert ``args`` to memref descriptors. - - Besides converting the arguments to memrefs, it also prepares the return value. To respect - the ABI, the return value is changed to a pointer and passed as the first parameter. - - Args: - restype: the type of restype is a ``CompiledFunctionReturnValue`` - args: the JAX arrays to be used as arguments to the function - - Returns: - c_abi_args: a list of memref descriptor pointers to return values and parameters - numpy_arg_buffer: A list to the return values. It must be kept around until the function - finishes execution as the memref descriptors will point to memory locations inside - numpy arrays. - - """ - numpy_arg_buffer = [] - return_value_pointer = ctypes.POINTER(ctypes.c_int)() # This is the null pointer - - if restype: - return_value_pointer = self.restype_to_memref_descs(restype) - - c_abi_args = [] - - args_data, args_shape = tree_flatten(args) - - for arg in args_data: - numpy_arg = np.asarray(arg) - numpy_arg_buffer.append(numpy_arg) - c_abi_ptr = ctypes.pointer(get_ranked_memref_descriptor(numpy_arg)) - c_abi_args.append(c_abi_ptr) - - args = tree_unflatten(args_shape, c_abi_args) - - class CompiledFunctionArgValue(ctypes.Structure): - """Programmatically create a structure which holds tensors of varying base types.""" - - _fields_ = [("f" + str(i), type(t)) for i, t in enumerate(c_abi_args)] - - def __init__(self, c_abi_args): - for ft_tuple, c_abi_arg in zip(CompiledFunctionArgValue._fields_, c_abi_args): - f = ft_tuple[0] - setattr(self, f, c_abi_arg) - - arg_value_pointer = ctypes.POINTER(ctypes.c_int)() - - if len(args) > 0: - arg_value = CompiledFunctionArgValue(c_abi_args) - arg_value_pointer = ctypes.pointer(arg_value) - - c_abi_args = [return_value_pointer] + [arg_value_pointer] - return c_abi_args, numpy_arg_buffer - - def get_cmain(self, *args): - """Get a string representing a C program that can be linked against the shared object.""" - _, buffer = self.args_to_memref_descs(self.restype, args) - - return get_template(self.func_name, self.restype, *buffer) - - def __call__(self, *args, **kwargs): - static_argnums = self.compile_options.static_argnums - dynamic_args = [args[idx] for idx in range(len(args)) if idx not in static_argnums] - - if self.compile_options.abstracted_axes is not None: - abstracted_axes = self.compile_options.abstracted_axes - dynamic_args = get_implicit_and_explicit_flat_args( - abstracted_axes, *dynamic_args, **kwargs - ) - - abi_args, _buffer = self.args_to_memref_descs(self.restype, dynamic_args) - - numpy_dict = {nparr.ctypes.data: nparr for nparr in _buffer} - - result = CompiledFunction._exec( - self.shared_object, - self.restype, - numpy_dict, - *abi_args, - ) - - return result - - # pylint: disable=too-many-instance-attributes class QJIT: """Class representing a just-in-time compiled hybrid quantum-classical function. @@ -504,7 +78,6 @@ class QJIT: def __init__(self, fn, compile_options): self.compile_options = compile_options self.compiler = Compiler(compile_options) - self.compiling_from_textual_ir = isinstance(fn, str) self.original_function = fn self.user_function = fn self.jaxed_function = None @@ -533,19 +106,13 @@ def __init__(self, fn, compile_options): ) self.preferred_workspace_dir = preferred_workspace_dir - # If we are compiling from textual ir, just use this as the name of the function. - name = "compiled_function" - if not self.compiling_from_textual_ir: - # pylint: disable=no-member - # Guaranteed to exist after functools.update_wrapper AND not compiling from textual IR - name = self.__name__ - self.function_name = name + # pylint: disable=no-member + # Guaranteed to exist after functools.update_wrapper + self.function_name = self.__name__ - self.workspace = WorkspaceManager.get_or_create_workspace(name, preferred_workspace_dir) - - if self.compiling_from_textual_ir: - EvaluationContext.check_is_not_tracing("Cannot compile from IR in tracing context.") - return + self.workspace = WorkspaceManager.get_or_create_workspace( + self.function_name, preferred_workspace_dir + ) parameter_types = get_type_annotations(self.user_function) @@ -603,25 +170,6 @@ def get_static_args_hash(self, *args): ) return static_args_hash - def merge_sig_into_args(self, args_list, *sig): - """Merge runtime signature back to argument list. - - Args: - args: arguments to the compiled function. - *sig: runtime signature. - Returns: - a list of dynamic arguments. - """ - static_argnums = self.compile_options.static_argnums - # Combine dynamic_args (in args) and sig (and keep the original order). - new_sig = sig - if static_argnums: - new_sig = args_list - dynamic_indices = [idx for idx in range(len(args_list)) if idx not in static_argnums] - for i, idx in enumerate(dynamic_indices): - new_sig[idx] = sig[i] - return new_sig - def get_mlir(self, *args): """Trace :func:`~.user_function` @@ -632,14 +180,14 @@ def get_mlir(self, *args): an MLIR module """ static_argnums = self.compile_options.static_argnums - dynamic_args = [args[idx] for idx in range(len(args)) if idx not in static_argnums] - self.c_sig = CompiledFunction.get_runtime_signature(*dynamic_args) + dynamic_args = filter_static_args(args, static_argnums) + self.c_sig = get_abstract_signature(dynamic_args) with Patcher( (qml.QNode, "__call__", QFunc.__call__), ): func = self.user_function - sig = self.merge_sig_into_args(list(args), *self.c_sig) + sig = merge_static_args(self.c_sig, args, static_argnums) abstracted_axes = self.compile_options.abstracted_axes mlir_module, ctx, jaxpr, _, self.out_tree = trace_to_mlir( func, static_argnums, abstracted_axes, *sig @@ -660,38 +208,26 @@ def compile(self): if self.compiled_function and self.compiled_function.shared_object: self.compiled_function.shared_object.close() - if self.compiling_from_textual_ir: - # Module name can be anything. - module_name = "catalyst_module" - shared_object, llvm_ir, inferred_func_data = self.compiler.run_from_ir( - self.user_function, module_name, self.workspace - ) - qfunc_name = inferred_func_data[0] - # Parse back the return types given as a semicolon-separated string - with ir.Context(): - restype = [ir.RankedTensorType.parse(rt) for rt in inferred_func_data[1].split(",")] - else: + # WARNING: assumption is that the first function + # is the entry point to the compiled program. + entry_point_func = self.mlir_module.body.operations[0] + restype = entry_point_func.type.results + + for res in restype: + baseType = ir.RankedTensorType(res).element_type # This will make a check before sending it to the compiler that the return type # is actually available in most systems. f16 needs a special symbol and linking # will fail if it is not available. - # - # WARNING: assumption is that the first function - # is the entry point to the compiled program. - entry_point_func = self.mlir_module.body.operations[0] - restype = entry_point_func.type.results - - for res in restype: - baseType = ir.RankedTensorType(res).element_type - mlir_type_to_numpy_type(baseType) - - # The function name out of MLIR has quotes around it, which we need to remove. - # The MLIR function name is actually a derived type from string which has no - # `replace` method, so we need to get a regular Python string out of it. - qfunc_name = str(self.mlir_module.body.operations[0].name).replace('"', "") - - shared_object, llvm_ir, inferred_func_data = self.compiler.run( - self.mlir_module, self.workspace - ) + mlir_type_to_numpy_type(baseType) + + # The function name out of MLIR has quotes around it, which we need to remove. + # The MLIR function name is actually a derived type from string which has no + # `replace` method, so we need to get a regular Python string out of it. + qfunc_name = str(self.mlir_module.body.operations[0].name).replace('"', "") + + shared_object, llvm_ir, _inferred_func_data = self.compiler.run( + self.mlir_module, self.workspace + ) self._llvmir = llvm_ir options = self.compile_options @@ -717,8 +253,8 @@ def _ensure_real_arguments_and_formal_parameters_are_compatible(self, function, *args: arguments that may have been promoted """ static_argnums = self.compile_options.static_argnums - dynamic_args = [args[idx] for idx in range(len(args)) if idx not in static_argnums] - r_sig = CompiledFunction.get_runtime_signature(*dynamic_args) + dynamic_args = filter_static_args(args, static_argnums) + r_sig = get_abstract_signature(dynamic_args) has_been_compiled = self.compiled_function is not None if static_argnums: @@ -742,38 +278,22 @@ def _ensure_real_arguments_and_formal_parameters_are_compatible(self, function, next_action = TypeCompatibility.NEEDS_COMPILATION else: abstracted_axes = self.compile_options.abstracted_axes - next_action = CompiledFunction.typecheck(abstracted_axes, self.c_sig, r_sig) + next_action = typecheck_signatures(self.c_sig, r_sig, abstracted_axes) if next_action == TypeCompatibility.NEEDS_PROMOTION: - args = CompiledFunction.promote_arguments(self.c_sig, *dynamic_args) + args = promote_arguments(self.c_sig, dynamic_args) elif next_action == TypeCompatibility.NEEDS_COMPILATION: if self.user_typed: msg = "Provided arguments did not match declared signature, recompiling..." warnings.warn(msg, UserWarning) - if not self.compiling_from_textual_ir: - sig = self.merge_sig_into_args(list(args), *r_sig) - self.mlir_module = self.get_mlir(*sig) + sig = merge_static_args(r_sig, args, static_argnums) + self.mlir_module = self.get_mlir(*sig) function = self.compile() else: assert next_action == TypeCompatibility.CAN_SKIP_PROMOTION return function, args - def get_cmain(self, *args): - """Return the C interface template for current arguments. - - Args: - *args: Arguments to be used in the template. - Returns: - str: A C program that can be compiled with the current shared object. - """ - msg = "C interface cannot be generated from tracing context." - EvaluationContext.check_is_not_tracing(msg) - function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible( - self.compiled_function, *args - ) - return function.get_cmain(*args) - def __call__(self, *args, **kwargs): static_argnums = self.compile_options.static_argnums diff --git a/frontend/catalyst/compiled_functions.py b/frontend/catalyst/compiled_functions.py new file mode 100644 index 0000000000..b6143129ad --- /dev/null +++ b/frontend/catalyst/compiled_functions.py @@ -0,0 +1,329 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains classes to manage compiled functions and their underlying resources.""" + +import ctypes + +import numpy as np +from jax.interpreters import mlir +from jax.tree_util import tree_flatten, tree_unflatten +from mlir_quantum.runtime import ( + as_ctype, + get_ranked_memref_descriptor, + make_nd_memref_descriptor, + make_zero_d_memref_descriptor, +) + +from catalyst.tracing.type_signatures import filter_static_args +from catalyst.utils import wrapper # pylint: disable=no-name-in-module +from catalyst.utils.c_template import get_template, mlir_type_to_numpy_type +from catalyst.utils.jax_extras import get_implicit_and_explicit_flat_args + + +class SharedObjectManager: + """Shared object manager. + + Manages the life time of the shared object. When is it loaded, when to close it. + + Args: + shared_object_file (str): path to shared object containing compiled function + func_name (str): name of compiled function + """ + + def __init__(self, shared_object_file, func_name): + self.shared_object = None + self.function = None + self.setup = None + self.teardown = None + self.mem_transfer = None + self.open(shared_object_file, func_name) + + def open(self, shared_object_file, func_name): + """Open the sharead object and load symbols.""" + self.shared_object = ctypes.CDLL(shared_object_file) + self.function, self.setup, self.teardown, self.mem_transfer = self.load_symbols(func_name) + + def close(self): + """Close the shared object""" + self.function = None + self.setup = None + self.teardown = None + self.mem_transfer = None + dlclose = ctypes.CDLL(None).dlclose + dlclose.argtypes = [ctypes.c_void_p] + # pylint: disable=protected-access + dlclose(self.shared_object._handle) + + def load_symbols(self, func_name): + """Load symbols necessary for for execution of the compiled function. + + Args: + func_name: name of compiled function to be executed + + Returns: + function: function handle + setup: handle to the setup function, which initializes the device + teardown: handle to the teardown function, which tears down the device + mem_transfer: memory transfer shared object + """ + + setup = self.shared_object.setup + setup.argtypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_char_p)] + setup.restypes = ctypes.c_int + + teardown = self.shared_object.teardown + teardown.argtypes = None + teardown.restypes = None + + # We are calling the c-interface + function = self.shared_object["_catalyst_pyface_" + func_name] + # Guaranteed from _mlir_ciface specification + function.restypes = None + # Not needed, computed from the arguments. + # function.argyptes + + mem_transfer = self.shared_object["_mlir_memory_transfer"] + + return function, setup, teardown, mem_transfer + + def __enter__(self): + params_to_setup = [b"jitted-function"] + argc = len(params_to_setup) + array_of_char_ptrs = (ctypes.c_char_p * len(params_to_setup))() + array_of_char_ptrs[:] = params_to_setup + self.setup(ctypes.c_int(argc), array_of_char_ptrs) + return self + + def __exit__(self, _type, _value, _traceback): + self.teardown() + + +class CompiledFunction: + """Manages the compilation result of a user program. Holds a reference to the binary object and + performs necessary processing to invoke the compiled program. + + Args: + shared_object_file (str): path to shared object containing compiled function + func_name (str): name of compiled function + restype (Iterable): MLIR tensor types representing the result of the compiled function + compile_options (CompileOptions): compilation options used + """ + + def __init__(self, shared_object_file, func_name, restype, compile_options): + self.shared_object_file = shared_object_file + self.shared_object = SharedObjectManager(shared_object_file, func_name) + self.compile_options = compile_options + self.return_type_c_abi = None + self.func_name = func_name + self.restype = restype + + @staticmethod + def _exec(shared_object, has_return, numpy_dict, *args): + """Execute the compiled function with arguments ``*args``. + + Args: + lib: Shared object + has_return: whether the function returns a value or not + numpy_dict: dictionary of numpy arrays of buffers from the runtime + *args: arguments to the function + + Returns: + retval: the value computed by the function or None if the function has no return value + """ + + with shared_object as lib: + result_desc = type(args[0].contents) if has_return else None + retval = wrapper.wrap(lib.function, args, result_desc, lib.mem_transfer, numpy_dict) + + return retval + + @staticmethod + def get_ranked_memref_descriptor_from_mlir_tensor_type(mlir_tensor_type): + """Convert an MLIR tensor type to a memref descriptor. + + Args: + mlir_tensor_type: an MLIR tensor type + Returns: + a memref descriptor with empty data + """ + assert mlir_tensor_type + assert mlir_tensor_type is not tuple + shape = mlir.ir.RankedTensorType(mlir_tensor_type).shape + mlir_element_type = mlir.ir.RankedTensorType(mlir_tensor_type).element_type + numpy_element_type = mlir_type_to_numpy_type(mlir_element_type) + ctp = as_ctype(numpy_element_type) + if shape: + memref_descriptor = make_nd_memref_descriptor(len(shape), ctp)() + else: + memref_descriptor = make_zero_d_memref_descriptor(ctp)() + + return memref_descriptor + + @staticmethod + def get_etypes(mlir_tensor_type): + """Get element type for an MLIR tensor type.""" + mlir_element_type = mlir.ir.RankedTensorType(mlir_tensor_type).element_type + return mlir_type_to_numpy_type(mlir_element_type) + + @staticmethod + def get_sizes(mlir_tensor_type): + """Get element type size for an MLIR tensor type.""" + mlir_element_type = mlir.ir.RankedTensorType(mlir_tensor_type).element_type + numpy_type = mlir_type_to_numpy_type(mlir_element_type) + dtype = np.dtype(numpy_type) + return dtype.itemsize + + @staticmethod + def get_ranks(mlir_tensor_type): + """Get rank for an MLIR tensor type.""" + shape = mlir.ir.RankedTensorType(mlir_tensor_type).shape + return len(shape) if shape else 0 + + def getCompiledReturnValueType(self, mlir_tensor_types): + """Compute the type for the return value and memoize it + + This type does not need to be recomputed as it is generated once per compiled function. + Args: + mlir_tensor_types: a list of MLIR tensor types which match the expected return type + Returns: + a pointer to a CompiledFunctionReturnValue, which corresponds to a structure in which + fields match the expected return types + """ + + if self.return_type_c_abi is not None: + return self.return_type_c_abi + + error_msg = """This function must be called with a non-zero length list as an argument.""" + assert mlir_tensor_types, error_msg + _get_rmd = CompiledFunction.get_ranked_memref_descriptor_from_mlir_tensor_type + return_fields_types = [_get_rmd(mlir_tensor_type) for mlir_tensor_type in mlir_tensor_types] + ranks = [ + CompiledFunction.get_ranks(mlir_tensor_type) for mlir_tensor_type in mlir_tensor_types + ] + + etypes = [ + CompiledFunction.get_etypes(mlir_tensor_type) for mlir_tensor_type in mlir_tensor_types + ] + + sizes = [ + CompiledFunction.get_sizes(mlir_tensor_type) for mlir_tensor_type in mlir_tensor_types + ] + + class CompiledFunctionReturnValue(ctypes.Structure): + """Programmatically create a structure which holds tensors of varying base types.""" + + _fields_ = [("f" + str(i), type(t)) for i, t in enumerate(return_fields_types)] + _ranks_ = ranks + _etypes_ = etypes + _sizes_ = sizes + + return_value = CompiledFunctionReturnValue() + return_value_pointer = ctypes.pointer(return_value) + self.return_type_c_abi = return_value_pointer + return self.return_type_c_abi + + def restype_to_memref_descs(self, mlir_tensor_types): + """Converts the return type to a compatible type for the expected ABI. + + Args: + mlir_tensor_types: a list of MLIR tensor types which match the expected return type + Returns: + a pointer to a CompiledFunctionReturnValue, which corresponds to a structure in which + fields match the expected return types + """ + return self.getCompiledReturnValueType(mlir_tensor_types) + + def args_to_memref_descs(self, restype, args): + """Convert ``args`` to memref descriptors. + + Besides converting the arguments to memrefs, it also prepares the return value. To respect + the ABI, the return value is changed to a pointer and passed as the first parameter. + + Args: + restype: the type of restype is a ``CompiledFunctionReturnValue`` + args: the JAX arrays to be used as arguments to the function + + Returns: + c_abi_args: a list of memref descriptor pointers to return values and parameters + numpy_arg_buffer: A list to the return values. It must be kept around until the function + finishes execution as the memref descriptors will point to memory locations inside + numpy arrays. + + """ + numpy_arg_buffer = [] + return_value_pointer = ctypes.POINTER(ctypes.c_int)() # This is the null pointer + + if restype: + return_value_pointer = self.restype_to_memref_descs(restype) + + c_abi_args = [] + + args_data, args_shape = tree_flatten(args) + + for arg in args_data: + numpy_arg = np.asarray(arg) + numpy_arg_buffer.append(numpy_arg) + c_abi_ptr = ctypes.pointer(get_ranked_memref_descriptor(numpy_arg)) + c_abi_args.append(c_abi_ptr) + + args = tree_unflatten(args_shape, c_abi_args) + + class CompiledFunctionArgValue(ctypes.Structure): + """Programmatically create a structure which holds tensors of varying base types.""" + + _fields_ = [("f" + str(i), type(t)) for i, t in enumerate(c_abi_args)] + + def __init__(self, c_abi_args): + for ft_tuple, c_abi_arg in zip(CompiledFunctionArgValue._fields_, c_abi_args): + f = ft_tuple[0] + setattr(self, f, c_abi_arg) + + arg_value_pointer = ctypes.POINTER(ctypes.c_int)() + + if len(args) > 0: + arg_value = CompiledFunctionArgValue(c_abi_args) + arg_value_pointer = ctypes.pointer(arg_value) + + c_abi_args = [return_value_pointer] + [arg_value_pointer] + return c_abi_args, numpy_arg_buffer + + def get_cmain(self, *args): + """Get a string representing a C program that can be linked against the shared object.""" + _, buffer = self.args_to_memref_descs(self.restype, args) + + return get_template(self.func_name, self.restype, *buffer) + + def __call__(self, *args, **kwargs): + static_argnums = self.compile_options.static_argnums + dynamic_args = filter_static_args(args, static_argnums) + + if self.compile_options.abstracted_axes is not None: + abstracted_axes = self.compile_options.abstracted_axes + dynamic_args = get_implicit_and_explicit_flat_args( + abstracted_axes, *dynamic_args, **kwargs + ) + + abi_args, _buffer = self.args_to_memref_descs(self.restype, dynamic_args) + + numpy_dict = {nparr.ctypes.data: nparr for nparr in _buffer} + + result = CompiledFunction._exec( + self.shared_object, + self.restype, + numpy_dict, + *abi_args, + ) + + return result diff --git a/frontend/catalyst/debug.py b/frontend/catalyst/debug.py index 54025886fb..5622b5bcc0 100644 --- a/frontend/catalyst/debug.py +++ b/frontend/catalyst/debug.py @@ -1,4 +1,4 @@ -# Copyright 2023 Xanadu Quantum Technologies Inc. +# Copyright 2023-2024 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,16 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Catalyst's debug module contains functions useful for user program debugging, such as -runtime printing. -""" +"""Catalyst's debug module contains functions useful for user program debugging.""" import builtins +import os import jax +from jax.interpreters import mlir +import catalyst +from catalyst.compiled_functions import CompiledFunction +from catalyst.compiler import Compiler from catalyst.jax_primitives import print_p -from catalyst.utils.contexts import EvaluationContext +from catalyst.tracing.contexts import EvaluationContext +from catalyst.utils.filesystem import WorkspaceManager # pylint: disable=redefined-builtin @@ -68,3 +72,84 @@ def func(x: float): else: # Dispatch to Python print outside a qjit context. builtins.print(x) + + +def get_cmain(fn, *args): + """Return a C program that calls a jitted function with the provided arguments. + + Args: + fn (QJIT): a qjit-decorated function + *args: argument values to use in the C program when invoking ``fn`` + + Returns: + str: A C program that can be compiled and linked with the current shared object. + """ + EvaluationContext.check_is_not_tracing("C interface cannot be generated from tracing context.") + + if not isinstance(fn, catalyst.QJIT): + raise TypeError(f"First argument needs to be a 'QJIT' object, got a {type(fn)}.") + + # TODO: will be removed in part 2 of the refactor + # pylint: disable=protected-access + complied_function, args = fn._ensure_real_arguments_and_formal_parameters_are_compatible( + fn.compiled_function, *args + ) + + return complied_function.get_cmain(*args) + + +# pylint: disable=line-too-long +def compile_from_mlir(ir, compiler=None, compile_options=None): + """Compile a Catalyst function to binary code from the provided MLIR. + + Args: + ir (str): the MLIR to compile in string form + compile_options: options to use during compilation + + Returns: + CompiledFunction: A callable that manages the compiled shared library and its invocation. + + **Example** + + The main entry point of the program is required to start with ``catalyst.entry_point``, and + the program is required to contain ``setup`` and ``teardown`` functions. + + .. code-block:: python + + ir = r\""" + module @workflow { + func.func public @catalyst.entry_point(%arg0: tensor) -> tensor attributes {llvm.emit_c_interface} { + return %arg0 : tensor + } + func.func @setup() { + quantum.init + return + } + func.func @teardown() { + quantum.finalize + return + } + } + \""" + + compiled_function = compile_from_mlir(ir) + + >>> compiled_function(0.1) + [0.1] + """ + EvaluationContext.check_is_not_tracing("Cannot compile from IR in tracing context.") + + if compiler is None: + compiler = Compiler(compile_options) + + module_name = "debug_module" + workspace_dir = os.getcwd() if compiler.options.keep_intermediate else None + workspace = WorkspaceManager.get_or_create_workspace("debug_workspace", workspace_dir) + shared_object, _llvm_ir, func_data = compiler.run_from_ir(ir, module_name, workspace) + + # Parse inferred function data, like name and return types. + qfunc_name = func_data[0] + with mlir.ir.Context(): + result_types = [mlir.ir.RankedTensorType.parse(rt) for rt in func_data[1].split(",")] + + return CompiledFunction(shared_object, qfunc_name, result_types, compiler.options) diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 6d3f950d38..dfb02cc385 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -1544,7 +1544,7 @@ def _scalar_abstractify(t): # pylint: disable=protected-access if t in {int, float, complex, bool} or isinstance(t, jax._src.numpy.lax_numpy._ScalarMeta): return core.ShapedArray([], dtype=t, weak_type=True) - raise TypeError(f"Cannot convert given type {t} to scalar ShapedArray.") + raise TypeError(f"Argument type {t} is not a valid JAX type.") # pylint: disable=protected-access diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 5858e1ef76..9e03a4420a 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -52,7 +52,11 @@ tensorobs_p, var_p, ) -from catalyst.utils.contexts import EvaluationContext, EvaluationMode, JaxTracingContext +from catalyst.tracing.contexts import ( + EvaluationContext, + EvaluationMode, + JaxTracingContext, +) from catalyst.utils.exceptions import CompileError from catalyst.utils.jax_extras import ( ClosedJaxpr, diff --git a/frontend/catalyst/pennylane_extensions.py b/frontend/catalyst/pennylane_extensions.py index 2fe6d115ec..f3dab44469 100644 --- a/frontend/catalyst/pennylane_extensions.py +++ b/frontend/catalyst/pennylane_extensions.py @@ -72,7 +72,11 @@ trace_quantum_tape, unify_result_types, ) -from catalyst.utils.contexts import EvaluationContext, EvaluationMode, JaxTracingContext +from catalyst.tracing.contexts import ( + EvaluationContext, + EvaluationMode, + JaxTracingContext, +) from catalyst.utils.exceptions import CompileError, DifferentiableCompileError from catalyst.utils.jax_extras import ( ClosedJaxpr, diff --git a/frontend/catalyst/utils/contexts.py b/frontend/catalyst/tracing/contexts.py similarity index 98% rename from frontend/catalyst/utils/contexts.py rename to frontend/catalyst/tracing/contexts.py index 79f57dc891..7f388a82fe 100644 --- a/frontend/catalyst/utils/contexts.py +++ b/frontend/catalyst/tracing/contexts.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 Xanadu Quantum Technologies Inc. +# Copyright 2022-2024 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ # limitations under the License. """ -Tracing module. +This module provides context classes to manage and query Catalyst's and JAX's tracing state. """ from contextlib import contextmanager diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py new file mode 100644 index 0000000000..7bdb9a66c7 --- /dev/null +++ b/frontend/catalyst/tracing/type_signatures.py @@ -0,0 +1,236 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility functions for obtaining and manipulating function signatures and +arguments in the context of tracing. +""" + +import enum +import inspect +from typing import Callable + +import jax +from jax._src.interpreters.partial_eval import infer_lambda_input_type +from jax._src.pjit import _flat_axes_specs +from jax.api_util import shaped_abstractify +from jax.tree_util import tree_flatten, tree_unflatten + +from catalyst.utils.jax_extras import get_aval2 +from catalyst.utils.patching import Patcher + + +def params_are_annotated(fn: Callable): + """Return true if all parameters are typed-annotated, or no parameters are present.""" + assert isinstance(fn, Callable) + signature = inspect.signature(fn) + parameters = signature.parameters + return all(p.annotation is not inspect.Parameter.empty for p in parameters.values()) + + +def get_type_annotations(fn: Callable): + """Get type annotations if all parameters are annotated.""" + assert isinstance(fn, Callable) + if fn is not None and params_are_annotated(fn): + return tuple(getattr(fn, "__annotations__", {}).values()) + + return None + + +def get_abstract_signature(args): + """Get abstract values from real arguments, preserving PyTrees. + + Args: + args (Iterable): arguments to convert + + Returns: + Iterable: ShapedArrays for the provided values + """ + flat_args, treedef = tree_flatten(args) + + abstract_args = [shaped_abstractify(arg) for arg in flat_args] + + return tree_unflatten(treedef, abstract_args) + + +def filter_static_args(args, static_argnums): + """Remove static values from arguments using the provided index list. + + Args: + args (Iterable): arguments to a compiled function + static_argnums (Iterable[int]): indices to filter + + Returns: + Tuple: dynamic arguments + """ + return tuple(args[idx] for idx in range(len(args)) if idx not in static_argnums) + + +# TODO: remove pragma in part 2 +def split_static_args(args, static_argnums): # pragma: nocover + """Split arguments into static and dynamic values using the provided index list. + + Args: + args (Iterable): arguments to a compiled function + static_argnums (Iterable[int]): indices to split on + + Returns: + Tuple: dynamic arguments + Tuple: static arguments + """ + dynamic_args, static_args = [], [] + for i, arg in enumerate(args): + if i in static_argnums: + static_args.append(arg) + else: + dynamic_args.append(arg) + return tuple(dynamic_args), tuple(static_args) + + +def merge_static_args(signature, args, static_argnums): + """Merge static arguments back into an abstract signature, retaining the original ordering. + + Args: + signature (Iterable[ShapedArray]): abstract values of the dynamic arguments + args (Iterable): original argument list to draw static values from + static_argnums (Iterable[int]): indices to merge on + + Returns: + Tuple[ShapedArray | Any]: a mixture of ShapedArrays and static argument values + """ + if not static_argnums: + return signature + + merged_sig = list(args) # mutable copy + + dynamic_indices = [idx for idx in range(len(args)) if idx not in static_argnums] + for i, idx in enumerate(dynamic_indices): + merged_sig[idx] = signature[i] + + return tuple(merged_sig) + + +# TODO: remove pragma in part 2 +def get_decomposed_signature(args, static_argnums): # pragma: nocover + """Decompose function arguments into dynamic and static arguments, where the dynamic arguments + are further processed into abstract values and PyTree metadata. All values returned by this + function are hashable. + + Args: + args (Iterable): arguments to a compiled function + static_argnums (Iterable[int]): indices to split on + + Returns: + Tuple[ShapedArray]: dynamic argument shape and dtype information + PyTreeDef: dynamic argument PyTree metadata + Tuple[Any]: static argument values + """ + dynamic_args, static_args = split_static_args(args, static_argnums) + flat_dynamic_args, treedef = tree_flatten(dynamic_args) + flat_signature = get_abstract_signature(flat_dynamic_args) + + return flat_signature, treedef, static_args + + +class TypeCompatibility(enum.Enum): + """Enum class to indicate result of type compatibility analysis between two signatures.""" + + UNKNOWN = 0 + CAN_SKIP_PROMOTION = 1 + NEEDS_PROMOTION = 2 + NEEDS_COMPILATION = 3 + + +def typecheck_signatures(compiled_signature, runtime_signature, abstracted_axes=None): + """Determine whether a signature is compatible with another, possibly via promotion and + considering dynamic axes, and return either of three states: + + - fully compatible (skip promotion) + - conditionally compatible (requires promotion) + - incompatible (requires re-compilation) + + Args: + compiled_signature (Iterable[ShapedArray]): base signature to compare against, typically the + signature of a previously compiled function or user specified type hints + runtime_signature (Iterable[ShapedArray]): signature to examine, usually from runtime + arguments + abstracted_axes + + Returns: + TypeCompatibility + """ + if compiled_signature is None: + return TypeCompatibility.NEEDS_COMPILATION + + flat_compiled_sig, compiled_treedef = tree_flatten(compiled_signature) + flat_runtime_sig, runtime_treedef = tree_flatten(runtime_signature) + + if compiled_treedef != runtime_treedef: + return TypeCompatibility.NEEDS_COMPILATION + + # We first check signature equality considering dynamic axes, allowing the shape of an array + # to be different if it was compiled with a dynamical shape. + # TODO: unify this with the promotion checks, allowing the dtype to change for a dynamic axis + with Patcher( + # pylint: disable=protected-access + (jax._src.interpreters.partial_eval, "get_aval", get_aval2), + ): + # TODO: do away with private jax functions + axes_specs_compile = _flat_axes_specs(abstracted_axes, *compiled_signature, {}) + axes_specs_runtime = _flat_axes_specs(abstracted_axes, *runtime_signature, {}) + in_type_compiled = infer_lambda_input_type(axes_specs_compile, flat_compiled_sig) + in_type_runtime = infer_lambda_input_type(axes_specs_runtime, flat_runtime_sig) + + if in_type_compiled == in_type_runtime: + return TypeCompatibility.CAN_SKIP_PROMOTION + + action = TypeCompatibility.CAN_SKIP_PROMOTION + for c_param, r_param in zip(flat_compiled_sig, flat_runtime_sig): + if c_param.dtype != r_param.dtype: + action = TypeCompatibility.NEEDS_PROMOTION + + if c_param.shape != r_param.shape: + return TypeCompatibility.NEEDS_COMPILATION + + promote_to = jax.numpy.promote_types(r_param.dtype, c_param.dtype) + if c_param.dtype != promote_to: + return TypeCompatibility.NEEDS_COMPILATION + + return action + + +def promote_arguments(target_signature, args): + """Promote arguments to the provided target signature, preserving PyTrees. + + Args: + target_signature (Iterable): target signature to promote arguments to + args (Iterable): arguments to promote, must have matching PyTrees with target signature + + Returns: + Iterable: arguments promoted to target signature + """ + flat_target_sig, target_treedef = tree_flatten(target_signature) + flat_args, treedef = tree_flatten(args) + assert target_treedef == treedef, "Argument PyTrees did not match target signature." + + promoted_args = [] + for c_param, r_param in zip(flat_target_sig, flat_args): + assert isinstance(c_param, jax.core.ShapedArray) + r_param = jax.numpy.asarray(r_param) + arg_dtype = r_param.dtype + promote_to = jax.numpy.promote_types(arg_dtype, c_param.dtype) + promoted_arg = jax.numpy.asarray(r_param, dtype=promote_to) + promoted_args.append(promoted_arg) + + return tree_unflatten(treedef, promoted_args) diff --git a/frontend/test/pytest/test_c_template.py b/frontend/test/pytest/test_c_template.py index e920316f98..6827f9eb0e 100644 --- a/frontend/test/pytest/test_c_template.py +++ b/frontend/test/pytest/test_c_template.py @@ -19,6 +19,7 @@ import pytest from catalyst import qjit +from catalyst.debug import get_cmain from catalyst.utils.c_template import CType, CVariable from catalyst.utils.exceptions import CompileError @@ -129,7 +130,7 @@ def f(x: float): qml.RX(x, wires=1) return qml.state(), qml.state() - template = f.get_cmain(4.0) + template = get_cmain(f, 4.0) assert "main" in template assert "struct result_t result_val;" in template assert "buff_0 = 4.0" in template @@ -144,7 +145,7 @@ def f(): """No-op function.""" return None - template = f.get_cmain() + template = get_cmain(f) assert "struct result_t result_val;" not in template assert "buff_0" not in template assert "arg_0" not in template @@ -166,4 +167,14 @@ def f(x: float): @qjit def error_fn(x: float): """Should raise an error as we try to generate the C template during tracing.""" - return f.get_cmain(x) + return get_cmain(f, x) + + def test_error_non_qjit_object(self): + """An error should be raised if the object supplied to the debug function is not a QJIT.""" + + def f(x: float): + """Identity function.""" + return x + + with pytest.raises(TypeError, match="First argument needs to be a 'QJIT' object"): + get_cmain(f, 0.5) diff --git a/frontend/test/pytest/test_compiler.py b/frontend/test/pytest/test_compiler.py index bd05d13ed3..769e05b8aa 100644 --- a/frontend/test/pytest/test_compiler.py +++ b/frontend/test/pytest/test_compiler.py @@ -30,12 +30,9 @@ import pytest from catalyst import qjit -from catalyst.compilation_pipelines import WorkspaceManager from catalyst.compiler import DEFAULT_PIPELINES, CompileOptions, Compiler, LinkerDriver -from catalyst.jax_tracer import trace_to_mlir from catalyst.utils.exceptions import CompileError from catalyst.utils.filesystem import Directory -from catalyst.utils.runtime import get_lib_path # pylint: disable=missing-function-docstring @@ -263,75 +260,6 @@ def test_compiler_driver_with_flags(self): assert observed_outfilename == expected_outfilename assert os.path.exists(observed_outfilename) - def test_compiler_from_textual_ir(self): - """Test the textual IR compilation.""" - full_path = get_lib_path("runtime", "RUNTIME_LIB_DIR") - extension = ".so" if platform.system() == "Linux" else ".dylib" - - # pylint: disable=line-too-long - ir = ( - r""" -module @workflow { - func.func public @catalyst.entry_point(%arg0: tensor) -> tensor attributes {llvm.emit_c_interface} { - %0 = call @workflow(%arg0) : (tensor) -> tensor - return %0 : tensor - } - func.func private @workflow(%arg0: tensor) -> tensor attributes {diff_method = "finite-diff", llvm.linkage = #llvm.linkage, qnode} { - quantum.device [""" - + r'"' - + full_path - + r"""/librtd_lightning""" - + extension - + """", "LightningSimulator", "{'shots': 0}"] - %0 = stablehlo.constant dense<4> : tensor - %1 = quantum.alloc( 4) : !quantum.reg - %2 = stablehlo.constant dense<0> : tensor - %extracted = tensor.extract %2[] : tensor - %3 = quantum.extract %1[%extracted] : !quantum.reg -> !quantum.bit - %4 = quantum.custom "PauliX"() %3 : !quantum.bit - %5 = stablehlo.constant dense<1> : tensor - %extracted_0 = tensor.extract %5[] : tensor - %6 = quantum.extract %1[%extracted_0] : !quantum.reg -> !quantum.bit - %extracted_1 = tensor.extract %arg0[] : tensor - %7 = quantum.custom "RX"(%extracted_1) %6 : !quantum.bit - %8 = quantum.namedobs %4[ PauliZ] : !quantum.obs - %9 = quantum.expval %8 : f64 - %from_elements = tensor.from_elements %9 : tensor - quantum.dealloc %1 : !quantum.reg - quantum.device_release - return %from_elements : tensor - } - func.func @setup() { - quantum.init - return - } - func.func @teardown() { - quantum.finalize - return - } -} -""" - ) - compiled_function = qjit(ir) - assert compiled_function(0.1) == -1 - - def test_parsing_errors(self): - """Test parsing error handling.""" - - ir = r""" -module @workflow { - func.func public @catalyst.entry_point(%arg0: tensor) -> tensor attributes {llvm.emit_c_interface} { - %c = stablehlo.constant dense<4.0> : tensor - return %c : tensor // Invalid type - } -} -""" - with pytest.raises(CompileError) as e: - qjit(ir)(0.1) - - assert "Failed to parse module as MLIR source" in e.value.args[0] - assert "Failed to parse module as LLVM source" in e.value.args[0] - def test_pipeline_error(self): """Test pipeline error handling.""" diff --git a/frontend/test/pytest/test_contexts.py b/frontend/test/pytest/test_contexts.py index 517d24d6c9..43bd1b9691 100644 --- a/frontend/test/pytest/test_contexts.py +++ b/frontend/test/pytest/test_contexts.py @@ -18,7 +18,7 @@ import pytest from catalyst import cond, measure, qjit, while_loop -from catalyst.utils.contexts import EvaluationContext, EvaluationMode +from catalyst.tracing.contexts import EvaluationContext, EvaluationMode class TestEvaluationModes: diff --git a/frontend/test/pytest/test_debug.py b/frontend/test/pytest/test_debug.py index 01e7df1661..ec26ac3045 100644 --- a/frontend/test/pytest/test_debug.py +++ b/frontend/test/pytest/test_debug.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import platform import re import jax.numpy as jnp @@ -17,6 +18,10 @@ import pytest from catalyst import debug, for_loop, qjit +from catalyst.compiler import CompileOptions, Compiler +from catalyst.debug import compile_from_mlir +from catalyst.utils.exceptions import CompileError +from catalyst.utils.runtime import get_lib_path class TestDebugPrint: @@ -184,5 +189,103 @@ def func2(): assert out == "hello\ngoodbye\n" +class TestCompileFromIR: + """Test the debug feature that compiles from a string representation of the IR.""" + + def test_compiler_from_textual_ir(self): + """Test the textual IR compilation.""" + full_path = get_lib_path("runtime", "RUNTIME_LIB_DIR") + extension = ".so" if platform.system() == "Linux" else ".dylib" + + # pylint: disable=line-too-long + ir = ( + r""" +module @workflow { + func.func public @catalyst.entry_point(%arg0: tensor) -> tensor attributes {llvm.emit_c_interface} { + %0 = call @workflow(%arg0) : (tensor) -> tensor + return %0 : tensor + } + func.func private @workflow(%arg0: tensor) -> tensor attributes {diff_method = "finite-diff", llvm.linkage = #llvm.linkage, qnode} { + quantum.device [""" + + r'"' + + full_path + + r"""/librtd_lightning""" + + extension + + """", "LightningSimulator", "{'shots': 0}"] + %0 = stablehlo.constant dense<4> : tensor + %1 = quantum.alloc( 4) : !quantum.reg + %2 = stablehlo.constant dense<0> : tensor + %extracted = tensor.extract %2[] : tensor + %3 = quantum.extract %1[%extracted] : !quantum.reg -> !quantum.bit + %4 = quantum.custom "PauliX"() %3 : !quantum.bit + %5 = stablehlo.constant dense<1> : tensor + %extracted_0 = tensor.extract %5[] : tensor + %6 = quantum.extract %1[%extracted_0] : !quantum.reg -> !quantum.bit + %extracted_1 = tensor.extract %arg0[] : tensor + %7 = quantum.custom "RX"(%extracted_1) %6 : !quantum.bit + %8 = quantum.namedobs %4[ PauliZ] : !quantum.obs + %9 = quantum.expval %8 : f64 + %from_elements = tensor.from_elements %9 : tensor + quantum.dealloc %1 : !quantum.reg + quantum.device_release + return %from_elements : tensor + } + func.func @setup() { + quantum.init + return + } + func.func @teardown() { + quantum.finalize + return + } +} +""" + ) + compiled_function = compile_from_mlir(ir) + assert compiled_function(0.1) == [-1] + + def test_compile_from_ir_with_compiler(self): + """Supply a custom compiler instance to the textual compilation function.""" + + options = CompileOptions(static_argnums=[1]) + compiler = Compiler(options) + + ir = r""" +module @workflow { + func.func public @catalyst.entry_point(%arg0: tensor) -> tensor attributes {llvm.emit_c_interface} { + return %arg0 : tensor + } + func.func @setup() { + quantum.init + return + } + func.func @teardown() { + quantum.finalize + return + } +} +""" + + compiled_function = compile_from_mlir(ir, compiler=compiler) + assert compiled_function(0.1, 0.2) == [0.1] # allow call with one extra argument + + def test_parsing_errors(self): + """Test parsing error handling.""" + + ir = r""" +module @workflow { + func.func public @catalyst.entry_point(%arg0: tensor) -> tensor attributes {llvm.emit_c_interface} { + %c = stablehlo.constant dense<4.0> : tensor + return %c : tensor // Invalid type + } +} +""" + with pytest.raises(CompileError) as e: + compile_from_mlir(ir)(0.1) + + assert "Failed to parse module as MLIR source" in e.value.args[0] + assert "Failed to parse module as LLVM source" in e.value.args[0] + + if __name__ == "__main__": pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/test_jit_behaviour.py b/frontend/test/pytest/test_jit_behaviour.py index adb34dedaa..998724ac17 100644 --- a/frontend/test/pytest/test_jit_behaviour.py +++ b/frontend/test/pytest/test_jit_behaviour.py @@ -24,8 +24,12 @@ from numpy import pi from catalyst import for_loop, grad, measure, qjit -from catalyst.compilation_pipelines import CompiledFunction, TypeCompatibility from catalyst.jax_primitives import _scalar_abstractify +from catalyst.tracing.type_signatures import ( + TypeCompatibility, + get_abstract_signature, + typecheck_signatures, +) def f_aot_builder(backend, wires=1, shots=1000): @@ -496,34 +500,40 @@ def test_shots_in_callsite_in_sample(self, backend): class TestPromotionRules: """Class to test different promotion rules.""" + def test_against_none_target(self): + """Test type check result when the target is None.""" + + retval = typecheck_signatures(None, [1]) + assert TypeCompatibility.NEEDS_COMPILATION == retval + def test_incompatible_compiled_vs_runtime_different_lengths(self): """Test incompatible compiled vs runtime.""" - retval = CompiledFunction.typecheck(None, [], [1]) + retval = typecheck_signatures([], [1]) assert TypeCompatibility.NEEDS_COMPILATION == retval def test_incompatible_compiled_vs_runtime_different_types(self): """Test incompatible compiled vs runtime with different types.""" - retval = CompiledFunction.typecheck(None, jnp.array([1]), jnp.array([complex(1, 2)])) + retval = typecheck_signatures(jnp.array([1]), jnp.array([complex(1, 2)])) assert TypeCompatibility.NEEDS_COMPILATION == retval def test_incompatible_compiled_vs_runtime_different_shapes(self): """Test incompatible compiled vs runtime with different shapes.""" - retval = CompiledFunction.typecheck(None, jnp.array([1, 2]), jnp.array([1])) + retval = typecheck_signatures(jnp.array([1, 2]), jnp.array([1])) assert TypeCompatibility.NEEDS_COMPILATION == retval def test_can_skip_promotion(self): """Test skipping promotion""" - retval = CompiledFunction.typecheck(None, jnp.array([1]), jnp.array([1])) + retval = typecheck_signatures(jnp.array([1]), jnp.array([1])) assert TypeCompatibility.CAN_SKIP_PROMOTION == retval def test_needs_promotion(self): """Test promotion""" - retval = CompiledFunction.typecheck(None, jnp.array([1.0]), jnp.array([1])) + retval = typecheck_signatures(jnp.array([1.0]), jnp.array([1])) assert TypeCompatibility.NEEDS_PROMOTION == retval @@ -533,33 +543,33 @@ class TestPromotionRulesDictionary: def test_trivial_no_promotion(self): """Test trivial for the same dictionary as input.""" one = jnp.array(1.0) - retval = CompiledFunction.typecheck(None, {"key1": one}, {"key1": one}) + retval = typecheck_signatures({"key1": one}, {"key1": one}) assert TypeCompatibility.CAN_SKIP_PROMOTION == retval def test_trivial_no_promotion_different_values(self): """Test trivial for the same dictionary with different values.""" one = jnp.array(1.0) two = jnp.array(2.0) - retval = CompiledFunction.typecheck(None, {"key1": one}, {"key1": two}) + retval = typecheck_signatures({"key1": one}, {"key1": two}) assert TypeCompatibility.CAN_SKIP_PROMOTION == retval def test_trivial_promotion_different_values(self): """Test promotion where keys have different values.""" one = jnp.array(1.0) one_int = jnp.array(1) - retval = CompiledFunction.typecheck(None, {"key1": one}, {"key1": one_int}) + retval = typecheck_signatures({"key1": one}, {"key1": one_int}) assert TypeCompatibility.NEEDS_PROMOTION == retval def test_recompilation_superset_keys(self): """Recompile if the structure is different superset case.""" one = jnp.array(1.0) - retval = CompiledFunction.typecheck(None, {"key1": one}, {"key2": one, "key1": one}) + retval = typecheck_signatures({"key1": one}, {"key2": one, "key1": one}) assert TypeCompatibility.NEEDS_COMPILATION == retval def test_recompilation_subset_keys(self): """Recompile if the structure is different subset case.""" one = jnp.array(1.0) - retval = CompiledFunction.typecheck(None, {"key2": one, "key1": one}, {"key1": one}) + retval = typecheck_signatures({"key2": one, "key1": one}, {"key1": one}) assert TypeCompatibility.NEEDS_COMPILATION == retval @@ -568,21 +578,18 @@ def test_incompatible_argument(self): """Test incompatible argument.""" string = "hello world" - with pytest.raises(TypeError) as err: - CompiledFunction.get_runtime_signature([string]) - assert "Unsupported argument type:" in str(err.value) + with pytest.raises(TypeError, match=" is not a valid JAX type"): + get_abstract_signature([string]) def test_incompatible_type_reachable_from_user_code(self): """Raise error message for incompatible types""" - with pytest.raises(TypeError) as err: + with pytest.raises(TypeError, match=" is not a valid JAX type"): @qjit def f(x: str): return - assert "Unsupported argument type:" in str(err.value) - def test_incompatible_abstractify(self): """Check error message. @@ -590,11 +597,9 @@ def test_incompatible_abstractify(self): This is because the incompatible argument above would reach it. """ - with pytest.raises(TypeError) as err: + with pytest.raises(TypeError, match=" is not a valid JAX type"): _scalar_abstractify(str) - assert "Cannot convert given type" in str(err.value) - class TestClassicalCompilation: @pytest.mark.parametrize("a,b", [(1, 1)]) diff --git a/setup.py b/setup.py index 30ea8ebf6e..dfc558a9d9 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,7 @@ entry_points = { "pennylane.plugins": "cudaq = catalystcuda:CudaQDevice", "pennylane.compilers": [ - "context = catalyst.utils.contexts:EvaluationContext", + "context = catalyst.tracing.contexts:EvaluationContext", "ops = catalyst:pennylane_extensions", "qjit = catalyst:qjit", ], @@ -62,7 +62,7 @@ entry_points = { "pennylane.plugins": "cudaq = catalystcuda:CudaQDevice", "pennylane.compilers": [ - "catalyst.context = catalyst.utils.contexts:EvaluationContext", + "catalyst.context = catalyst.tracing.contexts:EvaluationContext", "catalyst.ops = catalyst:pennylane_extensions", "catalyst.qjit = catalyst:qjit", ],