Skip to content

Commit

Permalink
✨ feat: add register decorator in OperatorFns
Browse files Browse the repository at this point in the history
  • Loading branch information
prajeeshag committed Nov 27, 2024
1 parent 063061d commit b352fac
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 92 deletions.
11 changes: 6 additions & 5 deletions clios/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .cli.app import Clios
from .core.operator import OperatorError
from .core.param_info import Input, Output, Param

__all__ = ["Clios", "OperatorError", "Input", "Output", "Param"]
from .cli.app import Clios as Clios
from .cli.app import OperatorFns as OperatorFns
from .core.operator import OperatorError as OperatorError
from .core.param_info import Input as Input
from .core.param_info import Output as Output
from .core.param_info import Param as Param
79 changes: 38 additions & 41 deletions clios/cli/app.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import sys
from typing import Annotated, Any, Callable, Literal
import typing as t

import click
from rich import print

from ..core.operator_fn import OperatorFn, OperatorFns
from clios.core.param_parser import ParamParserAbc

from ..core.operator_fn import OperatorFn
from ..core.operator_fn import OperatorFns as OperatorFns_
from ..core.param_info import Input
from .main_parser import CliParser
from .param_parser import CliParamParser
from .param_parser import StandardParamParser
from .presenter import CliPresenter


def output(input: Annotated[Any, Input()]) -> None:
def output(input: t.Annotated[t.Any, Input()]) -> None:
"""
Print the given input data to the terminal.
Expand All @@ -21,51 +24,31 @@ def output(input: Annotated[Any, Input()]) -> None:
print(input)


@click.command(
context_settings={"allow_extra_args": True, "ignore_unknown_options": True}
)
@click.option("--list", type=bool, help="List all available operators", is_flag=True)
@click.option(
"--show", type=str, help="Show the help information for the given operator", nargs=1
)
@click.option(
"--dry-run", type=bool, help="Dry run: prints the call tree", is_flag=True
)
@click.pass_context
def _click_app(ctx: Any, **kwargs: Any) -> tuple[list[str], dict[str, Any]]:
return ctx.args, kwargs
standard_param_parser = StandardParamParser()


default_param_parser = CliParamParser()
class OperatorFns(OperatorFns_):
@t.override
def register(
self,
*,
name: str = "",
param_parser: ParamParserAbc = standard_param_parser,
implicit: t.Literal["input", "param"] = "input",
) -> t.Callable[..., t.Any]:
return super().register(name=name, param_parser=param_parser, implicit=implicit)


class Clios:
def __init__(self) -> None:
self._operators = OperatorFns()
self._parser = CliParser()
def __init__(self, operator_fns: OperatorFns_) -> None:
self._operators = OperatorFns_()
self._operators["print"] = OperatorFn.validate(
output, param_parser=default_param_parser
output, param_parser=standard_param_parser
)
self._operators.update(operator_fns)
self._parser = CliParser()
self._presenter = CliPresenter(self._operators, self._parser)

def operator(
self,
name: str = "",
param_parser: CliParamParser = default_param_parser,
implicit: Literal["input", "param"] = "input",
) -> Any:
def decorator(func: Callable[..., Any]):
op_obj = OperatorFn.validate(
func,
param_parser=param_parser,
implicit=implicit,
)
key = name if name else func.__name__
self._operators[key] = op_obj
return func

return decorator

def __call__(self):
try:
args, options = _click_app(standalone_mode=False)
Expand All @@ -84,4 +67,18 @@ def __call__(self):

with click.Context(_click_app) as ctx:
click.echo(_click_app.get_help(ctx))
# self.presenter.print_list()


@click.command(
context_settings={"allow_extra_args": True, "ignore_unknown_options": True}
)
@click.option("--list", type=bool, help="List all available operators", is_flag=True)
@click.option(
"--show", type=str, help="Show the help information for the given operator", nargs=1
)
@click.option(
"--dry-run", type=bool, help="Dry run: prints the call tree", is_flag=True
)
@click.pass_context
def _click_app(ctx: t.Any, **kwargs: t.Any) -> tuple[list[str], dict[str, t.Any]]:
return ctx.args, kwargs
2 changes: 1 addition & 1 deletion clios/cli/param_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@dataclass(frozen=True)
class CliParamParser(ParamParserAbc):
class StandardParamParser(ParamParserAbc):
arg_sep: str = ","
kw_sep: str = "="
"""
Expand Down
26 changes: 12 additions & 14 deletions examples/calc/calc.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,41 @@
# [start]
from typing import Annotated, Any

