diff --git a/cirq-ft/cirq_ft/algos/swap_network.py b/cirq-ft/cirq_ft/algos/swap_network.py index 279ab33be38..1dd5ca88879 100644 --- a/cirq-ft/cirq_ft/algos/swap_network.py +++ b/cirq-ft/cirq_ft/algos/swap_network.py @@ -152,7 +152,9 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: @cached_property def target_registers(self) -> Tuple[infra.Register, ...]: - return (infra.Register('target', (self.n_target_registers, self.target_bitsize)),) + return ( + infra.Register('target', bitsize=self.target_bitsize, shape=self.n_target_registers), + ) @cached_property def registers(self) -> infra.Registers: diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb b/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb index 6afb6d49d4f..ef72a1e1479 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb +++ b/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb @@ -39,7 +39,7 @@ "source": [ "## `Registers`\n", "\n", - "`Register` objects have a name and a shape. `Registers` is an ordered collection of `Register` with some helpful methods." + "`Register` objects have a name, a bitsize and a shape. `Registers` is an ordered collection of `Register` with some helpful methods." ] }, { @@ -51,8 +51,8 @@ "source": [ "from cirq_ft import Register, Registers, infra\n", "\n", - "control_reg = Register(name='control', shape=(2,))\n", - "target_reg = Register(name='target', shape=(3,))\n", + "control_reg = Register(name='control', bitsize=2)\n", + "target_reg = Register(name='target', bitsize=3)\n", "control_reg, target_reg" ] }, diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers.py b/cirq-ft/cirq_ft/infra/gate_with_registers.py index 7139b66a65a..b4567591c67 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers.py +++ b/cirq-ft/cirq_ft/infra/gate_with_registers.py @@ -32,8 +32,9 @@ class Register: """ name: str + bitsize: int shape: Tuple[int, ...] = attr.field( - converter=lambda v: (v,) if isinstance(v, int) else tuple(v) + converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=() ) def all_idxs(self) -> Iterable[Tuple[int, ...]]: @@ -45,15 +46,14 @@ def total_bits(self) -> int: This is the product of each of the dimensions in `shape`. """ - return int(np.product(self.shape)) + return self.bitsize * int(np.product(self.shape)) def __repr__(self): - return f'cirq_ft.Register(name="{self.name}", shape={self.shape})' + return f'cirq_ft.Register(name="{self.name}", bitsize={self.bitsize}, shape={self.shape})' def total_bits(registers: Iterable[Register]) -> int: """Sum of `reg.total_bits()` for each register `reg` in input `registers`.""" - return sum(reg.total_bits() for reg in registers) @@ -65,7 +65,9 @@ def split_qubits( qubit_regs = {} base = 0 for reg in registers: - qubit_regs[reg.name] = np.array(qubits[base : base + reg.total_bits()]).reshape(reg.shape) + qubit_regs[reg.name] = np.array(qubits[base : base + reg.total_bits()]).reshape( + reg.shape + (reg.bitsize,) + ) base += reg.total_bits() return qubit_regs @@ -82,9 +84,10 @@ def merge_qubits( raise ValueError(f"All qubit registers must be present. {reg.name} not in qubit_regs") qubits = qubit_regs[reg.name] qubits = np.array([qubits] if isinstance(qubits, cirq.Qid) else qubits) - if qubits.shape != reg.shape: + full_shape = reg.shape + (reg.bitsize,) + if qubits.shape != full_shape: raise ValueError( - f'{reg.name} register must of shape {reg.shape} but is of shape {qubits.shape}' + f'{reg.name} register must of shape {full_shape} but is of shape {qubits.shape}' ) ret += qubits.flatten().tolist() return ret @@ -94,13 +97,16 @@ def get_named_qubits(registers: Iterable[Register]) -> Dict[str, NDArray[cirq.Qi """Returns a dictionary of appropriately shaped named qubit registers for input `registers`.""" def _qubit_array(reg: Register): - qubits = np.empty(reg.shape, dtype=object) + qubits = np.empty(reg.shape + (reg.bitsize,), dtype=object) for ii in reg.all_idxs(): - qubits[ii] = cirq.NamedQubit(f'{reg.name}[{", ".join(str(i) for i in ii)}]') + for j in range(reg.bitsize): + prefix = "" if not ii else f'[{", ".join(str(i) for i in ii)}]' + suffix = "" if reg.bitsize == 1 else f"[{j}]" + qubits[ii + (j,)] = cirq.NamedQubit(reg.name + prefix + suffix) return qubits def _qubits_for_reg(reg: Register): - if len(reg.shape) > 1: + if len(reg.shape) > 0: return _qubit_array(reg) return np.array( @@ -130,8 +136,8 @@ def __repr__(self): return f'cirq_ft.Registers({self._registers})' @classmethod - def build(cls, **registers: Union[int, Tuple[int, ...]]) -> 'Registers': - return cls(Register(name=k, shape=v) for k, v in registers.items()) + def build(cls, **registers: int) -> 'Registers': + return cls(Register(name=k, bitsize=v) for k, v in registers.items()) @overload def __getitem__(self, key: int) -> Register: @@ -216,23 +222,29 @@ class SelectionRegister(Register): >>> assert len(flat_indices) == N * M * L """ + name: str + bitsize: int iteration_length: int = attr.field() + shape: Tuple[int, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=() + ) @iteration_length.default def _default_iteration_length(self): - return 2 ** self.shape[0] + return 2**self.bitsize @iteration_length.validator def validate_iteration_length(self, attribute, value): - if len(self.shape) != 1: + if len(self.shape) != 0: raise ValueError(f'Selection register {self.name} should be flat. Found {self.shape=}') - if not (0 <= value <= 2 ** self.shape[0]): - raise ValueError(f'iteration length must be in range [0, 2^{self.shape[0]}]') + if not (0 <= value <= 2**self.bitsize): + raise ValueError(f'iteration length must be in range [0, 2^{self.bitsize}]') def __repr__(self) -> str: return ( f'cirq_ft.SelectionRegister(' f'name="{self.name}", ' + f'bitsize={self.bitsize}, ' f'shape={self.shape}, ' f'iteration_length={self.iteration_length})' ) diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers_test.py b/cirq-ft/cirq_ft/infra/gate_with_registers_test.py index 7560cb7a357..57af2354e48 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers_test.py +++ b/cirq-ft/cirq_ft/infra/gate_with_registers_test.py @@ -21,8 +21,9 @@ def test_register(): - r = cirq_ft.Register("my_reg", 5) - assert r.shape == (5,) + r = cirq_ft.Register("my_reg", 5, (1, 2)) + assert r.bitsize == 5 + assert r.shape == (1, 2) def test_registers(): @@ -103,12 +104,12 @@ def test_selection_registers_consistent(): _ = cirq_ft.SelectionRegister('a', 3, 10) with pytest.raises(ValueError, match="should be flat"): - _ = cirq_ft.SelectionRegister('a', (3, 5), 5) + _ = cirq_ft.SelectionRegister('a', bitsize=1, shape=(3, 5), iteration_length=5) selection_reg = cirq_ft.Registers( [ - cirq_ft.SelectionRegister('n', shape=3, iteration_length=5), - cirq_ft.SelectionRegister('m', shape=4, iteration_length=12), + cirq_ft.SelectionRegister('n', bitsize=3, iteration_length=5), + cirq_ft.SelectionRegister('m', bitsize=4, iteration_length=12), ] ) assert selection_reg[0] == cirq_ft.SelectionRegister('n', 3, 5) @@ -122,7 +123,9 @@ def test_registers_getitem_raises(): with pytest.raises(IndexError, match="must be of the type"): _ = g[2.5] - selection_reg = cirq_ft.Registers([cirq_ft.SelectionRegister('n', shape=3, iteration_length=5)]) + selection_reg = cirq_ft.Registers( + [cirq_ft.SelectionRegister('n', bitsize=3, iteration_length=5)] + ) with pytest.raises(IndexError, match='must be of the type'): _ = selection_reg[2.5]