Skip to content

Commit

Permalink
simpler code, thanks to Jake
Browse files Browse the repository at this point in the history
Co-authored-by: Jake Lishman <jake.lishman@ibm.com>
  • Loading branch information
1ucian0 and jakelishman committed Mar 10, 2023
1 parent 0e1ac16 commit b21c622
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 22 deletions.
27 changes: 9 additions & 18 deletions qiskit/transpiler/preset_passmanagers/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,31 +304,22 @@ def list_stage_plugins(stage_name: str) -> List[str]:
raise TranspilerError(f"Invalid stage name: {stage_name}")


def entry_point_obj(stage_name: str, plugin_name: str) -> abc.ABCMeta:
def passmanager_stage_plugins(stage: str):
"""Return the class type of an entry point.
Args:
stage_name: The stage name to get the entrypoint for
plugin_name: The plugin name to get the entrypoint for
stage: The stage name to get the entrypoint for
Returns:
Type: Class of the entrypoint
dict Type: TODO
Raises:
TranspilerError: If an invalid stage name is specified.
"""
plugin_mgr = PassManagerStagePluginManager()
if stage_name == "init":
return plugin_mgr.init_plugins[plugin_name].obj
elif stage_name == "layout":
return plugin_mgr.layout_plugins[plugin_name].obj
elif stage_name == "routing":
return plugin_mgr.routing_plugins[plugin_name].obj
elif stage_name == "translation":
return plugin_mgr.translation_plugins[plugin_name].obj
elif stage_name == "optimization":
return plugin_mgr.optimization_plugins[plugin_name].obj
elif stage_name == "scheduling":
return plugin_mgr.scheduling_plugins[plugin_name].obj
else:
raise TranspilerError(f"Invalid stage name: {stage_name}")
try:
manager = getattr(plugin_mgr, f"{stage}_plugins")
except AttributeError as exc:
raise TranspilerError(f"Passmanager stage {stage} not found") from exc

return {name: manager[name].obj for name in manager.names()}
13 changes: 9 additions & 4 deletions test/python/transpiler/test_stage_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from qiskit.transpiler.preset_passmanagers.plugin import (
PassManagerStagePluginManager,
list_stage_plugins,
entry_point_obj,
passmanager_stage_plugins,
)
from qiskit.transpiler.exceptions import TranspilerError
from qiskit.providers.basicaer import QasmSimulatorPy
Expand All @@ -53,10 +53,15 @@ def test_list_stage_plugins_invalid_stage_name(self):
with self.assertRaises(TranspilerError):
list_stage_plugins("not_a_stage")

def test_entry_point_obj(self):
def test_passmanager_stage_plugins(self):
"""Test entry_point_obj function."""
basic_obj = entry_point_obj("routing", "basic")
self.assertIsInstance(basic_obj, BasicSwapPassManager)
basic_obj = passmanager_stage_plugins("routing")
self.assertIsInstance(basic_obj["basic"], BasicSwapPassManager)

def test_passmanager_stage_plugins_not_found(self):
"""Test entry_point_obj function with nonexistent stage"""
with self.assertRaises(TranspilerError):
passmanager_stage_plugins("foo_stage")

def test_build_pm_invalid_plugin_name_valid_stage(self):
"""Test get pm from plugin with invalid plugin name and valid stage."""
Expand Down

0 comments on commit b21c622

Please sign in to comment.