diff --git a/usecases/usecase/usecase.py b/usecases/usecase/usecase.py index c52e8af..46688fc 100644 --- a/usecases/usecase/usecase.py +++ b/usecases/usecase/usecase.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from typing import Dict, Type -from utils.configurable import get_parameters, ParameterDefinitions, build_parser, get_arguments +from utils.configurable import ParameterDefinitions, build_parser, get_arguments, get_class_parameters class UseCase(abc.ABC): @@ -66,7 +66,7 @@ def use_case(name: str, desc: str): def inner(cls: Type[UseCase]): if name in use_cases: raise IndexError(f"Use case with name {name} already exists") - use_cases[name] = _WrappedUseCase(name, desc, cls, get_parameters(cls.__init__, name)) + use_cases[name] = _WrappedUseCase(name, desc, cls, get_class_parameters(cls, name)) return cls diff --git a/utils/configurable.py b/utils/configurable.py index 7c17113..d2d7dae 100644 --- a/utils/configurable.py +++ b/utils/configurable.py @@ -1,4 +1,5 @@ import argparse +import dataclasses import inspect import os from dataclasses import dataclass @@ -12,6 +13,16 @@ load_dotenv() +def parameter(*, desc: str, default=dataclasses.MISSING, init: bool = True, repr: bool = True, hash=None, + compare: bool = True, metadata: Dict = None, kw_only: bool = dataclasses.MISSING) -> dataclasses.Field: + if metadata is None: + metadata = dict() + metadata["desc"] = desc + + return dataclasses.field(default=default, default_factory=dataclasses.MISSING, init=init, repr=repr, hash=hash, + compare=compare, metadata=metadata, kw_only=kw_only) + + def get_default(key, default): return os.getenv(key, os.getenv(key.upper(), os.getenv(key.replace(".", "_"), os.getenv(key.replace(".", "_").upper(), default)))) @@ -24,12 +35,14 @@ class ParameterDefinition: name: str type: Type default: Any + description: str def parser(self, basename: str, parser: argparse.ArgumentParser): name = f"{basename}{self.name}" default = get_default(name, self.default) - parser.add_argument(f"--{name}", type=self.type, default=default, required=default is None) + parser.add_argument(f"--{name}", type=self.type, default=default, required=default is None, + help=self.description) def get(self, basename: str, args: argparse.Namespace): return getattr(args, f"{basename}{self.name}") @@ -62,7 +75,18 @@ def get(self, basename: str, args: argparse.Namespace): return parameter -def get_parameters(fun, basename: str) -> ParameterDefinitions: +def get_class_parameters(cls, name: str = None, fields: Dict[str, dataclasses.Field] = None) -> ParameterDefinitions: + if name is None: + name = cls.__name__ + if fields is None and hasattr(cls, "__dataclass_fields__"): + fields = cls.__dataclass_fields__ + return get_parameters(cls.__init__, name, fields) + + +def get_parameters(fun, basename: str, fields: Dict[str, dataclasses.Field] = None) -> ParameterDefinitions: + if fields is None: + fields = dict() + sig = inspect.signature(fun) params: ParameterDefinitions = {} for name, param in sig.parameters.items(): @@ -73,13 +97,27 @@ def get_parameters(fun, basename: str) -> ParameterDefinitions: raise ValueError(f"Parameter {name} of {basename}.{fun.__name__} must have a type annotation") default = param.default if param.default != inspect.Parameter.empty else None - - if hasattr(param.annotation, "__parameters__"): - params[name] = ComplexParameterDefinition(name, param.annotation, default, get_parameters(param.annotation, f"{basename}.{fun.__name__}")) - elif param.annotation in (str, int, bool): - params[name] = ParameterDefinition(name, param.annotation, default) + description = None + type = param.annotation + + field = None + if isinstance(default, dataclasses.Field): + field = default + default = field.default + elif name in fields: + field = fields[name] + + if field is not None: + description = field.metadata.get("desc", None) + if field.type is not None: + type = field.type + + if hasattr(type, "__parameters__"): + params[name] = ComplexParameterDefinition(name, type, default, description, get_class_parameters(type, f"{basename}.{fun.__name__}")) + elif type in (str, int, bool): + params[name] = ParameterDefinition(name, type, default, description) else: - raise ValueError(f"Parameter {name} of {basename}.{fun.__name__} must have str, int, bool, or a __parameters__ class as type, not {param.annotation}") + raise ValueError(f"Parameter {name} of {basename}.{fun.__name__} must have str, int, bool, or a __parameters__ class as type, not {type}") return params @@ -106,7 +144,7 @@ def inner(cls) -> Configurable: cls.name = service_name cls.description = service_desc cls.__service__ = True - cls.__parameters__ = get_parameters(cls.__init__, cls.__name__) + cls.__parameters__ = get_class_parameters(cls) return cls diff --git a/utils/db_storage/db_storage.py b/utils/db_storage/db_storage.py index 07972a9..dec6eb6 100644 --- a/utils/db_storage/db_storage.py +++ b/utils/db_storage/db_storage.py @@ -1,11 +1,11 @@ import sqlite3 -from utils.configurable import configurable +from utils.configurable import configurable, parameter @configurable("db_storage", "Stores the results of the experiments in a SQLite database") class DbStorage: - def __init__(self, connection_string: str = ":memory:"): + def __init__(self, connection_string: str = parameter(desc="sqlite3 database connection string for logs", default=":memory:")): self.connection_string = connection_string def init(self): diff --git a/utils/openai/openai_llm.py b/utils/openai/openai_llm.py index c67fb0d..024115d 100644 --- a/utils/openai/openai_llm.py +++ b/utils/openai/openai_llm.py @@ -6,7 +6,7 @@ import tiktoken -from utils.configurable import configurable +from utils.configurable import configurable, parameter from utils.llm_util import LLMResult, LLM @@ -20,13 +20,13 @@ class OpenAIConnection(LLM): If you really must use it, you can import it directly from the utils.openai.openai_llm module, which will later on show you, that you did not specialize yet. """ - api_key: str - model: str - context_size: int - api_url: str = "https://api.openai.com" - api_timeout: int = 240 - api_backoff: int = 60 - api_retries: int = 3 + api_key: str = parameter(desc="OpenAI API Key") + model: str = parameter(desc="OpenAI model name") + context_size: int = parameter(desc="Maximum context size for the model, only used internally for things like trimming to the context size") + api_url: str = parameter(desc="URL of the OpenAI API", default="https://api.openai.com") + api_timeout: int = parameter(desc="Timeout for the API request", default=240) + api_backoff: int = parameter(desc="Backoff time in seconds when running into rate-limits", default=60) + api_retries: int = parameter(desc="Number of retries when running into rate-limits", default=3) def get_response(self, prompt, *, retry: int = 0, **kwargs) -> LLMResult: if retry >= self.api_retries: diff --git a/wintermute.py b/wintermute.py index d64fee9..d3e1aac 100644 --- a/wintermute.py +++ b/wintermute.py @@ -10,7 +10,7 @@ def main(): for name, use_case in use_cases.items(): use_case.build_parser(subparser.add_parser( name=use_case.name, - description=use_case.description + help=use_case.description )) parsed = parser.parse_args(sys.argv[1:])