Skip to content

Commit

Permalink
Rework QJIT class into distinct compilation stages (#531)
Browse files Browse the repository at this point in the history
This is part 2 of a refactor started in #529. 

The QJIT class is reworked into 5 distinct compilation stages:
- pre-compilation (like autograph)
- capture (jaxpr generation)
- ir-generation (mlir generation)
- compilation (llvm and binary code generation - cannot be split up
since this happens in the compiler driver)
- execution

The class is also streamlined by using a new compilation cache to handle
previously compiled functions and signature lookups.

One point of contention might be the results produced by the split of
the `trace_to_mlir` function, which have been simplified and need to be
double checked against #520. EDIT:
c71c322
should address this concern

[sc-57014]

closes #268 
closes #520
  • Loading branch information
dime10 authored Feb 23, 2024
1 parent 47fd3fc commit 44506f1
Show file tree
Hide file tree
Showing 23 changed files with 682 additions and 395 deletions.
61 changes: 61 additions & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,44 @@

<h3>Improvements</h3>

* Catalyst will now remember previously compiled functions when the PyTree metadata of arguments
changes, in addition to already rememebering compiled functions when static arguments change.
[(#522)](https://github.com/PennyLaneAI/catalyst/pull/531)

The following example will no longer trigger a third compilation:
```py
@qjit
def func(x):
print("compiling")
return x
```
```pycon
>>> func([1,]); # list
compiling
>>> func((2,)); # tuple
compiling
>>> func([3,]); # list

```

Note however that in order to keep overheads low, changing the argument *type* or *shape* (in a
promotion incompatible way) may override a previously stored function (with identical PyTree
metadata and static argument values):
```py
@qjit
def func(x):
print("compiling")
return x
```
```pycon
>>> func(jnp.array(1)); # scalar
compiling
>>> func(jnp.array([2.])); # 1-D array
compiling
>>> func(jnp.array(3)); # scalar
compiling
```

* Keep the structure of the function return when taking the derivatives, JVP and VJP (pytrees support).
[(#500)](https://github.com/PennyLaneAI/catalyst/pull/500)
[(#501)](https://github.com/PennyLaneAI/catalyst/pull/501)
Expand Down Expand Up @@ -143,6 +181,24 @@

<h3>Breaking changes</h3>

* The Catalyst Python frontend has been partially refactored. The impact on user-facing
functionality is minimal, but the location of certain classes and methods used by the package
may have changed.
[(#529)](https://github.com/PennyLaneAI/catalyst/pull/529)
[(#522)](https://github.com/PennyLaneAI/catalyst/pull/531)

The following changes have been made:
* Some debug methods and features on the QJIT class have been turned into free functions and moved
to the `catalyst.debug` module, which will now appear in the public documention. This includes
compiling a program from ir, obtaining a C program to invoke a compiled function from, and
printing fine-grained MLIR compilation stages.
* The `compilation_pipelines.py` module has been renamed to `jit.py`, and certain functionality
has been moved out (see following items).
* A new module `compiled_functions.py` now manages low-level access to compiled functions.
* A new module `tracing/type_signatures.py` handles functionality related managing arguments
and type signatures during the tracing process.
* The `contexts.py` module has been moved from `utils` to the new `tracing` sub-module.

* `QCtrl` is overriden and never used.
[(#522)](https://github.com/PennyLaneAI/catalyst/pull/522)

Expand Down Expand Up @@ -247,6 +303,11 @@

<h3>Bug fixes</h3>

* Catalyst will no longer print a warning that recompilation is triggered when a `@qjit` decorated
function with no arguments is invoke without having been compiled first, for example via the use
of `target="mlir"`.
[(#522)](https://github.com/PennyLaneAI/catalyst/pull/531)

* Only set `JAX_DYNAMIC_SHAPES` configuration option during `trace_to_mlir()`.
[(#526)](https://github.com/PennyLaneAI/catalyst/pull/526)

Expand Down
1 change: 0 additions & 1 deletion doc/dev/sharp_bits.rst
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ array(0.16996714)
However, deviating from this will result in recompilation and a warning message:

>>> circuit(jnp.array([1.4, 1.4, 0.3, 0.1]))
catalyst/compilation_pipelines.py:592:
UserWarning: Provided arguments did not match declared signature, recompiling...
Tracing occurring
array(0.16996714)
Expand Down
2 changes: 1 addition & 1 deletion frontend/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ The ``catalyst`` Python package is a mixed Python package which relies on some C

and the following modules:

- `compilation_pipelines.py <https://github.com/PennyLaneAI/catalyst/tree/main/frontend/compilation_pipelines.py>`_:
- `jit.py <https://github.com/PennyLaneAI/catalyst/tree/main/frontend/jit.py>`_:
This module contains classes and decorators for just-in-time and ahead-of-time compilation of
hybrid quantum-classical functions using Catalyst.

Expand Down
3 changes: 2 additions & 1 deletion frontend/catalyst/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@

from catalyst import debug
from catalyst.ag_utils import AutoGraphError, autograph_source
from catalyst.compilation_pipelines import QJIT, CompileOptions, qjit
from catalyst.compiler import CompileOptions
from catalyst.jit import QJIT, qjit
from catalyst.pennylane_extensions import (
adjoint,
cond,
Expand Down
167 changes: 148 additions & 19 deletions frontend/catalyst/compiled_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
"""This module contains classes to manage compiled functions and their underlying resources."""

import ctypes
from dataclasses import dataclass
from typing import Tuple

import numpy as np
from jax.interpreters import mlir
from jax.tree_util import tree_flatten, tree_unflatten
from jax.tree_util import PyTreeDef, tree_flatten, tree_unflatten
from mlir_quantum.runtime import (
as_ctype,
get_ranked_memref_descriptor,
Expand All @@ -27,9 +29,15 @@
)

from catalyst.jax_extras import get_implicit_and_explicit_flat_args
from catalyst.tracing.type_signatures import filter_static_args
from catalyst.tracing.type_signatures import (
TypeCompatibility,
filter_static_args,
get_decomposed_signature,
typecheck_signatures,
)
from catalyst.utils import wrapper # pylint: disable=no-name-in-module
from catalyst.utils.c_template import get_template, mlir_type_to_numpy_type
from catalyst.utils.filesystem import Directory


class SharedObjectManager:
Expand All @@ -43,17 +51,19 @@ class SharedObjectManager:
"""

def __init__(self, shared_object_file, func_name):
self.shared_object_file = shared_object_file
self.shared_object = None
self.func_name = func_name
self.function = None
self.setup = None
self.teardown = None
self.mem_transfer = None
self.open(shared_object_file, func_name)
self.open()

def open(self, shared_object_file, func_name):
def open(self):
"""Open the sharead object and load symbols."""
self.shared_object = ctypes.CDLL(shared_object_file)
self.function, self.setup, self.teardown, self.mem_transfer = self.load_symbols(func_name)
self.shared_object = ctypes.CDLL(self.shared_object_file)
self.function, self.setup, self.teardown, self.mem_transfer = self.load_symbols()

def close(self):
"""Close the shared object"""
Expand All @@ -66,17 +76,14 @@ def close(self):
# pylint: disable=protected-access
dlclose(self.shared_object._handle)

def load_symbols(self, func_name):
def load_symbols(self):
"""Load symbols necessary for for execution of the compiled function.
Args:
func_name: name of compiled function to be executed
Returns:
function: function handle
setup: handle to the setup function, which initializes the device
teardown: handle to the teardown function, which tears down the device
mem_transfer: memory transfer shared object
CFuncPtr: handle to the main function of the program
CFuncPtr: handle to the setup function, which initializes the device
CFuncPtr: handle to the teardown function, which tears down the device
CFuncPtr: handle to the memory transfer function for program results
"""

setup = self.shared_object.setup
Expand All @@ -88,7 +95,7 @@ def load_symbols(self, func_name):
teardown.restypes = None

# We are calling the c-interface
function = self.shared_object["_catalyst_pyface_" + func_name]
function = self.shared_object["_catalyst_pyface_" + self.func_name]
# Guaranteed from _mlir_ciface specification
function.restypes = None
# Not needed, computed from the arguments.
Expand Down Expand Up @@ -122,7 +129,6 @@ class CompiledFunction:
"""

def __init__(self, shared_object_file, func_name, restype, compile_options):
self.shared_object_file = shared_object_file
self.shared_object = SharedObjectManager(shared_object_file, func_name)
self.compile_options = compile_options
self.return_type_c_abi = None
Expand All @@ -140,7 +146,7 @@ def _exec(shared_object, has_return, numpy_dict, *args):
*args: arguments to the function
Returns:
retval: the value computed by the function or None if the function has no return value
the return values computed by the function or None if the function has no results
"""

with shared_object as lib:
Expand Down Expand Up @@ -256,8 +262,8 @@ def args_to_memref_descs(self, restype, args):
args: the JAX arrays to be used as arguments to the function
Returns:
c_abi_args: a list of memref descriptor pointers to return values and parameters
numpy_arg_buffer: A list to the return values. It must be kept around until the function
List: a list of memref descriptor pointers to return values and parameters
List: A list to the return values. It must be kept around until the function
finishes execution as the memref descriptors will point to memory locations inside
numpy arrays.
Expand Down Expand Up @@ -327,3 +333,126 @@ def __call__(self, *args, **kwargs):
)

return result


@dataclass
class CacheKey:
"""A key by which to identify entries in the compiled function cache.
The cache only rembers one compiled function for each possible combination of:
- dynamic argument PyTree metadata
- static argument values
"""

treedef: PyTreeDef
static_args: Tuple

def __hash__(self) -> int:
if not hasattr(self, "_hash"):
# pylint: disable=attribute-defined-outside-init
self._hash = hash((self.treedef, self.static_args))
return self._hash


@dataclass
class CacheEntry:
"""An entry in the compiled function cache.
For each compiled function, the cache stores the dynamic argument signature, the output PyTree
definition, as well as the workspace in which compilation takes place.
"""

compiled_fn: CompiledFunction
signature: Tuple
out_treedef: PyTreeDef
workspace: Directory


class CompilationCache:
"""Class to manage CompiledFunction instances and retrieving previously compiled versions of
a function.
A full match requires the following properties to match:
- dynamic argument signature (the shape and dtype of flattened arrays)
- dynamic argument PyTree definitions
- static argument values
In order to allow some flexibility in the type of arguments provided by the user, a match is
also produced if the dtype of dynamic arguments can be promoted to the dtype of an existing
signature via JAX type promotion rules. Additional leniency is provided in the shape of
dynamic arguments in accordance with the abstracted axis specification.
To avoid excess promotion checks, only one function version is stored at a time for a given
combination of PyTreeDefs and static arguments.
"""

def __init__(self, static_argnums, abstracted_axes):
self.static_argnums = static_argnums
self.abstracted_axes = abstracted_axes
self.cache = {}

def get_function_status_and_key(self, args):
"""Check if the provided arguments match an existing function in the cache. The cache
status of the function is returned as a compilation action:
- no match: requires compilation
- partial match: requires argument promotion
- full match: skip promotion
Args:
args: arguments to match to existing functions
Returns:
TypeCompatibility
CacheKey | None
"""
if not self.cache:
return TypeCompatibility.NEEDS_COMPILATION, None

flat_runtime_sig, treedef, static_args = get_decomposed_signature(args, self.static_argnums)
key = CacheKey(treedef, static_args)
if key not in self.cache:
return TypeCompatibility.NEEDS_COMPILATION, None

entry = self.cache[key]
runtime_signature = tree_unflatten(treedef, flat_runtime_sig)
action = typecheck_signatures(entry.signature, runtime_signature, self.abstracted_axes)
return action, key

def lookup(self, args):
"""Get a function (if present) that matches the provided argument signature. Also computes
whether promotion is necessary.
Args:
args (Iterable): the arguments to match to existing functions
Returns:
CacheEntry | None: the matched cache entry
bool: whether the matched entry requires argument promotion
"""
action, key = self.get_function_status_and_key(args)

if action == TypeCompatibility.NEEDS_COMPILATION:
return None, None
elif action == TypeCompatibility.NEEDS_PROMOTION:
return self.cache[key], True
else:
assert action == TypeCompatibility.CAN_SKIP_PROMOTION
return self.cache[key], False

def insert(self, fn, args, out_treedef, workspace):
"""Inserts the provided function into the cache.
Args:
fn (CompiledFunction): compilation result
args (Iterable): arguments to determine cache key and additional metadata from
out_treedef (PyTreeDef): the output shape of the function
workspace (Directory): directory where compilation artifacts are stored
"""
assert isinstance(fn, CompiledFunction)

flat_signature, treedef, static_args = get_decomposed_signature(args, self.static_argnums)
signature = tree_unflatten(treedef, flat_signature)

key = CacheKey(treedef, static_args)
entry = CacheEntry(fn, signature, out_treedef, workspace)
self.cache[key] = entry
7 changes: 0 additions & 7 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,10 +524,3 @@ def get_output_of(self, pipeline) -> Optional[str]:
return None

return self.last_compiler_output.get_pipeline_output(pipeline)

def print(self, pipeline):
"""Print the output IR of pass.
Args:
pipeline (str): name of pass class
"""
print(self.get_output_of(pipeline)) # pragma: no cover
5 changes: 2 additions & 3 deletions frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,12 +840,11 @@ def cudaq_backend_info(device):
# We could also pass abstract arguments here in *args
# the same way we do so in Catalyst.
# But I think that is redundant now given make_jaxpr2
_, jaxpr, _, out_tree = trace_to_jaxpr(func, static_args, abs_axes, *args)
jaxpr, out_treedef = trace_to_jaxpr(func, static_args, abs_axes, args, {})

# TODO(@erick-xanadu):
# What about static_args?
# We could return _out_type2 as well
return jaxpr, out_tree
return jaxpr, out_treedef


def interpret(fun):
Expand Down
Loading

0 comments on commit 44506f1

Please sign in to comment.