Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update handling statevector data types in Python #290

Merged
merged 16 commits into from
May 10, 2022
Merged
8 changes: 8 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@

### Improvements

* Device `lightning.qubit` now accepts a datatype for a statevector.
[(#290)](https://github.com/PennyLaneAI/pennylane-lightning/pull/290)

```python
dev1 = qml.device('lightning.qubit', wires=4, c_dtype=np.complex64) # for single precision
dev2 = qml.device('lightning.qubit', wires=4, c_dtype=np.complex128) # for double precision
```

* Split matrix operations, refactor dispatch mechanisms, and add a benchmark suits.
[(#274)](https://github.com/PennyLaneAI/pennylane-lightning/pull/274)

Expand Down
2 changes: 1 addition & 1 deletion pennylane_lightning/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.24.0-dev6"
__version__ = "0.24.0-dev7"
153 changes: 45 additions & 108 deletions pennylane_lightning/lightning_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class LightningQubit(DefaultQubit):
the expectation values. Defaults to ``None`` if not specified. Setting
to ``None`` results in computing statistics like expectation values and
variances analytically.
c_dtype: Datatypes for statevector representation. Must be one of ``np.complex64`` or ``np.complex128``.
"""

name = "Lightning Qubit PennyLane plugin"
Expand All @@ -110,8 +111,16 @@ class LightningQubit(DefaultQubit):
_CPP_BINARY_AVAILABLE = True
operations = _remove_snapshot_from_operations(DefaultQubit.operations)

def __init__(self, wires, *, shots=None, batch_obs=False):
super().__init__(wires, shots=shots)
def __init__(self, wires, *, c_dtype=np.complex128, shots=None, batch_obs=False):
if c_dtype is np.complex64:
r_dtype = np.float32
self.use_csingle = True
elif c_dtype is np.complex128:
r_dtype = np.float64
self.use_csingle = False
else:
raise TypeError(f"Unsupported complex Type: {c_dtype}")
super().__init__(wires, r_dtype=r_dtype, c_dtype=c_dtype, shots=shots)
self._batch_obs = batch_obs

@classmethod
Expand Down Expand Up @@ -144,26 +153,20 @@ def apply(self, operations, rotations=None, **kwargs):
"applied on a {} device.".format(operation.name, self.short_name)
)

# Get the Type of self._state
# as the reference type
dtype = self._state.dtype

if operations:
self._pre_rotated_state = self.apply_lightning(self._state, operations, dtype=dtype)
self._pre_rotated_state = self.apply_lightning(self._state, operations)
else:
self._pre_rotated_state = self._state

if rotations:
if any(isinstance(r, QubitUnitary) for r in rotations):
super().apply(operations=[], rotations=rotations)
else:
self._state = self.apply_lightning(
np.copy(self._pre_rotated_state), rotations, dtype=dtype
)
self._state = self.apply_lightning(np.copy(self._pre_rotated_state), rotations)
else:
self._state = self._pre_rotated_state

def apply_lightning(self, state, operations, dtype=np.complex128):
def apply_lightning(self, state, operations):
"""Apply a list of operations to the state tensor.

Args:
Expand All @@ -177,14 +180,12 @@ def apply_lightning(self, state, operations, dtype=np.complex128):
"""
state_vector = np.ravel(state)

if dtype == np.complex64:
if self.use_csingle:
# use_csingle
sim = StateVectorC64(state_vector)
elif dtype == np.complex128:
else:
# self.C_DTYPE is np.complex128 by default
sim = StateVectorC128(state_vector)
else:
raise TypeError(f"Unsupported complex Type: {dtype}")

# Skip over identity operations instead of performing
# matrix multiplication with the identity.
Expand Down Expand Up @@ -264,15 +265,6 @@ def adjoint_jacobian(self, tape, starting_state=None, use_device_state=False):
UserWarning,
)

# To support np.complex64 based on the type of self._state
dtype = self._state.dtype
if dtype == np.complex64:
use_csingle = True
elif dtype == np.complex128:
use_csingle = False
else:
raise TypeError(f"Unsupported complex Type: {dtype}")

if len(tape.trainable_params) == 0:
return np.array(0)

Expand All @@ -288,14 +280,13 @@ def adjoint_jacobian(self, tape, starting_state=None, use_device_state=False):
self.execute(tape)
ket = np.ravel(self._pre_rotated_state)

if use_csingle:
if self.use_csingle:
adj = AdjointJacobianC64()
ket = ket.astype(np.complex64)
else:
adj = AdjointJacobianC128()

obs_serialized = _serialize_obs(tape, self.wire_map, use_csingle=use_csingle)
ops_serialized, use_sp = _serialize_ops(tape, self.wire_map, use_csingle=use_csingle)
obs_serialized = _serialize_obs(tape, self.wire_map, use_csingle=self.use_csingle)
ops_serialized, use_sp = _serialize_ops(tape, self.wire_map, use_csingle=self.use_csingle)

ops_serialized = adj.create_ops_list(*ops_serialized)

Expand All @@ -306,7 +297,7 @@ def adjoint_jacobian(self, tape, starting_state=None, use_device_state=False):
trainable_params if not use_sp else [i - 1 for i in trainable_params[first_elem:]]
) # exclude first index if explicitly setting sv

state_vector = StateVectorC64(ket) if use_csingle else StateVectorC128(ket)
state_vector = StateVectorC64(ket) if self.use_csingle else StateVectorC128(ket)

# If requested batching over observables, chunk into OMP_NUM_THREADS sized chunks.
# This will allow use of Lightning with adjoint for large-qubit numbers AND large
Expand Down Expand Up @@ -368,14 +359,10 @@ def compute_vjp(self, dy, jac, num=None):
if math.allclose(dy, 0):
return math.convert_like(np.zeros([num_params]), dy)

# To support np.complex64 based on the type of self._state
dtype = self._state.dtype
if dtype == np.complex64:
if self.use_csingle:
VJP = VectorJacobianProductC64()
elif dtype == np.complex128:
VJP = VectorJacobianProductC128()
else:
raise TypeError(f"Unsupported complex Type: {dtype}")
VJP = VectorJacobianProductC128()

vjp_tensor = VJP.compute_vjp_from_jac(
math.reshape(jac, [-1]),
Expand Down Expand Up @@ -416,16 +403,7 @@ def vjp(self, tape, dy, starting_state=None, use_device_state=False):
if math.allclose(dy, 0):
return lambda _: math.convert_like(np.zeros([num_params]), dy)

# To support np.complex64 based on the type of self._state
dtype = self._state.dtype
if dtype == np.complex64:
use_csingle = True
elif dtype == np.complex128:
use_csingle = False
else:
raise TypeError(f"Unsupported complex Type: {dtype}")

V = VectorJacobianProductC64() if use_csingle else VectorJacobianProductC128()
V = VectorJacobianProductC64() if self.use_csingle else VectorJacobianProductC128()

fn = V.vjp_fn(math.reshape(dy, [-1]), tape.num_params)

Expand All @@ -442,11 +420,10 @@ def processing_fn(tape):
self.execute(tape)
ket = np.ravel(self._pre_rotated_state)

if use_csingle:
ket = ket.astype(np.complex64)

obs_serialized = _serialize_obs(tape, self.wire_map, use_csingle=use_csingle)
ops_serialized, use_sp = _serialize_ops(tape, self.wire_map, use_csingle=use_csingle)
obs_serialized = _serialize_obs(tape, self.wire_map, use_csingle=self.use_csingle)
ops_serialized, use_sp = _serialize_ops(
tape, self.wire_map, use_csingle=self.use_csingle
)

ops_serialized = V.create_ops_list(*ops_serialized)

Expand All @@ -457,7 +434,7 @@ def processing_fn(tape):
trainable_params if not use_sp else [i - 1 for i in trainable_params[first_elem:]]
) # exclude first index if explicitly setting sv

state_vector = StateVectorC64(ket) if use_csingle else StateVectorC128(ket)
state_vector = StateVectorC64(ket) if self.use_csingle else StateVectorC128(ket)

return fn(state_vector, obs_serialized, ops_serialized, tp_shift)

Expand Down Expand Up @@ -543,21 +520,10 @@ def probability(self, wires=None, shot_range=None, bin_size=None):

# To support np.complex64 based on the type of self._state
dtype = self._state.dtype
if dtype == np.complex64:
use_csingle = True
elif dtype == np.complex128:
use_csingle = False
else:
raise TypeError(f"Unsupported complex Type: {dtype}")

# Initialization of state
ket = np.ravel(self._state)

if use_csingle:
ket = ket.astype(np.complex64)

state_vector = StateVectorC64(ket) if use_csingle else StateVectorC128(ket)
M = MeasuresC64(state_vector) if use_csingle else MeasuresC128(state_vector)
state_vector = StateVectorC64(ket) if self.use_csingle else StateVectorC128(ket)
M = MeasuresC64(state_vector) if self.use_csingle else MeasuresC128(state_vector)

return M.probs(device_wires)

Expand All @@ -568,23 +534,11 @@ def generate_samples(self):
array[int]: array of samples in binary representation with shape ``(dev.shots, dev.num_wires)``
"""

# To support np.complex64 based on the type of self._state
dtype = self._state.dtype
if dtype == np.complex64:
use_csingle = True
elif dtype == np.complex128:
use_csingle = False
else:
raise TypeError(f"Unsupported complex Type: {dtype}")

# Initialization of state
ket = np.ravel(self._state)

if use_csingle:
ket = ket.astype(np.complex64)

state_vector = StateVectorC64(ket) if use_csingle else StateVectorC128(ket)
M = MeasuresC64(state_vector) if use_csingle else MeasuresC128(state_vector)
state_vector = StateVectorC64(ket) if self.use_csingle else StateVectorC128(ket)
M = MeasuresC64(state_vector) if self.use_csingle else MeasuresC128(state_vector)

return M.generate_samples(len(self.wires), self.shots).astype(int)

Expand Down Expand Up @@ -617,23 +571,11 @@ def expval(self, observable, shot_range=None, bin_size=None):
samples = self.sample(observable, shot_range=shot_range, bin_size=bin_size)
return np.squeeze(np.mean(samples, axis=0))

# To support np.complex64 based on the type of self._state
dtype = self._state.dtype
if dtype == np.complex64:
use_csingle = True
elif dtype == np.complex128:
use_csingle = False
else:
raise TypeError(f"Unsupported complex Type: {dtype}")

# Initialization of state
ket = np.ravel(self._pre_rotated_state)

if use_csingle:
ket = ket.astype(np.complex64)

state_vector = StateVectorC64(ket) if use_csingle else StateVectorC128(ket)
M = MeasuresC64(state_vector) if use_csingle else MeasuresC128(state_vector)
state_vector = StateVectorC64(ket) if self.use_csingle else StateVectorC128(ket)
M = MeasuresC64(state_vector) if self.use_csingle else MeasuresC128(state_vector)

# translate to wire labels used by device
observable_wires = self.map_wires(observable.wires)
Expand Down Expand Up @@ -667,23 +609,11 @@ def var(self, observable, shot_range=None, bin_size=None):
samples = self.sample(observable, shot_range=shot_range, bin_size=bin_size)
return np.squeeze(np.var(samples, axis=0))

# To support np.complex64 based on the type of self._state
dtype = self._state.dtype
if dtype == np.complex64:
use_csingle = True
elif dtype == np.complex128:
use_csingle = False
else:
raise TypeError(f"Unsupported complex Type: {dtype}")

# Initialization of state
ket = np.ravel(self._pre_rotated_state)

if use_csingle:
ket = ket.astype(np.complex64)

state_vector = StateVectorC64(ket) if use_csingle else StateVectorC128(ket)
M = MeasuresC64(state_vector) if use_csingle else MeasuresC128(state_vector)
state_vector = StateVectorC64(ket) if self.use_csingle else StateVectorC128(ket)
M = MeasuresC64(state_vector) if self.use_csingle else MeasuresC128(state_vector)

# translate to wire labels used by device
observable_wires = self.map_wires(observable.wires)
Expand All @@ -702,12 +632,19 @@ class LightningQubit(DefaultQubit): # pragma: no cover
_CPP_BINARY_AVAILABLE = False
operations = _remove_snapshot_from_operations(DefaultQubit.operations)

def __init__(self, *args, **kwargs):
def __init__(self, wires, *, c_dtype=np.complex128, **kwargs):
warn(
"Pre-compiled binaries for lightning.qubit are not available. Falling back to "
"using the Python-based default.qubit implementation. To manually compile from "
"source, follow the instructions at "
"https://pennylane-lightning.readthedocs.io/en/latest/installation.html.",
UserWarning,
)
super().__init__(*args, **kwargs)

if c_dtype is np.complex64:
r_dtype = np.float32
elif c_dtype is np.complex128:
r_dtype = np.float64
else:
raise TypeError(f"Unsupported complex Type: {c_dtype}")
super().__init__(wires, r_dtype=r_dtype, c_dtype=c_dtype, **kwargs)
18 changes: 9 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@ def n_subsystems(request):
return request.param


@pytest.fixture(scope="function")
def qubit_device_1_wire():
return LightningQubit(wires=1)
@pytest.fixture(scope="function", params=[np.complex64, np.complex128])
def qubit_device_1_wire(request):
return LightningQubit(wires=1, c_dtype=request.param)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of inserting datatype directly to a statevector, we now set the device datatypes. Test suits are updated accordingly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job on adding this change



@pytest.fixture(scope="function")
def qubit_device_2_wires():
return LightningQubit(wires=2)
@pytest.fixture(scope="function", params=[np.complex64, np.complex128])
def qubit_device_2_wires(request):
return LightningQubit(wires=2, c_dtype=request.param)


@pytest.fixture(scope="function")
def qubit_device_3_wires():
return LightningQubit(wires=3)
@pytest.fixture(scope="function", params=[np.complex64, np.complex128])
def qubit_device_3_wires(request):
return LightningQubit(wires=3, c_dtype=request.param)
Loading