from clios import Clios, Param
from clios.core.param_info import Output
from clios import OperatorFns, Output, Param

app = Clios()
# [app_created]
operators = OperatorFns()


@app.operator()
@operators.register()
def add(input1: float, input2: float) -> float:
"""Add two numbers."""
return input1 + input2


@app.operator()
@operators.register()
def sub(input1: float, input2: float) -> float:
"""Subtract two numbers."""
return input1 - input2


@app.operator()
@operators.register()
def mul(input1: float, input2: float) -> float:
"""Multiply two numbers."""
return input1 * input2


@app.operator()
@operators.register()
def div(input1: float, input2: float) -> float:
"""Divide two numbers."""
return input1 / input2


@app.operator()
@operators.register()
def sqrt(input: float) -> float:
"""Calculate the square root of a number."""
return input**0.5


@app.operator()
@operators.register()
def mean(*inputs: float) -> float:
"""Calculate the mean of a list of numbers."""
return sum(inputs) / len(inputs)
Expand All @@ -57,13 +54,14 @@ def file_writer(content: Any, file_path: str) -> None:
floatOutput = Annotated[list[float], Output(callback=file_writer)]


@app.operator(name="range", implicit="param")
@operators.register(name="range", implicit="param")
def range_(start: floatParam, end: floatParam, step: floatParam = 1) -> floatOutput:
"""Generate a range of numbers."""
return list(range(int(start), int(end), int(step)))


# [main_start]
if __name__ == "__main__":
from clios import Clios

app = Clios(operators)
app()
# [main_end]
6 changes: 4 additions & 2 deletions examples/calc/test_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import sys

import pytest
from calc import app
from calc import operators
from parameters import parameters

from clios import Clios


@pytest.mark.parametrize("input, output", parameters)
def test(input: list[str], output: str, capsys):
sys.argv = ["calc", *input]
app()
Clios(operators)()
captured = capsys.readouterr()
assert captured.out == f"{output}\n"
26 changes: 9 additions & 17 deletions tests/cli/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,53 +3,45 @@

import pytest

from clios.cli.app import Clios
from clios.cli.app import Clios, OperatorFns


@pytest.fixture
def app():
return Clios()


def test_operator_registration(app):
@app.operator(name="test_op")
def test_op():
return "test"

assert app._operators.get("test_op")
return OperatorFns()


def test_click_app_list(app):
sys.argv = ["cli", "--list"]
result = app()
result = Clios(app)()
assert result is None


def test_click_app_show(app):
@app.operator(name="test_op")
@app.register(name="test_op")
def test_op():
return "test"

sys.argv = ["cli", "--show", "test_op"]
result = app()
result = Clios(app)()
assert result is None


def test_click_app_dry_run(app):
@app.operator(name="test_op")
@app.register(name="test_op")
def test_op():
return "test"

sys.argv = ["cli", "--dry-run", "test_op"]
result = app()
result = Clios(app)()
assert result is None


def test_click_app_run(app):
@app.operator(name="test_op")
@app.register(name="test_op")
def test_op():
return None

sys.argv = ["cli", "test_op"]
result = app()
result = Clios(app)()
assert result is None
4 changes: 2 additions & 2 deletions tests/cli/test_main_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import ValidationError

