diff --git a/core/dbt/cli/context.py b/core/dbt/cli/context.py new file mode 100644 index 00000000000..e0cf27624ca --- /dev/null +++ b/core/dbt/cli/context.py @@ -0,0 +1,44 @@ +from typing import List +from click import Context +from click.exceptions import NoSuchOption, UsageError +from dbt.cli.main import cli +from dbt.config.project import Project + +class DBTUsageException(Exception): + pass + +class DBTContext(Context): + def __init__(self, args: List[str]) -> None: + try: + ctx = cli.make_context(cli.name, args) + if args: + cmd_name, cmd, cmd_args = cli.resolve_command(ctx, args) + cmd.make_context(cmd_name, cmd_args, parent=ctx) + except (NoSuchOption, UsageError) as e: + raise DBTUsageException(e.message) + + ctx.obj = {} + # yikes? + self.__dict__.update(ctx.__dict__) + # TODO: consider initializing Flags, ctx.obj here. + + # @classmethod + # def from_args(cls, args: List[str]) -> "DBTContext": + # try: + # ctx = cli.make_context(cli.name, args) + # if args: + # cmd_name, cmd, cmd_args = cli.resolve_command(ctx, args) + # cmd.make_context(cmd_name, cmd_args, parent=ctx) + # except (NoSuchOption, UsageError) as e: + # raise DBTUsageException(e.message) + + # ctx.obj = {} + # # yikes + # ctx.__class__ = cls + # return ctx + + def set_project(self, project: Project): + if not isinstance(project, Project): + raise ValueError(f"{project} is a {type(project)}, expected a Project object.") + + self.obj["project"] = project \ No newline at end of file diff --git a/core/dbt/tests/util.py b/core/dbt/tests/util.py index af837c18b17..b446ebd7faa 100644 --- a/core/dbt/tests/util.py +++ b/core/dbt/tests/util.py @@ -9,6 +9,7 @@ from dbt.adapters.factory import Adapter from dbt.main import handle_and_check +from dbt.cli.context import DBTContext, cli as dbt from dbt.logger import log_manager from dbt.contracts.graph.manifest import Manifest from dbt.events.functions import fire_event, capture_stdout_logs, stop_capture_stdout_logs, reset_metadata_vars @@ -73,7 +74,11 @@ def run_dbt(args: List[str] = None, expect_pass=True): args = ["run"] print("\n\nInvoking dbt with {}".format(args)) - res, success = handle_and_check(args) + # ctx = DBTContext.from_args(args) + ctx = DBTContext(args) + + # ctx.set_project('test') + res, success = dbt.invoke(ctx) if expect_pass is not None: assert success == expect_pass, "dbt exit state did not match expected"