Skip to content

Commit

Permalink
Merge branch 'main' into frontend-refactor-2
Browse files Browse the repository at this point in the history
  • Loading branch information
dime10 committed Feb 23, 2024
2 parents 3805b29 + 47fd3fc commit 5917cb4
Show file tree
Hide file tree
Showing 25 changed files with 627 additions and 442 deletions.
1 change: 1 addition & 0 deletions .dep-versions
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ jax=0.4.23
mhlo=4611968a5f6818e6bdfb82217b9e836e0400bba9
llvm=cd9a641613eddf25d4b25eaa96b2c393d401d42c
enzyme=1beb98b51442d50652eaa3ffb9574f4720d611f1
pennylane=95129a0d6365b48cb4acfa828ceb6a8532e47ef5
1 change: 1 addition & 0 deletions .github/workflows/build-wheel-linux-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -390,4 +390,5 @@ jobs:
run: |
python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/pytest -n auto
python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/pytest --backend="lightning.kokkos" -n auto
python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/async_tests
python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/pytest --runbraket=LOCAL -n auto
1 change: 1 addition & 0 deletions .github/workflows/build-wheel-macos-arm64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -400,4 +400,5 @@ jobs:
run: |
python${{ matrix.python_version }} -m pytest -v $GITHUB_WORKSPACE/frontend/test/pytest -n auto
python${{ matrix.python_version }} -m pytest -v $GITHUB_WORKSPACE/frontend/test/pytest --backend="lightning.kokkos" -n auto
python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/async_tests
python${{ matrix.python_version }} -m pytest -v $GITHUB_WORKSPACE/frontend/test/pytest --runbraket=LOCAL -n auto
1 change: 1 addition & 0 deletions .github/workflows/build-wheel-macos-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -370,4 +370,5 @@ jobs:
python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/pytest -n auto
# TODO: Uncomment after fixing https://github.com/PennyLaneAI/pennylane-lightning/issues/552
# python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/pytest --backend="lightning.kokkos" -n auto
python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/async_tests
python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/pytest --runbraket=LOCAL -n auto
2 changes: 2 additions & 0 deletions .github/workflows/check-pl-compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ jobs:
constants:
name: "Set build matrix"
uses: ./.github/workflows/constants.yaml
with:
use_release_tag: ${{ inputs.catalyst == 'stable' }}