from clios.cli.main_parser import CliParser, ParserError
from clios.cli.param_parser import CliParamParser, ParamParserError
from clios.cli.param_parser import ParamParserError, StandardParamParser
from clios.cli.tokenizer import CliTokenizer
from clios.core.operator import RootOperator
from clios.core.operator_fn import OperatorFn, OperatorFns
Expand Down Expand Up @@ -118,7 +118,7 @@ def op_2o() -> t.Annotated[int, Output(callback=print, num_outputs=2)]:
"op_2o": op_2o,
}
operator_fns = OperatorFns()
param_parser = CliParamParser()
param_parser = StandardParamParser()
for name, op in ops.items():
operator_fns[name] = OperatorFn.validate(op, param_parser=param_parser)

Expand Down
12 changes: 6 additions & 6 deletions tests/cli/test_param_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from pydantic import ValidationError

from clios.cli.param_parser import CliParamParser
from clios.cli.param_parser import StandardParamParser
from clios.core.operator_fn import OperatorFn
from clios.core.param_info import Input, Output, Param
from clios.core.param_parser import ParamParserError as ParserError
Expand Down Expand Up @@ -178,7 +178,7 @@ def op_1i1k1o(i: intIn, *, ip: intParam) -> intOut:

@pytest.mark.parametrize("input,expected", invalid_operators)
def test_parse_arguments(input, expected):
parser = CliParamParser()
parser = StandardParamParser()
parameters = OperatorFn.validate(input[1], parser).parameters

with pytest.raises(ParserError) as e:
Expand All @@ -189,7 +189,7 @@ def test_parse_arguments(input, expected):

@pytest.mark.parametrize("input,expected", build_error)
def test_parse_arguments_build_error(input, expected):
parser = CliParamParser()
parser = StandardParamParser()
parameters = OperatorFn.validate(input[1], parser).parameters

with pytest.raises(ParserError) as e:
Expand All @@ -203,22 +203,22 @@ def test_parse_arguments_build_error(input, expected):


def test_valid():
parser = CliParamParser()
parser = StandardParamParser()
parameters = OperatorFn.validate(op_1p1k, parser, implicit="param").parameters
param_values = parser.parse_arguments("1,ik=1", parameters)
assert param_values[0] == (1,)
assert param_values[1] == (("ik", 1),)


def test_valid_single():
parser = CliParamParser()
parser = StandardParamParser()
parameters = OperatorFn.validate(op_1i1p1o, parser).parameters
param_values = parser.parse_arguments("1", parameters)
assert param_values[0] == (1,)


def test_get_synopsis():
parser = CliParamParser()
parser = StandardParamParser()

def fn(
input: int,
Expand Down
4 changes: 2 additions & 2 deletions tests/cli/test_presenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing_extensions import Doc

from clios.cli.main_parser import CliParser
from clios.cli.param_parser import CliParamParser
from clios.cli.param_parser import StandardParamParser
from clios.cli.presenter import CliPresenter
from clios.core.main_parser import ParserError
from clios.core.operator_fn import OperatorFn, OperatorFns
Expand Down Expand Up @@ -54,7 +54,7 @@ def operator1(

@pytest.fixture
def get_presenter(mocker):
param_parser = CliParamParser()
param_parser = StandardParamParser()

def _presenter(fns):
operator_fns = OperatorFns()
Expand Down
6 changes: 4 additions & 2 deletions tests/test_operator_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import ValidationError

from clios.cli.main_parser import CliParser
from clios.cli.param_parser import CliParamParser
from clios.cli.param_parser import StandardParamParser
from clios.core.operator import OperatorError, RootOperator
from clios.core.operator_fn import OperatorFn, OperatorFns
from clios.core.param_info import Input, Output, Param
Expand Down Expand Up @@ -113,7 +113,9 @@ def list_functions():

operators = OperatorFns()
for func in list_functions():
operators[func.__name__] = OperatorFn.validate(func, param_parser=CliParamParser())
operators[func.__name__] = OperatorFn.validate(
func, param_parser=StandardParamParser()
)


execute_error = [
Expand Down

0 comments on commit b352fac

Please sign in to comment.