Skip to content

Commit

Permalink
DBTContext
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk authored and ChenyuLInx committed Dec 8, 2022
1 parent 3ee0790 commit b247973
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
44 changes: 44 additions & 0 deletions core/dbt/cli/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import List
from click import Context
from click.exceptions import NoSuchOption, UsageError
from dbt.cli.main import cli
from dbt.config.project import Project

class DBTUsageException(Exception):
pass

class DBTContext(Context):
def __init__(self, args: List[str]) -> None:
try:
ctx = cli.make_context(cli.name, args)
if args:
cmd_name, cmd, cmd_args = cli.resolve_command(ctx, args)
cmd.make_context(cmd_name, cmd_args, parent=ctx)
except (NoSuchOption, UsageError) as e:
raise DBTUsageException(e.message)

ctx.obj = {}
# yikes?
self.__dict__.update(ctx.__dict__)
# TODO: consider initializing Flags, ctx.obj here.

# @classmethod
# def from_args(cls, args: List[str]) -> "DBTContext":
# try:
# ctx = cli.make_context(cli.name, args)
# if args:
# cmd_name, cmd, cmd_args = cli.resolve_command(ctx, args)
# cmd.make_context(cmd_name, cmd_args, parent=ctx)
# except (NoSuchOption, UsageError) as e:
# raise DBTUsageException(e.message)

# ctx.obj = {}
# # yikes
# ctx.__class__ = cls
# return ctx

def set_project(self, project: Project):
if not isinstance(project, Project):
raise ValueError(f"{project} is a {type(project)}, expected a Project object.")

self.obj["project"] = project
7 changes: 6 additions & 1 deletion core/dbt/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dbt.adapters.factory import Adapter

from dbt.main import handle_and_check
from dbt.cli.context import DBTContext, cli as dbt
from dbt.logger import log_manager
from dbt.contracts.graph.manifest import Manifest
from dbt.events.functions import fire_event, capture_stdout_logs, stop_capture_stdout_logs, reset_metadata_vars
Expand Down Expand Up @@ -73,7 +74,11 @@ def run_dbt(args: List[str] = None, expect_pass=True):
args = ["run"]

print("\n\nInvoking dbt with {}".format(args))
res, success = handle_and_check(args)
# ctx = DBTContext.from_args(args)
ctx = DBTContext(args)

# ctx.set_project('test')
res, success = dbt.invoke(ctx)

if expect_pass is not None:
assert success == expect_pass, "dbt exit state did not match expected"
Expand Down

0 comments on commit b247973

Please sign in to comment.