From 3efaf9474e0f68c010a988950cd0dcff86c211f5 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Thu, 8 Dec 2022 15:41:48 -0500 Subject: [PATCH] DBTContext with invocation_args --- core/dbt/cli/context.py | 61 ++++++++++++++++++++--------------------- core/dbt/cli/main.py | 7 ++++- core/dbt/tests/util.py | 9 +++--- 3 files changed, 39 insertions(+), 38 deletions(-) diff --git a/core/dbt/cli/context.py b/core/dbt/cli/context.py index e0cf27624ca..f26b5156fde 100644 --- a/core/dbt/cli/context.py +++ b/core/dbt/cli/context.py @@ -1,44 +1,41 @@ -from typing import List -from click import Context +from click import Context, Group, Command from click.exceptions import NoSuchOption, UsageError -from dbt.cli.main import cli +from dbt.cli.flags import Flags from dbt.config.project import Project +import sys + class DBTUsageException(Exception): pass + class DBTContext(Context): - def __init__(self, args: List[str]) -> None: - try: - ctx = cli.make_context(cli.name, args) - if args: - cmd_name, cmd, cmd_args = cli.resolve_command(ctx, args) - cmd.make_context(cmd_name, cmd_args, parent=ctx) + def __init__(self, command: Command, **kwargs) -> None: + if isinstance(kwargs.get("parent"), DBTContext): + self.invocation_args = kwargs["parent"].invocation_args + else: + self.invocation_args = kwargs.pop("args", sys.argv[1:]) + + super().__init__(command, **kwargs) + + # Bubble up validation errors for top-level commands + if not self.parent: + self._validate_args(command, self.invocation_args) + + self.obj = self.obj or {} + self.flags = Flags(self) + + def _validate_args(self, command, args) -> None: + try: + command.parse_args(self, args) + if isinstance(command, Group): + _, cmd, cmd_args = command.resolve_command(self, args) + self._validate_args(cmd, cmd_args) except (NoSuchOption, UsageError) as e: - raise DBTUsageException(e.message) - - ctx.obj = {} - # yikes? - self.__dict__.update(ctx.__dict__) - # TODO: consider initializing Flags, ctx.obj here. - - # @classmethod - # def from_args(cls, args: List[str]) -> "DBTContext": - # try: - # ctx = cli.make_context(cli.name, args) - # if args: - # cmd_name, cmd, cmd_args = cli.resolve_command(ctx, args) - # cmd.make_context(cmd_name, cmd_args, parent=ctx) - # except (NoSuchOption, UsageError) as e: - # raise DBTUsageException(e.message) - - # ctx.obj = {} - # # yikes - # ctx.__class__ = cls - # return ctx + raise DBTUsageException(e.message) def set_project(self, project: Project): if not isinstance(project, Project): raise ValueError(f"{project} is a {type(project)}, expected a Project object.") - - self.obj["project"] = project \ No newline at end of file + + self.obj["project"] = project diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index f16e4c737e6..779eb1d596b 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -6,6 +6,7 @@ from dbt.adapters.factory import adapter_management from dbt.cli import params as p from dbt.cli.flags import Flags +from dbt.cli.context import DBTContext from dbt.events.functions import setup_event_logger from dbt.profiler import profiler from dbt.tracking import initialize_from_flags, track_run @@ -19,6 +20,11 @@ def cli_runner(): ls.hidden = True cli.add_command(ls, "ls") + # TODO: set context_class this on all commands outside this method + cli.context_class = DBTContext + for command in cli.commands.values(): + command.context_class = DBTContext + # Run the cli cli() @@ -57,7 +63,6 @@ def cli(ctx, **kwargs): """An ELT tool for managing your SQL transformations and data models. For more documentation on these commands, visit: docs.getdbt.com """ - ctx.obj = {} flags = Flags() # Logging # N.B. Legacy logger is not supported diff --git a/core/dbt/tests/util.py b/core/dbt/tests/util.py index dbff80d111a..7fe22e2c2fe 100644 --- a/core/dbt/tests/util.py +++ b/core/dbt/tests/util.py @@ -8,7 +8,8 @@ from contextlib import contextmanager from dbt.adapters.factory import Adapter -from dbt.cli.context import DBTContext, cli as dbt +from dbt.cli.context import DBTContext +from dbt.cli.main import cli from dbt.logger import log_manager from dbt.contracts.graph.manifest import Manifest from dbt.events.functions import fire_event, capture_stdout_logs, stop_capture_stdout_logs, reset_metadata_vars @@ -73,11 +74,9 @@ def run_dbt(args: List[str] = None, expect_pass=True): args = ["run"] print("\n\nInvoking dbt with {}".format(args)) - # ctx = DBTContext.from_args(args) - ctx = DBTContext(args) - + ctx = DBTContext(cli, args=args) # ctx.set_project('test') - res, success = dbt.invoke(ctx) + res, success = cli.invoke(ctx) if expect_pass is not None: assert success == expect_pass, "dbt exit state did not match expected"