Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] BasisState does not work with jax.jit #6006

Closed
1 task done
isaacdevlugt opened this issue Jul 17, 2024 · 1 comment
Closed
1 task done

[BUG] BasisState does not work with jax.jit #6006

isaacdevlugt opened this issue Jul 17, 2024 · 1 comment
Labels
bug 🐛 Something isn't working

Comments

@isaacdevlugt
Copy link
Contributor

isaacdevlugt commented Jul 17, 2024

Expected behavior

I expect that BasisState, when used with jit, should behave nicely like BasisStatePreparation does because BasisState decomposes to BasisStatePreparation.

Actual behavior

BasisState isn't jit friendly with default.qubit

Additional information

No response

Source code

import pennylane as qml
import jax
from jax import numpy as jnp

dev = qml.device("default.qubit", wires=3)

@jax.jit
@qml.qnode(dev)
def circuit_BasisState(n):
    qml.BasisState(n, wires) # doesn't work
    #qml.BasisStatePreparation(n, wires) # works
    return qml.state()

n = jnp.array([0, 1, 1])
print(qml.draw(circuit_BasisState)(n))

Tracebacks

---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[64], line 19
     16 n = jnp.array([0, 1, 1])
     18 #print(qml.draw(circuit_BasisEmbedding)(n))
---> 19 print(qml.draw(circuit_BasisState)(n))

File ~/Documents/pennylane/pennylane/drawer/draw.py:304, in draw.<locals>.wrapper(*args, **kwargs)
    302 @wraps(qnode)
    303 def wrapper(*args, **kwargs):
--> 304     tape = qml.tape.make_qscript(qnode)(*args, **kwargs)
    306     if wire_order:
    307         _wire_order = wire_order

File ~/Documents/pennylane/pennylane/tape/qscript.py:1298, in make_qscript.<locals>.wrapper(*args, **kwargs)
   1296 def wrapper(*args, **kwargs):
   1297     with AnnotatedQueue() as q:
-> 1298         fn(*args, **kwargs)
   1300     return QuantumScript.from_queue(q, shots)

    [... skipping hidden 12 frame]

File ~/Documents/pennylane/pennylane/workflow/qnode.py:1164, in QNode.__call__(self, *args, **kwargs)
   1162 if qml.capture.enabled():
   1163     return qml.capture.qnode_call(self, *args, **kwargs)
-> 1164 return self._impl_call(*args, **kwargs)

File ~/Documents/pennylane/pennylane/workflow/qnode.py:1150, in QNode._impl_call(self, *args, **kwargs)
   1147 self._update_gradient_fn(shots=override_shots, tape=self._tape)
   1149 try:
-> 1150     res = self._execution_component(args, kwargs, override_shots=override_shots)
   1151 finally:
   1152     if old_interface == "auto":

File ~/Documents/pennylane/pennylane/workflow/qnode.py:1103, in QNode._execution_component(self, args, kwargs, override_shots)
   1100 _prune_dynamic_transform(full_transform_program, inner_transform_program)
   1102 # pylint: disable=unexpected-keyword-arg
-> 1103 res = qml.execute(
   1104     (self._tape,),
   1105     device=self.device,
   1106     gradient_fn=self.gradient_fn,
   1107     interface=self.interface,
   1108     transform_program=full_transform_program,
   1109     inner_transform=inner_transform_program,
   1110     config=config,
   1111     gradient_kwargs=self.gradient_kwargs,
   1112     override_shots=override_shots,
   1113     **self.execute_kwargs,
   1114 )
   1115 res = res[0]
   1117 # convert result to the interface in case the qfunc has no parameters

File ~/Documents/pennylane/pennylane/workflow/execution.py:666, in execute(tapes, device, gradient_fn, interface, transform_program, inner_transform, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform, device_vjp, mcm_config)
    664 # Exiting early if we do not need to deal with an interface boundary
    665 if no_interface_boundary_required:
--> 666     results = inner_execute(tapes)
    667     return post_processing(results)
    669 _grad_on_execution = False

File ~/Documents/pennylane/pennylane/workflow/execution.py:316, in _make_inner_execute.<locals>.inner_execute(tapes, **_)
    313     transformed_tapes = tuple(expand_fn(t) for t in transformed_tapes)
    315 if transformed_tapes:
--> 316     results = device_execution(transformed_tapes)
    317 else:
    318     results = ()

File ~/Documents/pennylane/pennylane/devices/modifiers/simulator_tracking.py:30, in _track_execute.<locals>.execute(self, circuits, execution_config)
     28 @wraps(untracked_execute)
     29 def execute(self, circuits, execution_config=DefaultExecutionConfig):
---> 30     results = untracked_execute(self, circuits, execution_config)
     31     if isinstance(circuits, QuantumScript):
     32         batch = (circuits,)

File ~/Documents/pennylane/pennylane/devices/modifiers/single_tape_support.py:32, in _make_execute.<locals>.execute(self, circuits, execution_config)
     30     is_single_circuit = True
     31     circuits = (circuits,)
