Skip to content

Commit

Permalink
DBTContext with invocation_args
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk committed Dec 8, 2022
1 parent 1e5d173 commit c0d1f73
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 37 deletions.
61 changes: 29 additions & 32 deletions core/dbt/cli/context.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,41 @@
from typing import List
from click import Context
from click import Context, Group, Command
from click.exceptions import NoSuchOption, UsageError
from dbt.cli.main import cli
from dbt.cli.flags import Flags
from dbt.config.project import Project
import sys


class DBTUsageException(Exception):
pass


class DBTContext(Context):
def __init__(self, args: List[str]) -> None:
try:
ctx = cli.make_context(cli.name, args)
if args:
cmd_name, cmd, cmd_args = cli.resolve_command(ctx, args)
cmd.make_context(cmd_name, cmd_args, parent=ctx)
def __init__(self, command: Command, **kwargs) -> None:
if isinstance(kwargs.get("parent"), DBTContext):
self.invocation_args = kwargs["parent"].invocation_args
else:
self.invocation_args = kwargs.pop("args", sys.argv[1:])

super().__init__(command, **kwargs)

# Bubble up validation errors for top-level commands
if not self.parent:
self._validate_args(command, self.invocation_args)

self.obj = self.obj or {}
self.flags = Flags(self)

def _validate_args(self, command, args) -> None:
try:
command.parse_args(self, args)
if isinstance(command, Group):
_, cmd, cmd_args = command.resolve_command(self, args)
self._validate_args(cmd, cmd_args)
except (NoSuchOption, UsageError) as e:
raise DBTUsageException(e.message)

ctx.obj = {}
# yikes?
self.__dict__.update(ctx.__dict__)
# TODO: consider initializing Flags, ctx.obj here.

# @classmethod
# def from_args(cls, args: List[str]) -> "DBTContext":
# try:
# ctx = cli.make_context(cli.name, args)
# if args:
# cmd_name, cmd, cmd_args = cli.resolve_command(ctx, args)
# cmd.make_context(cmd_name, cmd_args, parent=ctx)
# except (NoSuchOption, UsageError) as e:
# raise DBTUsageException(e.message)

# ctx.obj = {}
# # yikes
# ctx.__class__ = cls
# return ctx
raise DBTUsageException(e.message)

def set_project(self, project: Project):
if not isinstance(project, Project):
raise ValueError(f"{project} is a {type(project)}, expected a Project object.")
self.obj["project"] = project

self.obj["project"] = project
6 changes: 6 additions & 0 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dbt.cli.flags import Flags
from dbt.config import RuntimeConfig
from dbt.config.runtime import load_project, load_profile
from dbt.cli.context import DBTContext
from dbt.events.functions import setup_event_logger
from dbt.profiler import profiler
from dbt.task.deps import DepsTask
Expand All @@ -21,6 +22,11 @@ def cli_runner():
ls.hidden = True
cli.add_command(ls, "ls")

# TODO: set context_class this on all commands outside this method
cli.context_class = DBTContext
for command in cli.commands.values():
command.context_class = DBTContext

# Run the cli
cli()

Expand Down
9 changes: 4 additions & 5 deletions core/dbt/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from contextlib import contextmanager
from dbt.adapters.factory import Adapter

from dbt.cli.context import DBTContext, cli as dbt
from dbt.cli.context import DBTContext
from dbt.cli.main import cli
from dbt.logger import log_manager
from dbt.contracts.graph.manifest import Manifest
from dbt.events.functions import fire_event, capture_stdout_logs, stop_capture_stdout_logs, reset_metadata_vars
Expand Down Expand Up @@ -73,11 +74,9 @@ def run_dbt(args: List[str] = None, expect_pass=True):
args = ["run"]

print("\n\nInvoking dbt with {}".format(args))
# ctx = DBTContext.from_args(args)
ctx = DBTContext(args)

ctx = DBTContext(cli, args=args)
# ctx.set_project('test')
res, success = dbt.invoke(ctx)
res, success = cli.invoke(ctx)

if expect_pass is not None:
assert success == expect_pass, "dbt exit state did not match expected"
Expand Down

0 comments on commit c0d1f73

Please sign in to comment.