Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow tensor batching in numerical representations of Operators #2535

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,27 @@

<h3>New features since last release</h3>

* Many parametrized operations now allow arguments with a batch dimension
[(#2535)](https://github.com/PennyLaneAI/pennylane/pull/2535)

This feature is not usable as a stand-alone but a technical requirement
for future performance improvements.
Previously unsupported batched parameters are allowed for example in
standard rotation gates. The batch dimension is the last dimension
of operator matrices, eigenvalues etc. Note that the batched parameter
has to be passed as an `array` but not as a python `list` or `tuple`.

```pycon
>>> op = qml.RX(np.array([0.1, 0.2, 0.3], requires_grad=True), 0)
>>> np.round(op.matrix(), 4)
tensor([[[0.9988+0.j , 0.995 +0.j , 0.9888+0.j ],
[0. -0.05j , 0. -0.0998j, 0. -0.1494j]],

[[0. -0.05j , 0. -0.0998j, 0. -0.1494j],
[0.9988+0.j , 0.995 +0.j , 0.9888+0.j ]]], requires_grad=True)
>>> op.matrix().shape
(2, 2, 3)

* Boolean mask indexing of the parameter-shift Hessian
[(#2538)](https://github.com/PennyLaneAI/pennylane/pull/2538)

Expand Down
35 changes: 27 additions & 8 deletions pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,29 +190,39 @@ def expand_matrix(base_matrix, wires, wire_order):
# TODO[Maria]: In future we should consider making ``utils.expand`` differentiable and calling it here.
wire_order = Wires(wire_order)
n = len(wires)
interface = qml.math._multi_dispatch(base_matrix) # pylint: disable=protected-access
shape = qml.math.shape(base_matrix)
batch_dim = shape[-1] if len(shape) == 3 else None
interface = qml.math.get_interface(base_matrix) # pylint: disable=protected-access

# operator's wire positions relative to wire ordering
op_wire_pos = wire_order.indices(wires)

identity = qml.math.reshape(
qml.math.eye(2 ** len(wire_order), like=interface), [2] * len(wire_order) * 2
qml.math.eye(2 ** len(wire_order), like=interface), [2] * (len(wire_order) * 2)
)
axes = (list(range(n, 2 * n)), op_wire_pos)

# reshape op.matrix()
op_matrix_interface = qml.math.convert_like(base_matrix, identity)
mat_op_reshaped = qml.math.reshape(op_matrix_interface, [2] * n * 2)
shape = [2] * (n * 2) + [batch_dim] if batch_dim else [2] * (n * 2)
mat_op_reshaped = qml.math.reshape(op_matrix_interface, shape)
mat_tensordot = qml.math.tensordot(
mat_op_reshaped, qml.math.cast_like(identity, mat_op_reshaped), axes
)
if batch_dim:
mat_tensordot = qml.math.moveaxis(mat_tensordot, n, -1)

unused_idxs = [idx for idx in range(len(wire_order)) if idx not in op_wire_pos]
# permute matrix axes to match wire ordering
perm = op_wire_pos + unused_idxs
mat = qml.math.moveaxis(mat_tensordot, wire_order.indices(wire_order), perm)
sources = wire_order.indices(wire_order)
if batch_dim:
perm = perm + [-1]
sources = sources + [-1]

mat = qml.math.reshape(mat, (2 ** len(wire_order), 2 ** len(wire_order)))
mat = qml.math.moveaxis(mat_tensordot, sources, perm)
shape = [2 ** len(wire_order)] * 2 + [batch_dim] if batch_dim else [2 ** len(wire_order)] * 2
mat = qml.math.reshape(mat, shape)

return mat

Expand Down Expand Up @@ -688,7 +698,14 @@ def eigvals(self):
# By default, compute the eigenvalues from the matrix representation.
# This will raise a NotImplementedError if the matrix is undefined.
try:
return qml.math.linalg.eigvals(self.matrix())
mat = self.matrix()
if len(qml.math.shape(mat)) == 3:
# linalg.eigvals expects the last two dimensions to be the square dimension
# so that we have to transpose before and after the calculation.
return qml.math.transpose(
qml.math.linalg.eigvals(qml.math.transpose(mat, (2, 0, 1))), (1, 0)
)
return qml.math.linalg.eigvals(mat)
except MatrixUndefinedError as e:
raise EigvalsUndefinedError from e

Expand Down Expand Up @@ -804,7 +821,9 @@ def label(self, decimals=None, base_label=None, cache=None):

if len(qml.math.shape(params[0])) != 0:
# assume that if the first parameter is matrix-valued, there is only a single parameter
# this holds true for all current operations and templates
# this holds true for all current operations and templates unless tensor-batching
# is used
# TODO[dwierichs]: Implement a proper label for tensor-batched operators
if (
cache is None
or not isinstance(cache.get("matrices", None), list)
Expand Down Expand Up @@ -1404,7 +1423,7 @@ def matrix(self, wire_order=None):
canonical_matrix = self.compute_matrix(*self.parameters, **self.hyperparameters)

if self.inverse:
canonical_matrix = qml.math.conj(qml.math.T(canonical_matrix))
canonical_matrix = qml.math.conj(qml.math.moveaxis(canonical_matrix, 0, 1))

if wire_order is None or self.wires == Wires(wire_order):
return canonical_matrix
Expand Down
3 changes: 2 additions & 1 deletion pennylane/ops/functions/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def _matrix(tape, wire_order=None):

for op in tape.operations:
U = matrix(op, wire_order=wire_order)
unitary_matrix = qml.math.dot(U, unitary_matrix)
unitary_matrix = qml.math.tensordot(U, unitary_matrix, axes=[[1], [0]])
unitary_matrix = qml.math.moveaxis(unitary_matrix, 1, -1)

return unitary_matrix
25 changes: 25 additions & 0 deletions pennylane/ops/qubit/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,28 @@ def __contains__(self, obj):
representation using ``np.linalg.eigvals``, which fails for some tensor types that the matrix
may be cast in on backpropagation devices.
"""

supports_tensorbatching = Attribute(
[
"QubitUnitary",
"DiagonalQubitUnitary",
"RX",
"RY",
"RZ",
"PhaseShift",
"ControlledPhaseShift",
"Rot",
"MultiRZ",
"PauliRot",
"CRX",
"CRY",
"CRZ",
"CRot",
"U1",
"U2",
"U3",
"IsingXX",
"IsingYY",
"IsingZZ",
]
)
50 changes: 38 additions & 12 deletions pennylane/ops/qubit/matrix_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,27 @@ def __init__(self, *params, wires, do_queue=True):
# of wires fits the dimensions of the matrix
if not isinstance(self, ControlledQubitUnitary):
U = params[0]
U_shape = qml.math.shape(U)

dim = 2 ** len(wires)

if qml.math.shape(U) != (dim, dim):
if not (len(U_shape) in {2, 3} and U_shape[:2] == (dim, dim)):
raise ValueError(
f"Input unitary must be of shape {(dim, dim)} to act on {len(wires)} wires."
f"Input unitary must be of shape {(dim, dim)} or ({dim, dim}, batch_dim) "
f"to act on {len(wires)} wires."
)

# Check for unitarity; due to variable precision across the different ML frameworks,
# here we issue a warning to check the operation, instead of raising an error outright.
if not qml.math.is_abstract(U) and not qml.math.allclose(
qml.math.dot(U, qml.math.T(qml.math.conj(U))),
qml.math.eye(qml.math.shape(U)[0]),
atol=1e-6,
# TODO[dwierichs]: Implement unitarity check also for tensor-batched arguments U
if not (
qml.math.is_abstract(U)
or len(U_shape) == 3
or qml.math.allclose(
qml.math.dot(U, qml.math.T(qml.math.conj(U))),
qml.math.eye(dim),
atol=1e-6,
)
):
warnings.warn(
f"Operator {U}\n may not be unitary."
Expand Down Expand Up @@ -142,16 +149,25 @@ def compute_decomposition(U, wires):
"""
# Decomposes arbitrary single-qubit unitaries as Rot gates (RZ - RY - RZ format),
# or a single RZ for diagonal matrices.
if qml.math.shape(U) == (2, 2):
shape = qml.math.shape(U)
if shape == (2, 2):
return qml.transforms.decompositions.zyz_decomposition(U, Wires(wires)[0])

if qml.math.shape(U) == (4, 4):
if shape == (4, 4):
return qml.transforms.two_qubit_decomposition(U, Wires(wires))

# TODO[dwierichs]: Implement decomposition of tensor-batched unitary
if len(shape) == 3:
raise DecompositionUndefinedError(
"The decomposition of QubitUnitary does not support tensor-batching."
)

return super(QubitUnitary, QubitUnitary).compute_decomposition(U, wires=wires)

def adjoint(self):
return QubitUnitary(qml.math.T(qml.math.conj(self.matrix())), wires=self.wires)
U = self.matrix()
axis = (1, 0) if len(qml.math.shape(U)) == 2 else (1, 0, 2)
return QubitUnitary(qml.math.transpose(qml.math.conj(U), axis), wires=self.wires)

def pow(self, z):
if isinstance(z, int):
Expand Down Expand Up @@ -237,6 +253,10 @@ def __init__(
"The control wires must be different from the wires specified to apply the unitary on."
)

# TODO[dwierichs]: Implement tensor-batching
if len(qml.math.shape(params[0])) == 3:
raise NotImplementedError("ControlledQubitUnitary does not support tensor-batching.")

self._hyperparameters = {
"u_wires": wires,
"control_wires": control_wires,
Expand Down Expand Up @@ -389,6 +409,11 @@ def compute_matrix(D): # pylint: disable=arguments-differ
if not qml.math.allclose(D * qml.math.conj(D), qml.math.ones_like(D)):
raise ValueError("Operator must be unitary.")

if len(qml.math.shape(D)) == 2:
return qml.math.transpose(
qml.math.stack([qml.math.diag(_D) for _D in qml.math.T(D)]), (1, 2, 0)
)

return qml.math.diag(D)

@staticmethod
Expand Down Expand Up @@ -419,8 +444,9 @@ def compute_eigvals(D): # pylint: disable=arguments-differ
"""
D = qml.math.asarray(D)

if not qml.math.is_abstract(D) and not qml.math.allclose(
D * qml.math.conj(D), qml.math.ones_like(D)
if not (
qml.math.is_abstract(D)
or qml.math.allclose(D * qml.math.conj(D), qml.math.ones_like(D))
):
raise ValueError("Operator must be unitary.")

Expand Down Expand Up @@ -450,7 +476,7 @@ def compute_decomposition(D, wires):
[QubitUnitary(array([[1, 0], [0, 1]]), wires=[0])]

"""
return [QubitUnitary(qml.math.diag(D), wires=wires)]
return [QubitUnitary(DiagonalQubitUnitary.compute_matrix(D), wires=wires)]

def adjoint(self):
return DiagonalQubitUnitary(qml.math.conj(self.parameters[0]), wires=self.wires)
Expand Down
Loading