Skip to content

Commit

Permalink
Add bitsize field to Cirq-FT Registers (#6286)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanujkhattar authored Sep 20, 2023
1 parent 188bb94 commit 8e4e7d1
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 26 deletions.
4 changes: 3 additions & 1 deletion cirq-ft/cirq_ft/algos/swap_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]:

@cached_property
def target_registers(self) -> Tuple[infra.Register, ...]:
return (infra.Register('target', (self.n_target_registers, self.target_bitsize)),)
return (
infra.Register('target', bitsize=self.target_bitsize, shape=self.n_target_registers),
)

@cached_property
def registers(self) -> infra.Registers:
Expand Down
6 changes: 3 additions & 3 deletions cirq-ft/cirq_ft/infra/gate_with_registers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"source": [
"## `Registers`\n",
"\n",
"`Register` objects have a name and a shape. `Registers` is an ordered collection of `Register` with some helpful methods."
"`Register` objects have a name, a bitsize and a shape. `Registers` is an ordered collection of `Register` with some helpful methods."
]
},
{
Expand All @@ -51,8 +51,8 @@
"source": [
"from cirq_ft import Register, Registers, infra\n",
"\n",
"control_reg = Register(name='control', shape=(2,))\n",
"target_reg = Register(name='target', shape=(3,))\n",
"control_reg = Register(name='control', bitsize=2)\n",
"target_reg = Register(name='target', bitsize=3)\n",
"control_reg, target_reg"
]
},
Expand Down
44 changes: 28 additions & 16 deletions cirq-ft/cirq_ft/infra/gate_with_registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ class Register:
"""

name: str
bitsize: int
shape: Tuple[int, ...] = attr.field(
converter=lambda v: (v,) if isinstance(v, int) else tuple(v)
converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=()
)

def all_idxs(self) -> Iterable[Tuple[int, ...]]:
Expand All @@ -45,15 +46,14 @@ def total_bits(self) -> int:
This is the product of each of the dimensions in `shape`.
"""
return int(np.product(self.shape))
return self.bitsize * int(np.product(self.shape))

def __repr__(self):
return f'cirq_ft.Register(name="{self.name}", shape={self.shape})'
return f'cirq_ft.Register(name="{self.name}", bitsize={self.bitsize}, shape={self.shape})'


def total_bits(registers: Iterable[Register]) -> int:
"""Sum of `reg.total_bits()` for each register `reg` in input `registers`."""

return sum(reg.total_bits() for reg in registers)


