From d37d023a1c86491948d8e6f31c8932323f72a85e Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Wed, 7 Dec 2022 16:50:56 -0500 Subject: [PATCH] DBTContext --- core/dbt/cli/context.py | 44 +++++++++++++++++++++++++++++++++++++++++ core/dbt/cli/main.py | 12 ----------- core/dbt/tests/util.py | 8 ++++++-- 3 files changed, 50 insertions(+), 14 deletions(-) create mode 100644 core/dbt/cli/context.py diff --git a/core/dbt/cli/context.py b/core/dbt/cli/context.py new file mode 100644 index 00000000000..e0cf27624ca --- /dev/null +++ b/core/dbt/cli/context.py @@ -0,0 +1,44 @@ +from typing import List +from click import Context +from click.exceptions import NoSuchOption, UsageError +from dbt.cli.main import cli +from dbt.config.project import Project + +class DBTUsageException(Exception): + pass + +class DBTContext(Context): + def __init__(self, args: List[str]) -> None: + try: + ctx = cli.make_context(cli.name, args) + if args: + cmd_name, cmd, cmd_args = cli.resolve_command(ctx, args) + cmd.make_context(cmd_name, cmd_args, parent=ctx) + except (NoSuchOption, UsageError) as e: + raise DBTUsageException(e.message) + + ctx.obj = {} + # yikes? + self.__dict__.update(ctx.__dict__) + # TODO: consider initializing Flags, ctx.obj here. + + # @classmethod + # def from_args(cls, args: List[str]) -> "DBTContext": + # try: + # ctx = cli.make_context(cli.name, args) + # if args: + # cmd_name, cmd, cmd_args = cli.resolve_command(ctx, args) + # cmd.make_context(cmd_name, cmd_args, parent=ctx) + # except (NoSuchOption, UsageError) as e: + # raise DBTUsageException(e.message) + + # ctx.obj = {} + # # yikes + # ctx.__class__ = cls + # return ctx + + def set_project(self, project: Project): + if not isinstance(project, Project): + raise ValueError(f"{project} is a {type(project)}, expected a Project object.") + + self.obj["project"] = project \ No newline at end of file diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index 3942bf73419..f16e4c737e6 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -11,20 +11,8 @@ from dbt.tracking import initialize_from_flags, track_run from dbt.config.runtime import load_project from dbt.task.deps import DepsTask -from typing import Optional -def make_context(args, command) -> Optional[click.Context]: - ctx = command.make_context(command.name, args) - - ctx.invoked_subcommand = ctx.protected_args[0] if ctx.protected_args else None - return ctx - -def handle_and_check(args): - ctx = make_context(args, cli) - res, success = cli.invoke(ctx) - return res, success - def cli_runner(): # Alias "list" to "ls" ls = copy(cli.commands["list"]) diff --git a/core/dbt/tests/util.py b/core/dbt/tests/util.py index 6943ee1d7b7..dbff80d111a 100644 --- a/core/dbt/tests/util.py +++ b/core/dbt/tests/util.py @@ -8,7 +8,7 @@ from contextlib import contextmanager from dbt.adapters.factory import Adapter -from dbt.cli.main import handle_and_check +from dbt.cli.context import DBTContext, cli as dbt from dbt.logger import log_manager from dbt.contracts.graph.manifest import Manifest from dbt.events.functions import fire_event, capture_stdout_logs, stop_capture_stdout_logs, reset_metadata_vars @@ -73,7 +73,11 @@ def run_dbt(args: List[str] = None, expect_pass=True): args = ["run"] print("\n\nInvoking dbt with {}".format(args)) - res, success = handle_and_check(args) + # ctx = DBTContext.from_args(args) + ctx = DBTContext(args) + + # ctx.set_project('test') + res, success = dbt.invoke(ctx) if expect_pass is not None: assert success == expect_pass, "dbt exit state did not match expected"