Skip to content

Commit

Permalink
utilize function wrapper (#807)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajberdy authored Nov 20, 2023
1 parent e189311 commit 1f76b13
Showing 1 changed file with 23 additions and 21 deletions.
44 changes: 23 additions & 21 deletions src/braket/experimental/autoqasm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def main(
*,
num_qubits: Optional[int] = None,
device: Optional[Union[Device, str]] = None,
) -> (
aq_program.Program | Callable[[Optional[int], Optional[Union[Device, str]]], aq_program.Program]
):
) -> aq_program.Program | functools.partial:
"""Decorator that converts a function into a callable that returns
a Program object containing the quantum program.
Expand All @@ -68,33 +66,33 @@ def main(
program. Can be either an Device object or a valid Amazon Braket device ARN.
Returns:
Program | Callable[[Optional[int], Optional[Union[Device, str]]], Program]: A callable
Program | partial: A callable
which returns the converted quantum program when called.
"""
if isinstance(device, str):
device = AwsDevice(device)

bound_convert_main = functools.partial(
_convert_main,
options=converter.ConversionOptions(
user_requested=True,
optional_features=_autograph_optional_features(),
),
user_config=aq_program.UserConfig(
num_qubits=num_qubits,
device=device,
),
)

# decorator is called on a Program
if isinstance(func, aq_program.Program):
return func

# decorator is used with parentheses
# (see _function_wrapper for more details)
if not (func and callable(func)):
# decorator is used with parentheses
# (see _function_wrapper for more details)
return bound_convert_main
return functools.partial(main, num_qubits=num_qubits, device=device)

program_builder = _function_wrapper(
func,
converter_callback=_convert_main,
converter_args={
"user_config": aq_program.UserConfig(
num_qubits=num_qubits,
device=device,
)
},
)

return bound_convert_main(func)
return program_builder()


def subroutine(func: Optional[Callable] = None) -> Callable[..., aq_program.Program]:
Expand Down Expand Up @@ -196,7 +194,10 @@ def _wrapper(*args, **kwargs) -> Callable:
optional_features=_autograph_optional_features(),
)
# Call the appropriate function converter
return converter_callback(func, options, args, kwargs, **converter_args)
if converter_callback == _convert_main:
# main doesn't take args or kwargs at call time
return converter_callback(func, options=options, **converter_args)
return converter_callback(func, options=options, args=args, kwargs=kwargs, **converter_args)

if inspect.isfunction(func) or inspect.ismethod(func):
_wrapper = functools.update_wrapper(_wrapper, func)
Expand All @@ -214,6 +215,7 @@ def _autograph_optional_features() -> tuple[converter.Feature]:

def _convert_main(
f: Callable,
*,
options: converter.ConversionOptions,
user_config: aq_program.UserConfig,
) -> aq_program.Program:
Expand Down

0 comments on commit 1f76b13

Please sign in to comment.