From 98418552dc92a039850278bc9c7e2410c5f3e625 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Mar 2024 17:55:22 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20pre-commit=20fixes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mqt/predictor/ml/Predictor.py | 3 +-- src/mqt/predictor/rl/PredictorEnv.py | 12 ++++++------ src/mqt/predictor/rl/helper.py | 20 +++++++++++--------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/mqt/predictor/ml/Predictor.py b/src/mqt/predictor/ml/Predictor.py index 034e65888..9af324a05 100644 --- a/src/mqt/predictor/ml/Predictor.py +++ b/src/mqt/predictor/ml/Predictor.py @@ -8,14 +8,13 @@ import numpy as np from joblib import Parallel, delayed, load from qiskit import QuantumCircuit +from qiskit.qasm2 import dump from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import GridSearchCV, train_test_split from mqt.bench.devices import get_available_device_names, get_available_devices from mqt.predictor import ml, reward, rl, utils -from qiskit.qasm2 import dump - if TYPE_CHECKING: from numpy._typing import NDArray diff --git a/src/mqt/predictor/rl/PredictorEnv.py b/src/mqt/predictor/rl/PredictorEnv.py index f1b9a0750..e10d21a54 100644 --- a/src/mqt/predictor/rl/PredictorEnv.py +++ b/src/mqt/predictor/rl/PredictorEnv.py @@ -13,11 +13,10 @@ from pytket.circuit import Qubit from pytket.extensions.qiskit import qiskit_to_tk, tk_to_qiskit from qiskit import QuantumCircuit +from qiskit.passmanager.flow_controllers import DoWhileController from qiskit.transpiler import CouplingMap, PassManager, TranspileLayout from qiskit.transpiler.passes import CheckMap, GatesInBasis -from qiskit.passmanager.flow_controllers import DoWhileController - from mqt.bench.devices import get_device_by_name from mqt.predictor import reward, rl @@ -218,11 +217,12 @@ def apply_action(self, action_index: int) -> QuantumCircuit | None: pm = PassManager() pm.append( DoWhileController( - action["transpile_pass"]( - self.device.basis_gates, - CouplingMap(self.device.coupling_map) if self.layout is not None else None, + action["transpile_pass"]( + self.device.basis_gates, + CouplingMap(self.device.coupling_map) if self.layout is not None else None, + ), + do_while=action["do_while"], ), - do_while=action["do_while"]), ) else: pm = PassManager(transpile_pass) diff --git a/src/mqt/predictor/rl/helper.py b/src/mqt/predictor/rl/helper.py index b413f910e..a6b55de53 100644 --- a/src/mqt/predictor/rl/helper.py +++ b/src/mqt/predictor/rl/helper.py @@ -52,18 +52,14 @@ VF2Layout, VF2PostLayout, ) +from qiskit.transpiler.passes.layout.vf2_layout import VF2LayoutStopReason from sb3_contrib import MaskablePPO from tqdm import tqdm - -from qiskit.transpiler.passes.layout.vf2_layout import VF2LayoutStopReason - from mqt.bench.utils import calc_supermarq_features from mqt.predictor import reward, rl if TYPE_CHECKING: - from collections.abc import Callable - from numpy.typing import NDArray from mqt.bench.devices import Device @@ -190,8 +186,11 @@ def get_actions_opt() -> list[dict[str, Any]]: CommutativeCancellation(basis_gates=native_gate), GatesInBasis(native_gate), ConditionalController( - common.generate_translation_passmanager(target=None, basis_gates=native_gate, coupling_map=coupling_map).to_flow_controller(), - condition=lambda property_set: not property_set["all_gates_in_basis"]), + common.generate_translation_passmanager( + target=None, basis_gates=native_gate, coupling_map=coupling_map + ).to_flow_controller(), + condition=lambda property_set: not property_set["all_gates_in_basis"], + ), Depth(recurse=True), FixedPoint("depth"), Size(recurse=True), @@ -257,8 +256,11 @@ def get_actions_layout() -> list[dict[str, Any]]: [ FullAncillaAllocation(coupling_map=CouplingMap(device.coupling_map)), EnlargeWithAncilla(), - ApplyLayout()], - condition=lambda property_set: property_set["VF2Layout_stop_reason"] == VF2LayoutStopReason.SOLUTION_FOUND) + ApplyLayout(), + ], + condition=lambda property_set: property_set["VF2Layout_stop_reason"] + == VF2LayoutStopReason.SOLUTION_FOUND, + ), ], "origin": "qiskit", },