---> 32 results = batch_execute(self, circuits, execution_config)
     33 return results[0] if is_single_circuit else results

File ~/Documents/pennylane/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/Documents/pennylane/pennylane/devices/default_qubit.py:597, in DefaultQubit.execute(self, circuits, execution_config)
    594 prng_keys = [self.get_prng_keys()[0] for _ in range(len(circuits))]
    596 if max_workers is None:
--> 597     return tuple(
    598         _simulate_wrapper(
    599             c,
    600             {
    601                 "rng": self._rng,
    602                 "debugger": self._debugger,
    603                 "interface": interface,
    604                 "state_cache": self._state_cache,
    605                 "prng_key": _key,
    606                 "mcm_method": execution_config.mcm_config.mcm_method,
    607                 "postselect_mode": execution_config.mcm_config.postselect_mode,
    608             },
    609         )
    610         for c, _key in zip(circuits, prng_keys)
    611     )
    613 vanilla_circuits = convert_to_numpy_parameters(circuits)[0]
    614 seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))

File ~/Documents/pennylane/pennylane/devices/default_qubit.py:598, in <genexpr>(.0)
    594 prng_keys = [self.get_prng_keys()[0] for _ in range(len(circuits))]
    596 if max_workers is None:
    597     return tuple(
--> 598         _simulate_wrapper(
    599             c,
    600             {
    601                 "rng": self._rng,
    602                 "debugger": self._debugger,
    603                 "interface": interface,
    604                 "state_cache": self._state_cache,
    605                 "prng_key": _key,
    606                 "mcm_method": execution_config.mcm_config.mcm_method,
    607                 "postselect_mode": execution_config.mcm_config.postselect_mode,
    608             },
    609         )
    610         for c, _key in zip(circuits, prng_keys)
    611     )
    613 vanilla_circuits = convert_to_numpy_parameters(circuits)[0]
    614 seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))

File ~/Documents/pennylane/pennylane/devices/default_qubit.py:863, in _simulate_wrapper(circuit, kwargs)
    862 def _simulate_wrapper(circuit, kwargs):
--> 863     return simulate(circuit, **kwargs)

File ~/Documents/pennylane/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/Documents/pennylane/pennylane/devices/qubit/simulate.py:354, in simulate(circuit, debugger, state_cache, **execution_kwargs)
    351     return tuple(results)
    353 ops_key, meas_key = jax_random_split(prng_key)
--> 354 state, is_state_batched = get_final_state(
    355     circuit, debugger=debugger, prng_key=ops_key, **execution_kwargs
    356 )
    357 if state_cache is not None:
    358     state_cache[circuit.hash] = state

File ~/Documents/pennylane/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/Documents/pennylane/pennylane/devices/qubit/simulate.py:165, in get_final_state(circuit, debugger, **execution_kwargs)
    162 if len(circuit) > 0 and isinstance(circuit[0], qml.operation.StatePrepBase):
    163     prep = circuit[0]
--> 165 state = create_initial_state(sorted(circuit.op_wires), prep, like=INTERFACE_TO_LIKE[interface])
    167 # initial state is batched only if the state preparation (if it exists) is batched
    168 is_state_batched = bool(prep and prep.batch_size is not None)

File ~/Documents/pennylane/pennylane/devices/qubit/initialize_state.py:46, in create_initial_state(wires, prep_operation, like)
     43     state[(0,) * num_wires] = 1
     44     return qml.math.asarray(state, like=like)
---> 46 return qml.math.asarray(prep_operation.state_vector(wire_order=list(wires)), like=like)

File ~/Documents/pennylane/pennylane/ops/qubit/state_preparation.py:104, in BasisState.state_vector(self, wire_order)
    102 """Returns a statevector of shape ``(2,) * num_wires``."""
    103 prep_vals = self.parameters[0]
--> 104 if any(i not in [0, 1] for i in prep_vals):
    105     raise ValueError("BasisState parameter must consist of 0 or 1 integers.")
    107 if (num_wires := len(self.wires)) != len(prep_vals):

File ~/Documents/pennylane/pennylane/ops/qubit/state_preparation.py:104, in <genexpr>(.0)
    102 """Returns a statevector of shape ``(2,) * num_wires``."""
    103 prep_vals = self.parameters[0]
--> 104 if any(i not in [0, 1] for i in prep_vals):
    105     raise ValueError("BasisState parameter must consist of 0 or 1 integers.")
    107 if (num_wires := len(self.wires)) != len(prep_vals):

    [... skipping hidden 1 frame]

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/core.py:1510, in concretization_function_error.<locals>.error(self, arg)
   1509 def error(self, arg):
-> 1510   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function circuit_BasisState at /var/folders/cn/h46l05vn2qd9c7ldxf0g905c0000gq/T/ipykernel_53736/2286483519.py:9 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

System information

Name: PennyLane
Version: 0.37.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: [/Users/isaac/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages)
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning

