Skip to content

Commit

Permalink
Daily rc sync to main (#1091)
Browse files Browse the repository at this point in the history
Automatic sync from the release candidate to main during a feature
freeze.

---------

Co-authored-by: Raul Torres <138264735+rauletorresc@users.noreply.github.com>
Co-authored-by: paul0403 <79805239+paul0403@users.noreply.github.com>
Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com>
Co-authored-by: Ahmed Darwish <exclass9.24@gmail.com>
Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com>
Co-authored-by: ringo-but-quantum <github-ringo-but-quantum@xanadu.ai>
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
Co-authored-by: Romain Moyard <rmoyard@gmail.com>
Co-authored-by: GitHub Actions Bot <>
  • Loading branch information
9 people authored Sep 3, 2024
1 parent 39e6baf commit afa4759
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 9 deletions.
20 changes: 20 additions & 0 deletions doc/releases/changelog-0.8.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,25 @@
* Bug fixed when parameter annotations return strings.
[(#1078)](https://github.com/PennyLaneAI/catalyst/pull/1078)

* In certain cases, `jax.scipy.linalg.expm`
[may return incorrect numerical results](https://github.com/PennyLaneAI/catalyst/issues/1071)
when used within a qjit-compiled function. A warning will now be raised
when `jax.scipy.linalg.expm` is used to inform of this issue.

In the meantime, we strongly recommend the
[catalyst.accelerate](https://docs.pennylane.ai/projects/catalyst/en/latest/code/api/catalyst.accelerate.html) function
within qjit-compiled function to call `jax.scipy.linalg.expm` directly.

```python
@qjit
def f(A):
B = catalyst.accelerate(jax.scipy.linalg.expm)(A)
return B
```

Note that this PR doesn't actually fix the aforementioned numerical errors, and just raises a warning.
[(#1082)](https://github.com/PennyLaneAI/catalyst/pull/1082)

<h3>Documentation</h3>

* A page has been added to the documentation, listing devices that are
Expand Down Expand Up @@ -686,6 +705,7 @@ This release contains contributions from (in alphabetical order):
Joey Carter,
Alessandro Cosentino,
Lillian M. A. Frederiksen,
David Ittah,
Josh Izaac,
Christina Lee,
Kunwar Maheep Singh,
Expand Down
17 changes: 9 additions & 8 deletions frontend/catalyst/api_extensions/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

from catalyst.jax_extras import transient_jax_config
from catalyst.jax_primitives import python_callback_p
from catalyst.tracing.contexts import EvaluationContext, GradContext
from catalyst.tracing.contexts import AccelerateContext, EvaluationContext, GradContext
from catalyst.utils.exceptions import DifferentiableCompileError
from catalyst.utils.jnp_to_memref import (
get_ranked_memref_descriptor,
Expand Down Expand Up @@ -347,13 +347,14 @@ def accelerate_impl(users_func=None, *, dev=None):

@functools.wraps(users_func, assigned=WRAPPER_ASSIGNMENTS)
def total(context, *args, **kwargs):
nonlocal users_func
if is_partial:
_, shape = tree_flatten(users_func)
users_func = tree_unflatten(shape, context)
return users_func(*args, **kwargs)
else:
return users_func(*args, **kwargs)
with AccelerateContext():
nonlocal users_func
if is_partial:
_, shape = tree_flatten(users_func)
users_func = tree_unflatten(shape, context)
return users_func(*args, **kwargs)
else:
return users_func(*args, **kwargs)

with transient_jax_config({"jax_dynamic_shapes": False}):
# jax.jit will wrap total and total wraps the user_function
Expand Down
55 changes: 55 additions & 0 deletions frontend/catalyst/jax_extras/jax_scipy_linalg_warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2022-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This module contains warnings for using jax.scipy.linalg functions inside qjit.
Due to improperly linked lapack symbols, occasionally these functions give wrong
numerical results when used in a qjit context.
As for now, we warn users to wrap all of these with a catalyst.accelerate() callback.
This patch should be removed when we have proper linkage to lapack.
See:
https://app.shortcut.com/xanaduai/story/70899/find-a-system-to-automatically-create-a-custom-call-library-from-the-one-in-jax
https://github.com/PennyLaneAI/catalyst/issues/753
https://github.com/PennyLaneAI/catalyst/issues/1071
"""

import warnings

import jax

from catalyst.tracing.contexts import AccelerateContext


class JaxLinalgWarner:
def __init__(self, fn):
self.fn = fn

def __call__(self, *args, **kwargs):
if not AccelerateContext.am_inside_accelerate():
warnings.warn(
f"""
jax.scipy.linalg.{self.fn.__name__} occasionally gives wrong numerical results
when used within a qjit-compiled function.
See https://github.com/PennyLaneAI/catalyst/issues/1071.
In the meantime, we recommend catalyst.accelerate to call
the underlying {self.fn.__name__} function directly:
@qjit
def f(A):
return catalyst.accelerate(jax.scipy.linalg.{self.fn.__name__})(A)
See https://docs.pennylane.ai/projects/catalyst/en/latest/code/api/catalyst.accelerate.html
"""
)
return (self.fn)(*args, **kwargs)
10 changes: 10 additions & 0 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from catalyst.compiled_functions import CompilationCache, CompiledFunction
from catalyst.compiler import CompileOptions, Compiler
from catalyst.debug.instruments import instrument
from catalyst.jax_extras.jax_scipy_linalg_warnings import JaxLinalgWarner
from catalyst.jax_tracer import lower_jaxpr_to_mlir, trace_to_jaxpr
from catalyst.logging import debug_logger, debug_logger_init
from catalyst.passes import _inject_transform_named_sequence
Expand Down Expand Up @@ -589,6 +590,15 @@ def closure(qnode, *args, **kwargs):

with Patcher(
(qml.QNode, "__call__", closure),
# !!! TODO: fix jax.scipy numerical failures with properly fetched lapack calls
# As of now, we raise a warning prompting the user to use a callback with catalyst.accelerate()
# https://app.shortcut.com/xanaduai/story/70899/find-a-system-to-automatically-create-a-custom-call-library-from-the-one-in-jax
# https://github.com/PennyLaneAI/catalyst/issues/753
# https://github.com/PennyLaneAI/catalyst/issues/1071
(jax.scipy.linalg, "expm", JaxLinalgWarner(jax.scipy.linalg.expm)),
(jax.scipy.linalg, "lu", JaxLinalgWarner(jax.scipy.linalg.lu)),
(jax.scipy.linalg, "lu_factor", JaxLinalgWarner(jax.scipy.linalg.lu_factor)),
(jax.scipy.linalg, "lu_solve", JaxLinalgWarner(jax.scipy.linalg.lu_solve)),
):
# TODO: improve PyTree handling

Expand Down
14 changes: 14 additions & 0 deletions frontend/catalyst/tracing/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,20 @@ def _peek():
return GradContext._grad_stack


class AccelerateContext:
_am_inside_accelerate: bool = False

def __enter__(self):
AccelerateContext._am_inside_accelerate = True

def __exit__(self, _exc_type, _exc, _exc_tb):
AccelerateContext._am_inside_accelerate = False

@staticmethod
def am_inside_accelerate():
return AccelerateContext._am_inside_accelerate


class EvaluationMode(Enum):
"""Enumerate the evaluation modes supported by Catalyst:
INTERPRETATION - native Python execution of a Catalyst program
Expand Down
126 changes: 125 additions & 1 deletion frontend/test/pytest/test_jax_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

"""Test that numerical jax functions produce correct results when compiled with qml.qjit"""

import warnings

import numpy as np
import pennylane as qml
import pytest
from jax import numpy as jnp
from jax import scipy as jsp

from catalyst import qjit
from catalyst import accelerate, qjit


class TestExpmNumerical:
Expand All @@ -32,10 +34,15 @@ class TestExpmNumerical:
jnp.array([[0.1, 0.2], [5.3, 1.2]]),
jnp.array([[1, 2], [3, 4]]),
jnp.array([[1.0, -1.0j], [1.0j, -1.0]]),
jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [3.0, 2.0, 1.0]]),
],
)
def test_expm_numerical(self, inp):
"""Test basic numerical correctness for jax.scipy.linalg.expm for float, int, complex"""
if np.array_equiv(inp, jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [3.0, 2.0, 1.0]])):
# this particular matrix has wrong answer numbers and need to be solved by proper lapack calls.
# https://github.com/PennyLaneAI/catalyst/issues/1071
pytest.xfail("Waiting for proper lapack calls")

@qjit
def f(x):
Expand Down Expand Up @@ -71,6 +78,123 @@ def circuit_rot():
assert np.allclose(res, expected)


class TestExpmWarnings:
"""Test jax.scipy.linalg.expm raises a warning when not used in accelerate callback"""

"""Remove the warnings module and this test when we have proper lapack calls"""

def test_expm_warnings(self):
@qjit
def f(x):
expm = jsp.linalg.expm
return expm(x)

with pytest.warns(
UserWarning,
match="jax.scipy.linalg.expm occasionally gives wrong numerical results",
):
f(jnp.array([[0.1, 0.2], [5.3, 1.2]]))

def test_accelerated_expm_no_warnings(self, recwarn):
@qjit
def f(x):
expm = accelerate(jsp.linalg.expm)
return expm(x)

observed = f(jnp.array([[0.1, 0.2], [5.3, 1.2]]))
expected = jsp.linalg.expm(jnp.array([[0.1, 0.2], [5.3, 1.2]]))
assert len(recwarn) == 0
assert np.allclose(observed, expected)


class TestLUWarnings:
"""Test jax.scipy.linalg.lu raises a warning when not used in accelerate callback"""

"""Remove the warnings module and this test when we have proper lapack calls"""

def test_lu_warnings(self):
@qjit
def f(x):
lu = jsp.linalg.lu
return lu(x)

with pytest.warns(
UserWarning,
match="jax.scipy.linalg.lu occasionally gives wrong numerical results",
):
f(jnp.array([[0.1, 0.2], [5.3, 1.2]]))

def test_accelerated_lu_no_warnings(self, recwarn):
@qjit
def f(x):
lu = accelerate(jsp.linalg.lu)
return lu(x)

observed = f(jnp.array([[0.1, 0.2], [5.3, 1.2]]))
expected = jsp.linalg.lu(jnp.array([[0.1, 0.2], [5.3, 1.2]]))
assert len(recwarn) == 0
assert np.allclose(observed, expected)

def test_lu_factor_warnings(self):
@qjit
def f(x):
luf = jsp.linalg.lu_factor
return luf(x)

with pytest.warns(
UserWarning,
match="jax.scipy.linalg.lu_factor occasionally gives wrong numerical results",
):
f(jnp.array([[0.1, 0.2], [5.3, 1.2]]))

def test_accelerated_lu_factor_no_warnings(self, recwarn):
@qjit
def f(x):
luf = accelerate(jsp.linalg.lu_factor)
return luf(x)

observed = f(jnp.array([[0.1, 0.2], [5.3, 1.2]]))
expected = jsp.linalg.lu_factor(jnp.array([[0.1, 0.2], [5.3, 1.2]]))
assert len(recwarn) == 0
assert np.allclose(observed[0], expected[0])
assert np.allclose(observed[1], expected[1])

def test_lu_solve_warnings(self):
@qjit
def f(x):
lus = jsp.linalg.lu_solve
b = jnp.array([3.0, 4.0])
B = accelerate(jsp.linalg.lu_factor)(
x
) # since this is a lu_solve unit test, use accelerate for lu_factor
return lus(B, b)

with pytest.warns(
UserWarning,
match="jax.scipy.linalg.lu_solve occasionally gives wrong numerical results",
):
f(jnp.array([[0.1, 0.2], [5.3, 1.2]]))

def test_accelerated_lu_solve_no_warnings(self, recwarn):
@qjit
def f(x):
lus = accelerate(jsp.linalg.lu_solve)
b = jnp.array([3.0, 4.0])
B = accelerate(jsp.linalg.lu_factor)(x)
return lus(B, b)

def truth(x):
lus = jsp.linalg.lu_solve
b = jnp.array([3.0, 4.0])
B = jsp.linalg.lu_factor(x)
return lus(B, b)

observed = f(jnp.array([[0.1, 0.2], [5.3, 1.2]]))
expected = truth(jnp.array([[0.1, 0.2], [5.3, 1.2]]))
assert len(recwarn) == 0
assert np.allclose(observed, expected)


class TestArgsortNumerical:
"""Test jax.numpy.argsort sort arrays correctly when being qjit compiled"""

Expand Down

0 comments on commit afa4759

Please sign in to comment.