Skip to content

Commit

Permalink
build self.obj in DBTContext
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk committed Dec 12, 2022
1 parent a045ca5 commit 889cb06
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 40 deletions.
30 changes: 14 additions & 16 deletions core/dbt/cli/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from click import Context, Group, Command
from click.exceptions import NoSuchOption, UsageError
from dbt.config.runtime import load_project, load_profile
from dbt.cli.flags import Flags
from dbt.config.project import Project
import sys


Expand All @@ -11,19 +11,23 @@ class DBTUsageException(Exception):

class DBTContext(Context):
def __init__(self, command: Command, **kwargs) -> None:
if isinstance(kwargs.get("parent"), DBTContext):
self.invocation_args = kwargs["parent"].invocation_args
else:
self.invocation_args = kwargs.pop("args", sys.argv[1:])

invocation_args = kwargs.pop("args", sys.argv[1:])
super().__init__(command, **kwargs)

# Bubble up validation errors for top-level commands
if not self.parent:
self._validate_args(command, self.invocation_args)

self.obj = self.obj or {}
self.flags = Flags(self)
self._validate_args(command, invocation_args)

if not self.obj:
flags = Flags(self, args=invocation_args)
# TODO: fix flags.THREADS access
# TODO: set accept pluggable profile, project objects
profile = load_profile(flags.PROJECT_DIR, flags.VARS, flags.PROFILE, flags.TARGET, None) # type: ignore
project = load_project(flags.PROJECT_DIR, flags.VERSION_CHECK, profile, flags.VARS) # type: ignore
self.obj = {}
self.obj["flags"] = flags
self.obj["profile"] = profile
self.obj["project"] = project

def _validate_args(self, command, args) -> None:
try:
Expand All @@ -33,9 +37,3 @@ def _validate_args(self, command, args) -> None:
self._validate_args(cmd, cmd_args)
except (NoSuchOption, UsageError) as e:
raise DBTUsageException(e.message)

def set_project(self, project: Project):
if not isinstance(project, Project):
raise ValueError(f"{project} is a {type(project)}, expected a Project object.")

self.obj["project"] = project
7 changes: 5 additions & 2 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from dbt.config.profile import read_user_config
from dbt.contracts.project import UserConfig
from typing import List

if os.name != "nt":
# https://bugs.python.org/issue41567
Expand All @@ -20,7 +21,9 @@

@dataclass(frozen=True)
class Flags:
def __init__(self, ctx: Context = None, user_config: UserConfig = None) -> None:
def __init__(
self, ctx: Context = None, user_config: UserConfig = None, args: List[str] = sys.argv
) -> None:

if ctx is None:
ctx = get_current_context()
Expand Down Expand Up @@ -50,7 +53,7 @@ def assign_params(ctx, params_assigned_from_default):
invoked_subcommand = getattr(import_module("dbt.cli.main"), invoked_subcommand_name)
invoked_subcommand.allow_extra_args = True
invoked_subcommand.ignore_unknown_options = True
invoked_subcommand_ctx = invoked_subcommand.make_context(None, sys.argv)
invoked_subcommand_ctx = invoked_subcommand.make_context(None, args)
assign_params(invoked_subcommand_ctx, params_assigned_from_default)

if not user_config:
Expand Down
23 changes: 2 additions & 21 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from dbt.cli import params as p
from dbt.cli.flags import Flags
from dbt.config import RuntimeConfig
from dbt.config.runtime import load_project, load_profile
from dbt.cli.context import DBTContext
from dbt.events.functions import setup_event_logger
from dbt.profiler import profiler
Expand All @@ -22,10 +21,7 @@ def cli_runner():
ls.hidden = True
cli.add_command(ls, "ls")

# TODO: set context_class this on all commands outside this method
cli.context_class = DBTContext
for command in cli.commands.values():
command.context_class = DBTContext

# Run the cli
cli()
Expand Down Expand Up @@ -67,7 +63,7 @@ def cli(ctx, **kwargs):
For more documentation on these commands, visit: docs.getdbt.com
"""
# Get primatives
flags = Flags()
flags = ctx.obj["flags"]

# Logging
# N.B. Legacy logger is not supported
Expand All @@ -94,21 +90,6 @@ def cli(ctx, **kwargs):
click.echo(f"`version` called\n ctx.params: {pf(ctx.params)}")
return

# Profile
# TODO: fix flags.THREADS access
profile = load_profile(
flags.PROJECT_DIR, flags.VARS, flags.PROFILE, flags.TARGET, None
)

# Project
project = load_project(flags.PROJECT_DIR, flags.VERSION_CHECK, profile, flags.VARS)

# Context for downstream commands
ctx.obj = {}
ctx.obj["flags"] = flags
ctx.obj["profile"] = profile
ctx.obj["project"] = project


# dbt build
@cli.command("build")
Expand Down Expand Up @@ -249,7 +230,7 @@ def debug(ctx, **kwargs):
@p.vars
def deps(ctx, **kwargs):
"""Pull the most recent version of the dependencies listed in packages.yml"""
flags = Flags()
flags = ctx.obj["flags"]
project = ctx.obj["project"]

task = DepsTask.from_project(project, flags.VARS)
Expand Down
1 change: 0 additions & 1 deletion core/dbt/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def run_dbt(args: List[str] = None, expect_pass=True):

print("\n\nInvoking dbt with {}".format(args))
ctx = DBTContext(cli, args=args)
# ctx.set_project('test')
res, success = cli.invoke(ctx)

if expect_pass is not None:
Expand Down

0 comments on commit 889cb06

Please sign in to comment.