Platform info:           macOS-14.5-arm64-arm-64bit
Python version:          3.11.8
Numpy version:           1.26.4
Scipy version:           1.12.0
Installed devices:
- default.clifford (PennyLane-0.38.0.dev0)
- default.gaussian (PennyLane-0.38.0.dev0)
- default.mixed (PennyLane-0.38.0.dev0)
- default.qubit (PennyLane-0.38.0.dev0)
- default.qubit.autograd (PennyLane-0.38.0.dev0)
- default.qubit.jax (PennyLane-0.38.0.dev0)
- default.qubit.legacy (PennyLane-0.38.0.dev0)
- default.qubit.tf (PennyLane-0.38.0.dev0)
- default.qubit.torch (PennyLane-0.38.0.dev0)
- default.qutrit (PennyLane-0.38.0.dev0)
- default.qutrit.mixed (PennyLane-0.38.0.dev0)
- default.tensor (PennyLane-0.38.0.dev0)
- null.qubit (PennyLane-0.38.0.dev0)
- lightning.qubit (PennyLane_Lightning-0.37.0)
- nvidia.custatevec (PennyLane-Catalyst-0.7.0)
- nvidia.cutensornet (PennyLane-Catalyst-0.7.0)
- oqc.cloud (PennyLane-Catalyst-0.7.0)
- softwareq.qpp (PennyLane-Catalyst-0.7.0)

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@isaacdevlugt isaacdevlugt added the bug 🐛 Something isn't working label Jul 17, 2024
KetpuntoG added a commit that referenced this issue Aug 21, 2024
This PR complete part of this story:
[[sc-68521](https://app.shortcut.com/xanaduai/story/68521)]

Goal: `BasisEmbedding` is an alias of `BasisState`. This way, we don't
have duplicate code that does the same thing.
In unifying this, I have had to modify some tests due to:
- `BasisEmbedding` and `BasisState` throw errors such as "incorrect
length" with different messages. Now it will always be the same. (test
modified for this reason: `test_default_qubit_legacy.py`,
`test_default_qubit_tf.py`
`test_default_qubit_torch.py`, `test_state_prep.py`,
`test_all_singles_doubles.py` and test_uccsd`)

- In `BasisEmbedding`, errors were thrown in `__init__` while in
BasisState in `state_vector`. Now they are unified in `__init__`. For
this reason, there were tests where the operator was not initialized
correctly but no error was thrown since `state_vector` was not being
called but now they are detected. To correct this, I have modified the
tests: `test_qscript.py`, `test_state_prep.py`,

- Now `BasisState` does not decompose `BasisStatePreparation` since we
are going to deprecate it. This causes the number of gates after
expanding to be affected. In this case I had to modify some test in
`test_tape.py`.

This PR also solves:

- [issue 6008](#6008)
- [issue 6007](#6007)
- [issue 6006](#6006)

---------

Co-authored-by: Isaac De Vlugt <34751083+isaacdevlugt@users.noreply.github.com>
Co-authored-by: soranjh <40344468+soranjh@users.noreply.github.com>
Co-authored-by: Utkarsh <utkarshazad98@gmail.com>
@DSGuala
Copy link
Contributor

DSGuala commented Sep 6, 2024

Closing this issue as it was resolved in #6021

@DSGuala DSGuala closed this as completed Sep 6, 2024
mudit2812 pushed a commit that referenced this issue Sep 10, 2024
This PR complete part of this story:
[[sc-68521](https://app.shortcut.com/xanaduai/story/68521)]

Goal: `BasisEmbedding` is an alias of `BasisState`. This way, we don't
have duplicate code that does the same thing.
In unifying this, I have had to modify some tests due to:
- `BasisEmbedding` and `BasisState` throw errors such as "incorrect
length" with different messages. Now it will always be the same. (test
modified for this reason: `test_default_qubit_legacy.py`,
`test_default_qubit_tf.py`
`test_default_qubit_torch.py`, `test_state_prep.py`,
`test_all_singles_doubles.py` and test_uccsd`)

- In `BasisEmbedding`, errors were thrown in `__init__` while in
BasisState in `state_vector`. Now they are unified in `__init__`. For
this reason, there were tests where the operator was not initialized
correctly but no error was thrown since `state_vector` was not being
called but now they are detected. To correct this, I have modified the
tests: `test_qscript.py`, `test_state_prep.py`,

- Now `BasisState` does not decompose `BasisStatePreparation` since we
are going to deprecate it. This causes the number of gates after
expanding to be affected. In this case I had to modify some test in
`test_tape.py`.

This PR also solves:

- [issue 6008](#6008)
- [issue 6007](#6007)
- [issue 6006](#6006)

---------

Co-authored-by: Isaac De Vlugt <34751083+isaacdevlugt@users.noreply.github.com>
Co-authored-by: soranjh <40344468+soranjh@users.noreply.github.com>
Co-authored-by: Utkarsh <utkarshazad98@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants