From e298039773bcf93ed2e61b57c472cbe629bb171a Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Mon, 12 Dec 2022 17:12:37 -0500 Subject: [PATCH] dbtRunner, initialize context in subcommand --- core/dbt/cli/context.py | 39 ---------------------- core/dbt/cli/flags.py | 7 ++-- core/dbt/cli/main.py | 71 +++++++++++++++++++++++++++++++++-------- core/dbt/tests/util.py | 7 ++-- 4 files changed, 63 insertions(+), 61 deletions(-) delete mode 100644 core/dbt/cli/context.py diff --git a/core/dbt/cli/context.py b/core/dbt/cli/context.py deleted file mode 100644 index a9f373f85da..00000000000 --- a/core/dbt/cli/context.py +++ /dev/null @@ -1,39 +0,0 @@ -from click import Context, Group, Command -from click.exceptions import NoSuchOption, UsageError -from dbt.config.runtime import load_project, load_profile -from dbt.cli.flags import Flags -import sys - - -class DBTUsageException(Exception): - pass - - -class DBTContext(Context): - def __init__(self, command: Command, **kwargs) -> None: - invocation_args = kwargs.pop("args", sys.argv[1:]) - super().__init__(command, **kwargs) - - # Bubble up validation errors for top-level commands - if not self.parent: - self._validate_args(command, invocation_args) - - if not self.obj: - flags = Flags(self, args=invocation_args) - # TODO: fix flags.THREADS access - # TODO: set accept pluggable profile, project objects - profile = load_profile(flags.PROJECT_DIR, flags.VARS, flags.PROFILE, flags.TARGET, None) # type: ignore - project = load_project(flags.PROJECT_DIR, flags.VERSION_CHECK, profile, flags.VARS) # type: ignore - self.obj = {} - self.obj["flags"] = flags - self.obj["profile"] = profile - self.obj["project"] = project - - def _validate_args(self, command, args) -> None: - try: - command.parse_args(self, args) - if isinstance(command, Group): - _, cmd, cmd_args = command.resolve_command(self, args) - self._validate_args(cmd, cmd_args) - except (NoSuchOption, UsageError) as e: - raise DBTUsageException(e.message) diff --git a/core/dbt/cli/flags.py b/core/dbt/cli/flags.py index b206f181c6d..a8af62bf61d 100644 --- a/core/dbt/cli/flags.py +++ b/core/dbt/cli/flags.py @@ -12,7 +12,6 @@ from dbt.config.profile import read_user_config from dbt.contracts.project import UserConfig -from typing import List if os.name != "nt": # https://bugs.python.org/issue41567 @@ -21,9 +20,7 @@ @dataclass(frozen=True) class Flags: - def __init__( - self, ctx: Context = None, user_config: UserConfig = None, args: List[str] = sys.argv - ) -> None: + def __init__(self, ctx: Context = None, user_config: UserConfig = None) -> None: if ctx is None: ctx = get_current_context() @@ -53,7 +50,7 @@ def assign_params(ctx, params_assigned_from_default): invoked_subcommand = getattr(import_module("dbt.cli.main"), invoked_subcommand_name) invoked_subcommand.allow_extra_args = True invoked_subcommand.ignore_unknown_options = True - invoked_subcommand_ctx = invoked_subcommand.make_context(None, args) + invoked_subcommand_ctx = invoked_subcommand.make_context(None, sys.argv) assign_params(invoked_subcommand_ctx, params_assigned_from_default) if not user_config: diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index f9f354f3ee1..e67a9262577 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -7,7 +7,9 @@ from dbt.cli import params as p from dbt.cli.flags import Flags from dbt.config import RuntimeConfig -from dbt.cli.context import DBTContext +from dbt.config.project import Project +from dbt.config.profile import Profile +from dbt.config.runtime import load_project, load_profile from dbt.events.functions import setup_event_logger from dbt.profiler import profiler from dbt.task.deps import DepsTask @@ -15,18 +17,63 @@ from dbt.tracking import initialize_from_flags, track_run +# CLI invocation def cli_runner(): # Alias "list" to "ls" ls = copy(cli.commands["list"]) ls.hidden = True cli.add_command(ls, "ls") - cli.context_class = DBTContext - # Run the cli cli() +class dbtUsageException(Exception): + pass + + +# Programmatic invocation +class dbtRunner: + def __init__(self, project: Project = None, profile: Profile = None): + self.project = project + self.profile = profile + + def invoke(self, args): + dbt_ctx = cli.make_context(cli.name, args) + dbt_ctx.obj = {} + dbt_ctx.obj["project"] = self.project + dbt_ctx.obj["profile"] = self.profile + + try: + return cli.invoke(dbt_ctx) + except (click.NoSuchOption, click.UsageError) as e: + raise dbtUsageException(e.message) + + +# TODO: refactor - consider decorator, or post-init on a dbtContext +def _initialize_context(ctx): + flags = Flags() + + # Tracking + initialize_from_flags(flags.ANONYMOUS_USAGE_STATS, flags.PROFILES_DIR) + ctx.with_resource(track_run(run_command=flags.WHICH)) + + ctx.obj = ctx.obj or {} + + # Profile + # TODO: generalize safe access to threads + threads = getattr(flags, "THREADS", None) + profile = load_profile(flags.PROJECT_DIR, flags.VARS, flags.PROFILE, flags.TARGET, threads) + # Project + if ctx.obj.get("project") is None: + project = load_project(flags.PROJECT_DIR, flags.VERSION_CHECK, profile, flags.VARS) + ctx.obj["project"] = project + + # Context for downstream commands + ctx.obj["flags"] = flags + ctx.obj["profile"] = profile + + # dbt @click.group( context_settings={"help_option_names": ["-h", "--help"]}, @@ -63,7 +110,12 @@ def cli(ctx, **kwargs): For more documentation on these commands, visit: docs.getdbt.com """ # Get primatives - flags = ctx.obj["flags"] + flags = Flags() + + # Version info + if flags.VERSION: + click.echo(f"`version` called\n ctx.params: {pf(ctx.params)}") + return # Logging # N.B. Legacy logger is not supported @@ -74,10 +126,6 @@ def cli(ctx, **kwargs): flags.DEBUG, ) - # Tracking - initialize_from_flags(flags.ANONYMOUS_USAGE_STATS, flags.PROFILES_DIR) - ctx.with_resource(track_run(run_command=ctx.invoked_subcommand)) - # Profiling if flags.RECORD_TIMING_INFO: ctx.with_resource(profiler(enable=True, outfile=flags.RECORD_TIMING_INFO)) @@ -85,11 +133,6 @@ def cli(ctx, **kwargs): # Adapter management ctx.with_resource(adapter_management()) - # Version info - if flags.VERSION: - click.echo(f"`version` called\n ctx.params: {pf(ctx.params)}") - return - # dbt build @cli.command("build") @@ -230,6 +273,7 @@ def debug(ctx, **kwargs): @p.vars def deps(ctx, **kwargs): """Pull the most recent version of the dependencies listed in packages.yml""" + _initialize_context(ctx) flags = ctx.obj["flags"] project = ctx.obj["project"] @@ -316,6 +360,7 @@ def parse(ctx, **kwargs): @p.version_check def run(ctx, **kwargs): """Compile SQL and execute against the current target database.""" + _initialize_context(ctx) config = RuntimeConfig.from_parts(ctx.obj["project"], ctx.obj["profile"], ctx.obj["flags"]) task = RunTask(ctx.obj["flags"], config) diff --git a/core/dbt/tests/util.py b/core/dbt/tests/util.py index 2d8c51124d4..6cdc4ee5b77 100644 --- a/core/dbt/tests/util.py +++ b/core/dbt/tests/util.py @@ -8,8 +8,7 @@ from contextlib import contextmanager from dbt.adapters.factory import Adapter -from dbt.cli.context import DBTContext -from dbt.cli.main import cli +from dbt.cli.main import dbtRunner from dbt.logger import log_manager from dbt.contracts.graph.manifest import Manifest from dbt.events.functions import fire_event, capture_stdout_logs, stop_capture_stdout_logs, reset_metadata_vars @@ -74,8 +73,8 @@ def run_dbt(args: List[str] = None, expect_pass=True): args = ["run"] print("\n\nInvoking dbt with {}".format(args)) - ctx = DBTContext(cli, args=args) - res, success = cli.invoke(ctx) + dbt = dbtRunner() + res, success = dbt.invoke(args) if expect_pass is not None: assert success == expect_pass, "dbt exit state did not match expected"