Skip to content

Commit

Permalink
dbtRunner, initialize context in subcommand
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk committed Dec 12, 2022
1 parent 889cb06 commit e298039
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 61 deletions.
39 changes: 0 additions & 39 deletions core/dbt/cli/context.py

This file was deleted.

7 changes: 2 additions & 5 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from dbt.config.profile import read_user_config
from dbt.contracts.project import UserConfig
from typing import List

if os.name != "nt":
# https://bugs.python.org/issue41567
Expand All @@ -21,9 +20,7 @@

@dataclass(frozen=True)
class Flags:
def __init__(
self, ctx: Context = None, user_config: UserConfig = None, args: List[str] = sys.argv
) -> None:
def __init__(self, ctx: Context = None, user_config: UserConfig = None) -> None:

if ctx is None:
ctx = get_current_context()
Expand Down Expand Up @@ -53,7 +50,7 @@ def assign_params(ctx, params_assigned_from_default):
invoked_subcommand = getattr(import_module("dbt.cli.main"), invoked_subcommand_name)
invoked_subcommand.allow_extra_args = True
invoked_subcommand.ignore_unknown_options = True
invoked_subcommand_ctx = invoked_subcommand.make_context(None, args)
invoked_subcommand_ctx = invoked_subcommand.make_context(None, sys.argv)
assign_params(invoked_subcommand_ctx, params_assigned_from_default)

if not user_config:
Expand Down
71 changes: 58 additions & 13 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,73 @@
from dbt.cli import params as p
from dbt.cli.flags import Flags
from dbt.config import RuntimeConfig
from dbt.cli.context import DBTContext
from dbt.config.project import Project
from dbt.config.profile import Profile
from dbt.config.runtime import load_project, load_profile
from dbt.events.functions import setup_event_logger
from dbt.profiler import profiler
from dbt.task.deps import DepsTask
from dbt.task.run import RunTask
from dbt.tracking import initialize_from_flags, track_run


# CLI invocation
def cli_runner():
# Alias "list" to "ls"
ls = copy(cli.commands["list"])
ls.hidden = True
cli.add_command(ls, "ls")

cli.context_class = DBTContext

# Run the cli
cli()


class dbtUsageException(Exception):
pass


# Programmatic invocation
class dbtRunner:
def __init__(self, project: Project = None, profile: Profile = None):
self.project = project
self.profile = profile

def invoke(self, args):
dbt_ctx = cli.make_context(cli.name, args)
dbt_ctx.obj = {}
dbt_ctx.obj["project"] = self.project
dbt_ctx.obj["profile"] = self.profile

try:
return cli.invoke(dbt_ctx)
except (click.NoSuchOption, click.UsageError) as e:
raise dbtUsageException(e.message)


# TODO: refactor - consider decorator, or post-init on a dbtContext
def _initialize_context(ctx):
flags = Flags()

# Tracking
initialize_from_flags(flags.ANONYMOUS_USAGE_STATS, flags.PROFILES_DIR)
ctx.with_resource(track_run(run_command=flags.WHICH))

ctx.obj = ctx.obj or {}

# Profile
# TODO: generalize safe access to threads
threads = getattr(flags, "THREADS", None)
profile = load_profile(flags.PROJECT_DIR, flags.VARS, flags.PROFILE, flags.TARGET, threads)
# Project
if ctx.obj.get("project") is None:
project = load_project(flags.PROJECT_DIR, flags.VERSION_CHECK, profile, flags.VARS)
ctx.obj["project"] = project

# Context for downstream commands
ctx.obj["flags"] = flags
ctx.obj["profile"] = profile


# dbt
@click.group(
context_settings={"help_option_names": ["-h", "--help"]},
Expand Down Expand Up @@ -63,7 +110,12 @@ def cli(ctx, **kwargs):
For more documentation on these commands, visit: docs.getdbt.com
"""
# Get primatives
flags = ctx.obj["flags"]
flags = Flags()

# Version info
if flags.VERSION:
click.echo(f"`version` called\n ctx.params: {pf(ctx.params)}")
return

# Logging
# N.B. Legacy logger is not supported
Expand All @@ -74,22 +126,13 @@ def cli(ctx, **kwargs):
flags.DEBUG,
)

# Tracking
initialize_from_flags(flags.ANONYMOUS_USAGE_STATS, flags.PROFILES_DIR)
ctx.with_resource(track_run(run_command=ctx.invoked_subcommand))

# Profiling
if flags.RECORD_TIMING_INFO:
ctx.with_resource(profiler(enable=True, outfile=flags.RECORD_TIMING_INFO))

# Adapter management
ctx.with_resource(adapter_management())

# Version info
if flags.VERSION:
click.echo(f"`version` called\n ctx.params: {pf(ctx.params)}")
return


# dbt build
@cli.command("build")
Expand Down Expand Up @@ -230,6 +273,7 @@ def debug(ctx, **kwargs):
@p.vars
def deps(ctx, **kwargs):
"""Pull the most recent version of the dependencies listed in packages.yml"""
_initialize_context(ctx)
flags = ctx.obj["flags"]
project = ctx.obj["project"]

Expand Down Expand Up @@ -316,6 +360,7 @@ def parse(ctx, **kwargs):
@p.version_check
def run(ctx, **kwargs):
"""Compile SQL and execute against the current target database."""
_initialize_context(ctx)

config = RuntimeConfig.from_parts(ctx.obj["project"], ctx.obj["profile"], ctx.obj["flags"])
task = RunTask(ctx.obj["flags"], config)
Expand Down
7 changes: 3 additions & 4 deletions core/dbt/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from contextlib import contextmanager
from dbt.adapters.factory import Adapter

from dbt.cli.context import DBTContext
from dbt.cli.main import cli
from dbt.cli.main import dbtRunner
from dbt.logger import log_manager
from dbt.contracts.graph.manifest import Manifest
from dbt.events.functions import fire_event, capture_stdout_logs, stop_capture_stdout_logs, reset_metadata_vars
Expand Down Expand Up @@ -74,8 +73,8 @@ def run_dbt(args: List[str] = None, expect_pass=True):
args = ["run"]

print("\n\nInvoking dbt with {}".format(args))
ctx = DBTContext(cli, args=args)
res, success = cli.invoke(ctx)
dbt = dbtRunner()
res, success = dbt.invoke(args)

if expect_pass is not None:
assert success == expect_pass, "dbt exit state did not match expected"
Expand Down

0 comments on commit e298039

Please sign in to comment.