Skip to content

Commit

Permalink
[Frontend] Fix Python fallback errors with AutoGraph in certain insta…
Browse files Browse the repository at this point in the history
…nces (#352)

Improve robustness of AutoGraph conversion. 

List of changes:
- fix bug when extracting the variable name in question for certain
error messages
- add input/output checking for while loops & add tests
- setup fix for leftover JAX primitive when fallback is triggered, by
adding access to the current JAXPR frame and quantum queuing context
- remove for/while loop primitives from the JAXPR when a fallback to
Python has occurred
 
Enables to following to succeed for example:
```py
@qjit(autograph=True)
def f():
    l = jnp.array([1, 2])
    for _ in range(2):
        l = jnp.kron(l, l)  # used to fail the lowering step of the leftover primitive, due to differently sized argument/return types
    return l
```

[sc-49491]
  • Loading branch information
dime10 authored Nov 8, 2023
1 parent 2737a84 commit cdff829
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 41 deletions.
15 changes: 15 additions & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,21 @@
* Include the "Catalyst" utility dialect in our MLIR C-API.
[(#345)](https://github.com/PennyLaneAI/catalyst/pull/345)

* Fix an issue with the AutoGraph conversion system that would prevent the fallback to Python from
working correctly in certain instances.
[(#352)](https://github.com/PennyLaneAI/catalyst/pull/352)

The following type of code is now supported:

```python
@qjit(autograph=True)
def f():
l = jnp.array([1, 2])
for _ in range(2):
l = jnp.kron(l, l)
return l
```

<h3>Breaking changes</h3>

* The axis ordering for `catalyst.jacobian` is updated to match `jax.jacobian`. Assume we have
Expand Down
119 changes: 92 additions & 27 deletions frontend/catalyst/ag_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
# as well as various utility objects.
import pennylane as qml
import tensorflow.python.autograph.impl.api as tf_autograph_api
from pennylane.queuing import AnnotatedQueue
from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core.converter import STANDARD_OPTIONS as STD
from tensorflow.python.autograph.core.converter import ConversionOptions
Expand All @@ -43,6 +44,7 @@

import catalyst
from catalyst.ag_utils import AutoGraphError
from catalyst.utils.contexts import EvaluationContext
from catalyst.utils.jax_extras import DynamicJaxprTracer, ShapedArray
from catalyst.utils.patching import Patcher

Expand All @@ -64,6 +66,42 @@
]


def get_program_length(reference_tracers):
"""Get the current number of instructions of the quantum and classical program."""
# pylint: disable=unnecessary-dunder-call

num_jaxpr_eqns, num_tape_ops = 0, 0

if EvaluationContext.is_tracing(): # pragma: no branch
jaxpr_frame = EvaluationContext.find_jaxpr_frame(reference_tracers)
num_jaxpr_eqns = len(jaxpr_frame.eqns)

if EvaluationContext.is_quantum_tracing():
quantum_queue = EvaluationContext.find_quantum_queue()
# Using the the class methods directly allows this to work for both
# QuantumTape & AnnotatedQueue instances.
num_tape_ops = AnnotatedQueue.__len__(quantum_queue)

return num_jaxpr_eqns, num_tape_ops


def reset_program_to_length(reference_tracers, num_jaxpr_eqns, num_tape_ops):
"""Reset the quantum and classical program back to a given length."""
# pylint: disable=unnecessary-dunder-call

if EvaluationContext.is_tracing(): # pragma: no branch
jaxpr_frame = EvaluationContext.find_jaxpr_frame(reference_tracers)
while len(jaxpr_frame.eqns) > num_jaxpr_eqns:
jaxpr_frame.eqns.pop()

if EvaluationContext.is_quantum_tracing():
quantum_queue = EvaluationContext.find_quantum_queue()
# Using the the class methods directly allows this to work for both
# QuantumTape & AnnotatedQueue instances.
while AnnotatedQueue.__len__(quantum_queue) > num_tape_ops:
AnnotatedQueue.popitem(quantum_queue)


def assert_results(results, var_names):
"""Assert that none of the results are undefined, i.e. have no value."""

Expand Down Expand Up @@ -111,7 +149,7 @@ def functional_cond():
set_state(results)


def assert_for_loop_inputs(inputs, iterate_names):
def assert_iteration_inputs(inputs, symbol_names):
"""All loop carried values, variables that are updated each iteration or accessed after the
loop terminates, need to be initialized prior to entering the loop.
Expand All @@ -138,7 +176,7 @@ def assert_for_loop_inputs(inputs, iterate_names):
jax.api_util.shaped_abstractify(inp)
except TypeError as e:
raise AutoGraphError(
f"The variable '{iterate_names[i]}' was initialized with type {type(inp)}, "
f"The variable '{symbol_names[i]}' was initialized with type {type(inp)}, "
"which is not compatible with JAX. Typically, this is the case for non-numeric "
"values.\n"
"You may still use such a variable as a constant inside a loop, but it cannot "
Expand All @@ -147,7 +185,7 @@ def assert_for_loop_inputs(inputs, iterate_names):
) from e


def assert_for_loop_results(inputs, outputs, iterate_names):
def assert_iteration_results(inputs, outputs, symbol_names):
"""The results of a for loop should have the identical type as the inputs, since they are
"passed" as inputs to the next iteration. A mismatch here may indicate that a loop carried
variable was initialized with wrong type.
Expand All @@ -157,20 +195,29 @@ def assert_for_loop_results(inputs, outputs, iterate_names):
inp_t, out_t = jax.api_util.shaped_abstractify(inp), jax.api_util.shaped_abstractify(out)
if inp_t.dtype != out_t.dtype or inp_t.shape != out_t.shape:
raise AutoGraphError(
f"The variable '{iterate_names[i]}' was initialized with the wrong type. "
f"The variable '{symbol_names[i]}' was initialized with the wrong type, or you may "
f"be trying to change its type from one iteration to the next. "
f"Expected: {out_t}, Got: {inp_t}"
)


def _call_catalyst_for(
start, stop, step, body_fn, get_state, set_state, opts, enum_start=None, array_iterable=None
start,
stop,
step,
body_fn,
get_state,
set_state,
symbol_names,
enum_start=None,
array_iterable=None,
):
"""Dispatch to a Catalyst implementation of for loops."""

# Ensure iteration arguments are properly initialized. We cannot process uninitialized
# loop carried values as we need their type information for tracing.
init_iter_args = get_state()
assert_for_loop_inputs(init_iter_args, opts["iterate_names"])
assert_iteration_inputs(init_iter_args, symbol_names)

@catalyst.for_loop(start, stop, step)
def functional_for(i, *iter_args):
Expand All @@ -194,7 +241,7 @@ def functional_for(i, *iter_args):
return get_state()

final_iter_args = functional_for(*init_iter_args)
assert_for_loop_results(init_iter_args, final_iter_args, opts["iterate_names"])
assert_iteration_results(init_iter_args, final_iter_args, symbol_names)
return final_iter_args


Expand All @@ -213,8 +260,8 @@ def for_stmt(
body_fn: Callable[[int], None],
get_state: Callable[[], Tuple],
set_state: Callable[[Tuple], None],
_symbol_names: Tuple[str],
opts: dict,
symbol_names: Tuple[str],
_opts: dict,
):
"""An implementation of the AutoGraph 'for .. in ..' statement. The interface is defined by
AutoGraph, here we merely provide an implementation of it in terms of Catalyst primitives."""
Expand All @@ -240,6 +287,7 @@ def for_stmt(
# to succeed, for example because they forgot to use a list instead of an array
fallback = False
init_state = get_state()
assert len(init_state) == len(symbol_names)

if isinstance(iteration_target, CRange):
start, stop, step = iteration_target.get_raw_range()
Expand Down Expand Up @@ -275,21 +323,34 @@ def for_stmt(

# Attempt to trace the Catalyst for loop.
if not fallback:
reference_tracers = (start, stop, step, *init_state)
num_instructions = get_program_length(reference_tracers)

try:
set_state(init_state)
results = _call_catalyst_for(
start, stop, step, body_fn, get_state, set_state, opts, enum_start, iteration_array
start,
stop,
step,
body_fn,
get_state,
set_state,
symbol_names,
enum_start,
iteration_array,
)

except Exception as e: # pylint: disable=broad-exception-caught
if catalyst.autograph_strict_conversion:
raise e

fallback = True
reset_program_to_length(reference_tracers, *num_instructions)

# pylint: disable=import-outside-toplevel
import inspect
import textwrap

fallback = True

for_loop_info = get_source_code_info(inspect.stack()[1])

if not catalyst.autograph_ignore_fallbacks:
Expand Down Expand Up @@ -319,28 +380,31 @@ def for_stmt(
set_state(results)


def _call_catalyst_while(loop_test, loop_body, get_state, set_state, _nonlocals, _symbol_names):
def _call_catalyst_while(loop_test, loop_body, get_state, set_state, symbol_names):
"""Dispatch to a Catalyst implementation of while loops."""

def _test(state):
init_iter_args = get_state()
assert_iteration_inputs(init_iter_args, symbol_names)

def test(state):
old = get_state()
set_state(state)
res = loop_test()
set_state(old)
return res

@catalyst.while_loop(_test)
def _functional_while(iter_args):
@catalyst.while_loop(test)
def functional_while(iter_args):
set_state(iter_args)
loop_body()
return get_state()

iter_inits = get_state()
iter_results = _functional_while(iter_inits)
return iter_results
final_iter_args = functional_while(init_iter_args)
assert_iteration_results(init_iter_args, final_iter_args, symbol_names)
return final_iter_args


def _call_python_while(loop_test, loop_body, get_state, _set_state, _nonlocals, _symbol_names):
def _call_python_while(loop_test, loop_body, get_state, _set_state):
"""Fallback to a Python implementation of while loops."""

while loop_test():
Expand All @@ -349,28 +413,29 @@ def _call_python_while(loop_test, loop_body, get_state, _set_state, _nonlocals,
return get_state()


def while_stmt(loop_test, loop_body, get_state, set_state, nonlocals, symbol_names):
def while_stmt(loop_test, loop_body, get_state, set_state, symbol_names, _opts):
"""An implementation of the AutoGraph 'while ..' statement. The interface is defined by
AutoGraph, here we merely provide an implementation of it in terms of Catalyst primitives."""

fallback = False
init_state = get_state()

reference_tracers = init_state
num_instructions = get_program_length(reference_tracers)

try:
results = _call_catalyst_while(
loop_test, loop_body, get_state, set_state, nonlocals, symbol_names
)
results = _call_catalyst_while(loop_test, loop_body, get_state, set_state, symbol_names)

except Exception as e: # pylint: disable=broad-exception-caught
if catalyst.autograph_strict_conversion:
raise e

fallback = True
reset_program_to_length(reference_tracers, *num_instructions)

if fallback:
set_state(init_state)
results = _call_python_while(
loop_test, loop_body, get_state, set_state, nonlocals, symbol_names
)
results = _call_python_while(loop_test, loop_body, get_state, set_state)

set_state(results)

Expand Down
31 changes: 31 additions & 0 deletions frontend/catalyst/utils/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
extend_jaxpr_stack,
)
from jax._src.source_info_util import reset_name_stack
from jax.core import find_top_trace
from pennylane.queuing import QueuingManager

from catalyst.utils.exceptions import CompileError
from catalyst.utils.jax_extras import new_dynamic_main2
Expand Down Expand Up @@ -170,6 +172,13 @@ def is_tracing(cls):
EvaluationMode.QUANTUM_COMPILATION,
]

@classmethod
def is_quantum_tracing(cls):
"""Returns true or false depending on whether the execution is currently being
traced.
"""
return cls.get_mode() == EvaluationMode.QUANTUM_COMPILATION

@classmethod
def check_modes(cls, modes, msg):
"""Asserts if the execution mode is not among the expected ``modes``.
Expand Down Expand Up @@ -213,3 +222,25 @@ def check_is_not_tracing(cls, msg):
"""
if cls.is_tracing():
raise CompileError(msg)

@classmethod
def find_jaxpr_frame(cls, *args):
"""Obtain the current JAXPR frame, in which primitives are being inserted.
Raises: CompileError
"""
cls.check_is_tracing("No JAXPR frames exist outside a tracing context.")
return find_top_trace(args).frame

@classmethod
def find_quantum_queue(cls):
"""Obtain the current quantum queuing context, in which operations are being inserted.
Raises: CompileError
"""
cls.check_is_quantum_tracing("No quantum queueing context found.")

queuing_context = QueuingManager.active_context()
assert queuing_context is not None

return queuing_context
2 changes: 1 addition & 1 deletion frontend/catalyst/utils/jax_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def pad_jaxpr_constvars(i, jaxpr):


def deduce_avals(f: Callable, args, kwargs):
"""Wrapes the callable ``f`` into a WrappedFun JAX container. Calculate input abstract values
"""Wraps the callable ``f`` into a WrappedFun JAX container. Calculate input abstract values
and output_tree promise. The promise must be called after the resulting wrapped function is
evaluated."""
flat_args, in_tree = tree_flatten((args, kwargs))
Expand Down
Loading

0 comments on commit cdff829

Please sign in to comment.