Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rebase Refactor passes module & clean-up exports with v0.10.0-rc #1432

Merged
merged 15 commits into from
Jan 9, 2025
Merged
9 changes: 3 additions & 6 deletions frontend/catalyst/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@
"mlir_quantum._mlir_libs._quantumDialects.mitigation"
)

from catalyst import debug, logging
from catalyst import debug, logging, passes
from catalyst.api_extensions import *
from catalyst.api_extensions import __all__ as _api_extension_list
from catalyst.autograph import *
from catalyst.autograph import __all__ as _autograph_functions
from catalyst.compiler import CompileOptions
from catalyst.debug.assertion import debug_assert
from catalyst.jit import QJIT, qjit
from catalyst.passes import Pass, PassPlugin, apply_pass, apply_pass_plugin, pipeline
from catalyst.passes.pass_api import pipeline
from catalyst.utils.exceptions import (
AutoGraphError,
CompileError,
Expand Down Expand Up @@ -187,11 +187,8 @@
"debug_assert",
"CompileOptions",
"debug",
"apply_pass",
"apply_pass_plugin",
"passes",
"pipeline",
"Pass",
"PassPlugin",
*_api_extension_list,
*_autograph_functions,
)
3 changes: 2 additions & 1 deletion frontend/catalyst/jax_primitives_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,9 @@ def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipelin
with ir.InsertionPoint(bb_named_sequence):
target = bb_named_sequence.arguments[0]
for _pass in pipeline:
options = _pass.get_options()
apply_registered_pass_op = ApplyRegisteredPassOp(
result=transform_mod_type, target=target, pass_name=_pass.name
result=transform_mod_type, target=target, pass_name=_pass.name, options=options
)
target = apply_registered_pass_op.result
transform_yield_op = YieldOp(operands_=[]) # pylint: disable=unused-variable
Expand Down
Loading