Skip to content

Commit

Permalink
DBTContext with invocation_args
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk committed Dec 8, 2022
1 parent d37d023 commit 3efaf94
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 38 deletions.
61 changes: 29 additions & 32 deletions core/dbt/cli/context.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,41 @@
from typing import List
from click import Context
from click import Context, Group, Command
from click.exceptions import NoSuchOption, UsageError
from dbt.cli.main import cli
from dbt.cli.flags import Flags
from dbt.config.project import Project
import sys


class DBTUsageException(Exception):
pass


class DBTContext(Context):
def __init__(self, args: List[str]) -> None:
try:
ctx = cli.make_context(cli.name, args)
if args:
cmd_name, cmd, cmd_args = cli.resolve_command(ctx, args)
cmd.make_context(cmd_name, cmd_args, parent=ctx)
def __init__(self, command: Command, **kwargs) -> None:
if isinstance(kwargs.get("parent"), DBTContext):
self.invocation_args = kwargs["parent"].invocation_args
else:
self.invocation_args = kwargs.pop("args", sys.argv[1:])

super().__init__(command, **kwargs)

# Bubble up validation errors for top-level commands
if not self.parent:
self._validate_args(command, self.invocation_args)

self.obj = self.obj or {}
self.flags = Flags(self)

def _validate_args(self, command, args) -> None:
try:
command.parse_args(self, args)
if isinstance(command, Group):
_, cmd, cmd_args = command.resolve_command(self, args)
self._validate_args(cmd, cmd_args)
except (NoSuchOption, UsageError) as e:
raise DBTUsageException(e.message)

ctx.obj = {}
# yikes?
self.__dict__.update(ctx.__dict__)
# TODO: consider initializing Flags, ctx.obj here.

# @classmethod
# def from_args(cls, args: List[str]) -> "DBTContext":
# try:
# ctx = cli.make_context(cli.name, args)
# if args:
# cmd_name, cmd, cmd_args = cli.resolve_command(ctx, args)
# cmd.make_context(cmd_name, cmd_args, parent=ctx)
# except (NoSuchOption, UsageError) as e:
# raise DBTUsageException(e.message)

# ctx.obj = {}
# # yikes
# ctx.__class__ = cls
# return ctx
raise DBTUsageException(e.message)

def set_project(self, project: Project):
if not isinstance(project, Project):
raise ValueError(f"{project} is a {type(project)}, expected a Project object.")
self.obj["project"] = project

self.obj["project"] = project
7 changes: 6 additions & 1 deletion core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dbt.adapters.factory import adapter_management
from dbt.cli import params as p
from dbt.cli.flags import Flags
from dbt.cli.context import DBTContext
from dbt.events.functions import setup_event_logger
from dbt.profiler import profiler
from dbt.tracking import initialize_from_flags, track_run
Expand All @@ -19,6 +20,11 @@ def cli_runner():
ls.hidden = True
cli.add_command(ls, "ls")

# TODO: set context_class this on all commands outside this method
cli.context_class = DBTContext
for command in cli.commands.values():
command.context_class = DBTContext

# Run the cli
cli()

Expand Down Expand Up @@ -57,7 +63,6 @@ def cli(ctx, **kwargs):
"""An ELT tool for managing your SQL transformations and data models.
For more documentation on these commands, visit: docs.getdbt.com
"""
ctx.obj = {}
flags = Flags()
# Logging
# N.B. Legacy logger is not supported
Expand Down
9 changes: 4 additions & 5 deletions core/dbt/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from contextlib import contextmanager
from dbt.adapters.factory import Adapter

from dbt.cli.context import DBTContext, cli as dbt
from dbt.cli.context import DBTContext
from dbt.cli.main import cli
from dbt.logger import log_manager
from dbt.contracts.graph.manifest import Manifest
from dbt.events.functions import fire_event, capture_stdout_logs, stop_capture_stdout_logs, reset_metadata_vars
Expand Down Expand Up @@ -73,11 +74,9 @@ def run_dbt(args: List[str] = None, expect_pass=True):
args = ["run"]

print("\n\nInvoking dbt with {}".format(args))
# ctx = DBTContext.from_args(args)
ctx = DBTContext(args)

ctx = DBTContext(cli, args=args)
# ctx.set_project('test')
res, success = dbt.invoke(ctx)
res, success = cli.invoke(ctx)

if expect_pass is not None:
assert success == expect_pass, "dbt exit state did not match expected"
Expand Down

0 comments on commit 3efaf94

Please sign in to comment.