check-config:
name: Build Configuration
Expand Down
9 changes: 9 additions & 0 deletions .github/workflows/constants.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ on:
required: false
default: false
type: boolean
use_release_tag:
required: false
default: false
type: boolean
outputs:
llvm_version:
description: "LLVM version"
Expand Down Expand Up @@ -45,6 +49,11 @@ jobs:
steps:
- name: Checkout Catalyst repo
uses: actions/checkout@v3
with:
fetch-depth: 0
- if: ${{ inputs.use_release_tag }}
run: |
git checkout $(git tag | sort -V | tail -1)
- name: LLVM version
id: llvm_version
Expand Down
10 changes: 10 additions & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@
f(2, MyClass(5)) # no re-compilation
```

* Catalyst now supports executing tapes in CUDA-Quantum simulators.
[(#477)](https://github.com/PennyLaneAI/catalyst/pull/477)
[(#536)](https://github.com/PennyLaneAI/catalyst/pull/536)

It has added the following devices:
* softwareq.qpp
* nvidia.statevec (with support for multi-gpu)
* nvidia.tensornet (with support for matrix product state)

<h3>Improvements</h3>

* Catalyst will now remember previously compiled functions when the PyTree metadata of arguments
Expand Down Expand Up @@ -304,6 +313,7 @@

* Handle run time exception in async qnodes.
[(#447)](https://github.com/PennyLaneAI/catalyst/pull/447)
[(#510)](https://github.com/PennyLaneAI/catalyst/pull/510)

This is done by:
* changeing `llvm.call` to `llvm.invoke`
Expand Down
2 changes: 1 addition & 1 deletion frontend/catalyst/ag_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@

import catalyst
from catalyst.ag_utils import AutoGraphError
from catalyst.jax_extras import DynamicJaxprTracer, ShapedArray
from catalyst.tracing.contexts import EvaluationContext
from catalyst.utils.jax_extras import DynamicJaxprTracer, ShapedArray
from catalyst.utils.patching import Patcher

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion frontend/catalyst/compiled_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
make_zero_d_memref_descriptor,
)

from catalyst.jax_extras import get_implicit_and_explicit_flat_args
from catalyst.tracing.type_signatures import (
TypeCompatibility,
filter_static_args,
Expand All @@ -37,7 +38,6 @@
from catalyst.utils import wrapper # pylint: disable=no-name-in-module
from catalyst.utils.c_template import get_template, mlir_type_to_numpy_type
from catalyst.utils.filesystem import Directory
from catalyst.utils.jax_extras import get_implicit_and_explicit_flat_args


class SharedObjectManager:
Expand Down
16 changes: 12 additions & 4 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ class LinkerDriver:
_default_fallback_compilers = ["clang", "gcc", "c99", "c89", "cc"]

@staticmethod
def get_default_flags():
def get_default_flags(options):
"""Re-compute the path where the libraries exist.
The use case for this is if someone is in a python jupyter notebook and
Expand Down Expand Up @@ -312,6 +312,14 @@ def get_default_flags():
elif platform.system() == "Darwin": # pragma: nocover
system_flags += ["-Wl,-arch_errors_fatal"]

# The exception handling mechanism requires linking against
# __gxx_personality_v0 which is either on -lstdc++ in
# or -lc++. We choose based on the operating system.
if options.async_qnodes and platform.system() == "Linux": # pragma: nocover
system_flags += ["-lstdc++"]
elif options.async_qnodes and platform.system() == "Darwin": # pragma: nocover
system_flags += ["-lc++"]

default_flags = [
"-shared",
"-rdynamic",
Expand Down Expand Up @@ -395,12 +403,12 @@ def run(infile, outfile=None, flags=None, fallback_compilers=None, options=None)
"""
if outfile is None:
outfile = LinkerDriver.get_output_filename(infile)
if options is None:
options = CompileOptions()
if flags is None:
flags = LinkerDriver.get_default_flags()
flags = LinkerDriver.get_default_flags(options)
if fallback_compilers is None:
fallback_compilers = LinkerDriver._default_fallback_compilers
if options is None:
options = CompileOptions()
for compiler in LinkerDriver._available_compilers(fallback_compilers):
success = LinkerDriver._attempt_link(compiler, flags, infile, outfile, options)
if success:
Expand Down
35 changes: 24 additions & 11 deletions frontend/catalyst/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def wrap_fn(fn):
class BaseCudaInstructionSet(qml.QubitDevice):
"""Base instruction set for CUDA-Quantum devices"""

# TODO: Once 0.35 is released, remove -dev suffix.
pennylane_requires = "0.35.0-dev"
pennylane_requires = ">=0.34"
version = "0.1.0"
author = "Xanadu, Inc."

Expand All @@ -68,14 +67,12 @@ class BaseCudaInstructionSet(qml.QubitDevice):
"RY",
"RZ",
"SWAP",
# "CSWAP", This is a bug in cuda-quantum. CSWAP is not exposed.
"CSWAP",
]
observables = []
config = Path(__file__).parent / "cuda_quantum.toml"

def __init__(self, shots=None, wires=None, mps=False, multi_gpu=False):
self.mps = mps
self.multi_gpu = multi_gpu
def __init__(self, shots=None, wires=None):
super().__init__(wires=wires, shots=shots)

def apply(self, operations, **kwargs):
Expand All @@ -88,25 +85,41 @@ def apply(self, operations, **kwargs):
class SoftwareQQPP(BaseCudaInstructionSet):
"""Concrete device class for qpp-cpu"""

name = "SoftwareQ q++ simulator"
short_name = "softwareq.qpp"

@property
def name(self):
"""Target name"""
return "qpp-cpu"


class NvidiaCuStateVec(BaseCudaInstructionSet):
"""Concrete device class for CuStateVec"""

name = "CuStateVec"
short_name = "nvidia.custatevec"

def __init__(self, shots=None, wires=None, multi_gpu=False): # pragma: no cover
super().__init__(wires=wires, shots=shots, multi_gpu=multi_gpu)
self.multi_gpu = multi_gpu
super().__init__(wires=wires, shots=shots)

@property
def name(self): # pragma: no cover
"""Target name"""
option = "-mgpu" if self.multi_gpu else ""
return f"nvidia{option}"


class NvidiaCuTensorNet(BaseCudaInstructionSet):
"""Concrete device class for CuTensorNet"""

name = "CuTensorNet"
short_name = "nvidia.cutensornet"

def __init__(self, shots=None, wires=None, mps=False): # pragma: no cover
super().__init__(wires=wires, shots=shots, mps=mps)
self.mps = mps
super().__init__(wires=wires, shots=shots)

@property
def name(self): # pragma: no cover
"""Target name"""
option = "-mps" if self.mps else ""
return f"tensornet{option}"
20 changes: 6 additions & 14 deletions frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,18 +275,11 @@ def change_device_to_cuda_device(ctx):

device_name = qdevice_eqn.params.get("rtd_name")

# TODO(@erick-xanadu) as more devices become available
# map the names here.
# TODO(@erick-xanadu) why does the device instruction lists the whole
# name instead of a short name?
target_map = {"SoftwareQ q++ simulator": "qpp-cpu"}
target = target_map.get(device_name, device_name)

if not target or not cudaq.has_target(target):
msg = f"Unavailable target {target}." # pragma: no cover
if not cudaq.has_target(device_name):
msg = f"Unavailable target {device_name}." # pragma: no cover
raise ValueError(msg)

cudaq_target = cudaq.get_target(target)
cudaq_target = cudaq.get_target(device_name)
cudaq.set_target(cudaq_target)

# cudaq_make_kernel returns a multiple values depending on the arguments.
Expand Down Expand Up @@ -411,7 +404,7 @@ def change_instruction(ctx, eqn):
"RY": "ry",
"RZ": "rz",
"SWAP": "swap",
# "CSWAP": "cswap", Bug in CUDA quantum. CSWAP is not exposed.
"CSWAP": "cswap",
# Other instructions that are missing:
# ch
# sdg
Expand Down Expand Up @@ -847,12 +840,11 @@ def cudaq_backend_info(device):
# We could also pass abstract arguments here in *args
# the same way we do so in Catalyst.
# But I think that is redundant now given make_jaxpr2
_, jaxpr, _, out_tree = trace_to_jaxpr(func, static_args, abs_axes, *args)
jaxpr, out_treedef = trace_to_jaxpr(func, static_args, abs_axes, args, {})

# TODO(@erick-xanadu):
# What about static_args?
# We could return _out_type2 as well
return jaxpr, out_tree
return jaxpr, out_treedef


def interpret(fun):
Expand Down
2 changes: 1 addition & 1 deletion frontend/catalyst/cuda/cuda_quantum.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ native = [
"RY",
"RZ",
"SWAP",
# "CSWAP", # Not exposed in CUDA quantum.
"CSWAP",
]

# Operators that should be decomposed according to the algorithm used
Expand Down
16 changes: 12 additions & 4 deletions frontend/catalyst/cuda/primitives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def cudaq_getstate(kernel):
@cudaq_getstate_p.def_impl
def cudaq_getstate_primitive_impl(kernel):
"""Concrete implementation."""
return cudaq.get_state(kernel)
return jax.numpy.array(cudaq.get_state(kernel))


@cudaq_getstate_p.def_abstract_eval
Expand Down Expand Up @@ -394,9 +394,17 @@ def cudaq_sample_impl(kernel, *args, shots_count=1000):
So, let's perform a little conversion here.
"""
a_dict = cudaq.sample(kernel, *args, shots_count=shots_count)
lls = [[k] * v for k, v in a_dict.items()]
# Weirdly enough Catalyst returns this transposed.
return jax.numpy.atleast_2d(jax.numpy.array([int(l) for ls in lls for l in ls])).T
aggregate = []
for bitstring, count in a_dict.items():
# It is technically a bit array
# So we should use int(bit)
# But in Catalyst, these are floats.
# So we use floats.
bitarray = [float(bit) for bit in bitstring]
for _ in range(count):
aggregate.append(bitarray)

return jax.numpy.array(aggregate)


@cudaq_sample_p.def_abstract_eval
Expand Down
56 changes: 56 additions & 0 deletions frontend/catalyst/jax_extras/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Catalyst additions to the Jax library """

from catalyst.jax_extras.lowering import custom_lower_jaxpr_to_module, jaxpr_to_mlir
from catalyst.jax_extras.patches import (
_gather_shape_rule_dynamic,
_no_clean_up_dead_vars,
get_aval2,
)
from catalyst.jax_extras.tracing import (
ClosedJaxpr,
DynshapedJaxpr,
DynamicJaxprTrace,
DynamicJaxprTracer,
Jaxpr,
PyTreeDef,
PyTreeRegistry,
ShapedArray,
ShapeDtypeStruct,
_abstractify,
_extract_implicit_args,
_initial_style_jaxpr,
_input_type_to_tracers,
convert_constvars_jaxpr,
convert_element_type,
deduce_avals,
eval_jaxpr,
get_implicit_and_explicit_flat_args,
infer_lambda_input_type,
initial_style_jaxprs_with_common_consts1,
initial_style_jaxprs_with_common_consts2,
make_jaxpr2,
make_jaxpr_effects,
new_dynamic_main2,
new_inner_tracer,
sort_eqns,
transient_jax_config,
tree_flatten,
tree_structure,
tree_unflatten,
treedef_is_leaf,
unzip2,
wrap_init,
)
Loading

0 comments on commit 5917cb4

Please sign in to comment.