Skip to content

Commit

Permalink
Optimize PhaseShift, T, S gates (#5876)
Browse files Browse the repository at this point in the history
### Before submitting

Please complete the following checklist when submitting a PR:

- [x] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      test directory!

- [x] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [x] Ensure that the test suite passes, by running `make test`.

- [x] Add a new entry to the `doc/releases/changelog-dev.md` file,
summarizing the
      change, and including a link back to the PR.

- [x] The PennyLane source code conforms to
      [PEP8 standards](https://www.python.org/dev/peps/pep-0008/).
We check all of our code against [Pylint](https://www.pylint.org/).
      To lint modified files, simply `pip install pylint`, and then
      run `pylint pennylane/path/to/file.py`.

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:**
`PauliZ` has a fast implementation which relies on the sparsity of the
operator in `DefaultQubit`. Several operations have the same non-zero
matrix elements and could be similarly accelerated. One candidate is
`PhaseShift` which is abundantly used in `iterative_qpe`.

**Description of the Change:**
Port the fast-`PauliZ` to `PhaseShift` and use the implementation for
`PauliZ`, `S` and `T`.

**Benefits:**
Faster execution. For example the simple system
```
nwires = 24
dev = qml.device("default.qubit", shots=1)
@qml.qnode(dev)
def circuit(iters):
    for i in range(iters):
        qml.PhaseShift(0.1234, i % nwires)
    return qml.sample(wires=[0])
circuit(100)
```
takes 0m13.178s on `master` and 0m9.146s on `optim_apply_operations`. We
observe the same speed-up for `S` and `T`.

**Possible Drawbacks:**

**Related GitHub Issues:**
[sc-67827]

---------

Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
  • Loading branch information
vincentmr and mudit2812 committed Jul 9, 2024
1 parent 2947317 commit e63362d
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 6 deletions.
5 changes: 4 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

<h3>Improvements 🛠</h3>

* Port the fast `apply_operation` implementation of `PauliZ` to `PhaseShift`, `S` and `T`.
[(#5876)](https://github.com/PennyLaneAI/pennylane/pull/5876)

* `qml.UCCSD` now accepts an additional optional argument, `n_repeats`, which defines the number of
times the UCCSD template is repeated. This can improve the accuracy of the template by reducing
the Trotter error but would result in deeper circuits.
Expand All @@ -31,8 +34,8 @@
<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):

Yushao Chen,
Christina Lee,
William Maxwell,
Vincent Michaud-Rioux,
Erik Schultheis.
69 changes: 69 additions & 0 deletions pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,75 @@ def apply_pauliz(op: qml.Z, state, is_state_batched: bool = False, debugger=None
return math.stack([state[sl_0], state1], axis=axis)


@apply_operation.register
def apply_phaseshift(op: qml.PhaseShift, state, is_state_batched: bool = False, debugger=None, **_):
"""Apply PhaseShift to state."""

n_dim = math.ndim(state)

if n_dim >= 9 and math.get_interface(state) == "tensorflow":
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)

axis = op.wires[0] + is_state_batched

sl_0 = _get_slice(0, axis, n_dim)
sl_1 = _get_slice(1, axis, n_dim)

params = math.cast(op.parameters[0], dtype=complex)
state0 = state[sl_0]
state1 = state[sl_1]
if op.batch_size is not None and len(params) > 1:
interface = math.get_interface(state)
if interface == "torch":
params = math.array(params, like=interface)
if is_state_batched:
params = math.reshape(params, (-1,) + (1,) * (n_dim - 2))
else:
axis = axis + 1
params = math.reshape(params, (-1,) + (1,) * (n_dim - 1))
state0 = math.expand_dims(state0, 0) + math.zeros_like(params)
state1 = math.expand_dims(state1, 0)
state1 = math.multiply(math.cast(state1, dtype=complex), math.exp(1.0j * params))
state = math.stack([state0, state1], axis=axis)
if not is_state_batched and op.batch_size == 1:
state = math.stack([state], axis=0)
return state


@apply_operation.register
def apply_T(op: qml.T, state, is_state_batched: bool = False, debugger=None, **_):
"""Apply T to state."""

axis = op.wires[0] + is_state_batched
n_dim = math.ndim(state)

if n_dim >= 9 and math.get_interface(state) == "tensorflow":
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)

sl_0 = _get_slice(0, axis, n_dim)
sl_1 = _get_slice(1, axis, n_dim)

state1 = math.multiply(math.cast(state[sl_1], dtype=complex), math.exp(0.25j * np.pi))
return math.stack([state[sl_0], state1], axis=axis)


@apply_operation.register
def apply_S(op: qml.S, state, is_state_batched: bool = False, debugger=None, **_):
"""Apply S to state."""

axis = op.wires[0] + is_state_batched
n_dim = math.ndim(state)

if n_dim >= 9 and math.get_interface(state) == "tensorflow":
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)

sl_0 = _get_slice(0, axis, n_dim)
sl_1 = _get_slice(1, axis, n_dim)

state1 = math.multiply(math.cast(state[sl_1], dtype=complex), 1j)
return math.stack([state[sl_0], state1], axis=axis)


@apply_operation.register
def apply_cnot(op: qml.CNOT, state, is_state_batched: bool = False, debugger=None, **_):
"""Apply cnot gate to state."""
Expand Down
12 changes: 7 additions & 5 deletions tests/devices/qubit/test_apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ class TestBroadcasting: # pylint: disable=too-few-public-methods
@pytest.mark.parametrize("op", broadcasted_ops)
def test_broadcasted_op(self, op, method, ml_framework):
"""Tests that batched operations are applied correctly to an unbatched state."""
state = np.ones((2, 2, 2)) / np.sqrt(8)
state = np.ones((2, 2, 2), dtype=complex) / np.sqrt(8)

res = method(op, qml.math.asarray(state, like=ml_framework))
missing_wires = 3 - len(op.wires)
Expand All @@ -796,7 +796,7 @@ def test_broadcasted_op(self, op, method, ml_framework):
@pytest.mark.parametrize("op", unbroadcasted_ops)
def test_broadcasted_state(self, op, method, ml_framework):
"""Tests that unbatched operations are applied correctly to a batched state."""
state = np.ones((3, 2, 2, 2)) / np.sqrt(8)
state = np.ones((3, 2, 2, 2), dtype=complex) / np.sqrt(8)

res = method(op, qml.math.asarray(state, like=ml_framework), is_state_batched=True)
missing_wires = 3 - len(op.wires)
Expand All @@ -813,7 +813,7 @@ def test_broadcasted_op_broadcasted_state(self, op, method, ml_framework):
if method is apply_operation_tensordot:
pytest.skip("Tensordot doesn't support batched operator and batched state.")

state = np.ones((3, 2, 2, 2)) / np.sqrt(8)
state = np.ones((3, 2, 2, 2), dtype=complex) / np.sqrt(8)

res = method(op, qml.math.asarray(state, like=ml_framework), is_state_batched=True)
missing_wires = 3 - len(op.wires)
Expand Down Expand Up @@ -1226,13 +1226,15 @@ def test_with_torch(self, batch_dim):
class TestLargeTFCornerCases:
"""Test large corner cases for tensorflow."""

@pytest.mark.parametrize("op", (qml.PauliZ(8), qml.CNOT((5, 6))))
@pytest.mark.parametrize(
"op", (qml.PauliZ(8), qml.PhaseShift(1.0, 8), qml.S(8), qml.T(8), qml.CNOT((5, 6)))
)
def test_tf_large_state(self, op):
"""Tests that custom kernels that use slicing fall back to a different method when
the state has a large number of wires."""
import tensorflow as tf

state = np.zeros([2] * 10)
state = np.zeros([2] * 10, dtype=complex)
state = tf.Variable(state)
new_state = apply_operation(op, state)

Expand Down

0 comments on commit e63362d

Please sign in to comment.