Expand All @@ -65,7 +65,9 @@ def split_qubits(
qubit_regs = {}
base = 0
for reg in registers:
qubit_regs[reg.name] = np.array(qubits[base : base + reg.total_bits()]).reshape(reg.shape)
qubit_regs[reg.name] = np.array(qubits[base : base + reg.total_bits()]).reshape(
reg.shape + (reg.bitsize,)
)
base += reg.total_bits()
return qubit_regs

Expand All @@ -82,9 +84,10 @@ def merge_qubits(
raise ValueError(f"All qubit registers must be present. {reg.name} not in qubit_regs")
qubits = qubit_regs[reg.name]
qubits = np.array([qubits] if isinstance(qubits, cirq.Qid) else qubits)
if qubits.shape != reg.shape:
full_shape = reg.shape + (reg.bitsize,)
if qubits.shape != full_shape:
raise ValueError(
f'{reg.name} register must of shape {reg.shape} but is of shape {qubits.shape}'
f'{reg.name} register must of shape {full_shape} but is of shape {qubits.shape}'
)
ret += qubits.flatten().tolist()
return ret
Expand All @@ -94,13 +97,16 @@ def get_named_qubits(registers: Iterable[Register]) -> Dict[str, NDArray[cirq.Qi
"""Returns a dictionary of appropriately shaped named qubit registers for input `registers`."""

def _qubit_array(reg: Register):
qubits = np.empty(reg.shape, dtype=object)
qubits = np.empty(reg.shape + (reg.bitsize,), dtype=object)
for ii in reg.all_idxs():
qubits[ii] = cirq.NamedQubit(f'{reg.name}[{", ".join(str(i) for i in ii)}]')
for j in range(reg.bitsize):
prefix = "" if not ii else f'[{", ".join(str(i) for i in ii)}]'
suffix = "" if reg.bitsize == 1 else f"[{j}]"
qubits[ii + (j,)] = cirq.NamedQubit(reg.name + prefix + suffix)
return qubits

def _qubits_for_reg(reg: Register):
if len(reg.shape) > 1:
if len(reg.shape) > 0:
return _qubit_array(reg)

return np.array(
Expand Down Expand Up @@ -130,8 +136,8 @@ def __repr__(self):
return f'cirq_ft.Registers({self._registers})'

@classmethod
def build(cls, **registers: Union[int, Tuple[int, ...]]) -> 'Registers':
return cls(Register(name=k, shape=v) for k, v in registers.items())
def build(cls, **registers: int) -> 'Registers':
return cls(Register(name=k, bitsize=v) for k, v in registers.items())

@overload
def __getitem__(self, key: int) -> Register:
Expand Down Expand Up @@ -216,23 +222,29 @@ class SelectionRegister(Register):
>>> assert len(flat_indices) == N * M * L
"""

name: str
bitsize: int
iteration_length: int = attr.field()
shape: Tuple[int, ...] = attr.field(
converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=()
)

@iteration_length.default
def _default_iteration_length(self):
return 2 ** self.shape[0]
return 2**self.bitsize

@iteration_length.validator
def validate_iteration_length(self, attribute, value):
if len(self.shape) != 1:
if len(self.shape) != 0:
raise ValueError(f'Selection register {self.name} should be flat. Found {self.shape=}')
if not (0 <= value <= 2 ** self.shape[0]):
raise ValueError(f'iteration length must be in range [0, 2^{self.shape[0]}]')
if not (0 <= value <= 2**self.bitsize):
raise ValueError(f'iteration length must be in range [0, 2^{self.bitsize}]')

def __repr__(self) -> str:
return (
f'cirq_ft.SelectionRegister('
f'name="{self.name}", '
f'bitsize={self.bitsize}, '
f'shape={self.shape}, '
f'iteration_length={self.iteration_length})'
)
Expand Down
15 changes: 9 additions & 6 deletions cirq-ft/cirq_ft/infra/gate_with_registers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@


def test_register():
r = cirq_ft.Register("my_reg", 5)
assert r.shape == (5,)
r = cirq_ft.Register("my_reg", 5, (1, 2))
assert r.bitsize == 5
assert r.shape == (1, 2)


def test_registers():
Expand Down Expand Up @@ -103,12 +104,12 @@ def test_selection_registers_consistent():
_ = cirq_ft.SelectionRegister('a', 3, 10)

with pytest.raises(ValueError, match="should be flat"):
_ = cirq_ft.SelectionRegister('a', (3, 5), 5)
_ = cirq_ft.SelectionRegister('a', bitsize=1, shape=(3, 5), iteration_length=5)

selection_reg = cirq_ft.Registers(
[
cirq_ft.SelectionRegister('n', shape=3, iteration_length=5),
cirq_ft.SelectionRegister('m', shape=4, iteration_length=12),
cirq_ft.SelectionRegister('n', bitsize=3, iteration_length=5),
cirq_ft.SelectionRegister('m', bitsize=4, iteration_length=12),
]
)
assert selection_reg[0] == cirq_ft.SelectionRegister('n', 3, 5)
Expand All @@ -122,7 +123,9 @@ def test_registers_getitem_raises():
with pytest.raises(IndexError, match="must be of the type"):
_ = g[2.5]

selection_reg = cirq_ft.Registers([cirq_ft.SelectionRegister('n', shape=3, iteration_length=5)])
selection_reg = cirq_ft.Registers(
[cirq_ft.SelectionRegister('n', bitsize=3, iteration_length=5)]
)
with pytest.raises(IndexError, match='must be of the type'):
_ = selection_reg[2.5]

Expand Down

0 comments on commit 8e4e7d1

Please sign in to comment.