Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DepsTask with Click #6260

Merged
merged 9 commits into from
Nov 22, 2022
9 changes: 8 additions & 1 deletion core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dbt.profiler import profiler
from dbt.tracking import initialize_from_flags, track_run
from dbt.config.runtime import load_project
from dbt.task.deps import DepsTask


def cli_runner():
Expand Down Expand Up @@ -228,7 +229,13 @@ def debug(ctx, **kwargs):
def deps(ctx, **kwargs):
"""Pull the most recent version of the dependencies listed in packages.yml"""
flags = Flags()
click.echo(f"`{inspect.stack()[0][3]}` called\n flags: {flags}")
project = ctx.obj["project"]

task = DepsTask.from_project(project, flags.VARS)

results = task.run()
success = task.interpret_results(results)
return results, success


# dbt init
Expand Down
1 change: 1 addition & 0 deletions core/dbt/cli/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@
envvar=None,
help="Supply variables to the project. This argument overrides variables defined in your dbt_project.yml file. This argument should be a YAML string, eg. '{my_variable: my_value}'",
type=YAML(),
default="{}",
)

version = click.option(
Expand Down
11 changes: 9 additions & 2 deletions core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
Project as ProjectContract,
SemverString,
)
from dbt.contracts.project import PackageConfig
from dbt.contracts.project import PackageConfig, ProjectPackageMetadata
from dbt.dataclass_schema import ValidationError
from .renderer import DbtProjectYamlRenderer
from .renderer import DbtProjectYamlRenderer, PackageRenderer
from .selectors import (
selector_config_from_data,
selector_data_from_root,
Expand Down Expand Up @@ -289,6 +289,13 @@ def render(self, renderer: DbtProjectYamlRenderer) -> "Project":
exc.path = os.path.join(self.project_root, "dbt_project.yml")
raise

def render_package_metadata(self, renderer: PackageRenderer) -> ProjectPackageMetadata:
packages_data = renderer.render_data(self.packages_dict)
packages_config = package_config_from_data(packages_data)
if not self.project_name:
raise DbtProjectError(DbtProjectError("Package dbt_project.yml must have a name!"))
return ProjectPackageMetadata(self.project_name, packages_config.packages)

def check_config_path(self, project_dict, deprecated_path, exp_path):
if deprecated_path in project_dict:
if exp_path in project_dict:
Expand Down
11 changes: 7 additions & 4 deletions core/dbt/deps/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import List, Optional

from dbt.clients import git, system
from dbt.config import Project
from dbt.config.project import PartialProject, Project
from dbt.config.renderer import PackageRenderer
from dbt.contracts.project import (
ProjectPackageMetadata,
GitPackage,
Expand Down Expand Up @@ -89,7 +90,9 @@ def _checkout(self):
raise
return os.path.join(get_downloads_path(), dir_)

def _fetch_metadata(self, project, renderer) -> ProjectPackageMetadata:
def _fetch_metadata(
self, project: Project, renderer: PackageRenderer
) -> ProjectPackageMetadata:
path = self._checkout()

if self.unpinned_msg() and self.warn_unpinned:
Expand All @@ -100,8 +103,8 @@ def _fetch_metadata(self, project, renderer) -> ProjectPackageMetadata:
),
log_fmt=ui.yellow("WARNING: {}"),
)
loaded = Project.from_project_root(path, renderer)
return ProjectPackageMetadata.from_project(loaded)
partial = PartialProject.from_project_root(path)
return partial.render_package_metadata(renderer)

def install(self, project, renderer):
dest_path = self.get_installation_path(project, renderer)
Expand Down
10 changes: 7 additions & 3 deletions core/dbt/deps/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
)
from dbt.events.functions import fire_event
from dbt.events.types import DepsCreatingLocalSymlink, DepsSymlinkNotAvailable
from dbt.config.project import PartialProject, Project
from dbt.config.renderer import PackageRenderer


class LocalPackageMixin:
Expand Down Expand Up @@ -39,9 +41,11 @@ def resolve_path(self, project):
project.project_root,
)

def _fetch_metadata(self, project, renderer):
loaded = project.from_project_root(self.resolve_path(project), renderer)
return ProjectPackageMetadata.from_project(loaded)
def _fetch_metadata(
MichelleArk marked this conversation as resolved.
Show resolved Hide resolved
self, project: Project, renderer: PackageRenderer
) -> ProjectPackageMetadata:
partial = PartialProject.from_project_root(self.resolve_path(project))
return partial.render_package_metadata(renderer)

def install(self, project, renderer):
src_path = self.resolve_path(project)
Expand Down
25 changes: 14 additions & 11 deletions core/dbt/deps/resolver.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataclasses import dataclass, field
from typing import Dict, List, NoReturn, Union, Type, Iterator, Set
from typing import Dict, List, NoReturn, Union, Type, Iterator, Set, Any

from dbt.exceptions import raise_dependency_error, InternalException

from dbt.config import Project, RuntimeConfig
from dbt.config.renderer import DbtProjectYamlRenderer
from dbt.config import Project
from dbt.config.renderer import PackageRenderer
from dbt.deps.base import BasePackage, PinnedPackage, UnpinnedPackage
from dbt.deps.local import LocalUnpinnedPackage
from dbt.deps.git import GitUnpinnedPackage
Expand Down Expand Up @@ -94,19 +94,19 @@ def __iter__(self) -> Iterator[UnpinnedPackage]:

def _check_for_duplicate_project_names(
final_deps: List[PinnedPackage],
config: Project,
renderer: DbtProjectYamlRenderer,
project: Project,
renderer: PackageRenderer,
):
seen: Set[str] = set()
for package in final_deps:
project_name = package.get_project_name(config, renderer)
project_name = package.get_project_name(project, renderer)
if project_name in seen:
raise_dependency_error(
f'Found duplicate project "{project_name}". This occurs when '
"a dependency has the same project name as some other "
"dependency."
)
elif project_name == config.project_name:
elif project_name == project.project_name:
raise_dependency_error(
"Found a dependency with the same name as the root project "
f'"{project_name}". Package names must be unique in a project.'
Expand All @@ -116,21 +116,24 @@ def _check_for_duplicate_project_names(


def resolve_packages(
packages: List[PackageContract], config: RuntimeConfig
packages: List[PackageContract],
project: Project,
cli_vars: Dict[str, Any],
) -> List[PinnedPackage]:
pending = PackageListing.from_contracts(packages)
final = PackageListing()
renderer = DbtProjectYamlRenderer(config, config.cli_vars)

renderer = PackageRenderer(cli_vars)

while pending:
next_pending = PackageListing()
# resolve the dependency in question
for package in pending:
final.incorporate(package)
target = final[package].resolved().fetch_metadata(config, renderer)
target = final[package].resolved().fetch_metadata(project, renderer)
next_pending.update_from(target.packages)
pending = next_pending

resolved = final.resolved()
_check_for_duplicate_project_names(resolved, config, renderer)
_check_for_duplicate_project_names(resolved, project, renderer)
return resolved
20 changes: 11 additions & 9 deletions core/dbt/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,12 @@ def read_profiles(profiles_dir=None):
class BaseTask(metaclass=ABCMeta):
ConfigType: Union[Type[NoneConfig], Type[Project]] = NoneConfig

def __init__(self, args, config):
def __init__(self, args, config, project=None):
self.args = args
self.args.single_threaded = False
self.config = config
if hasattr(config, "args"):
self.config.args.single_threaded = False
self.project = config if isinstance(config, Project) else project

@classmethod
def pre_init_hook(cls, args):
Expand Down Expand Up @@ -140,13 +142,13 @@ def interpret_results(self, results):
return True


def get_nearest_project_dir(args):
def get_nearest_project_dir(project_dir: Optional[str]) -> str:
# If the user provides an explicit project directory, use that
# but don't look at parent directories.
if args.project_dir:
project_file = os.path.join(args.project_dir, "dbt_project.yml")
if project_dir:
project_file = os.path.join(project_dir, "dbt_project.yml")
if os.path.exists(project_file):
return args.project_dir
return project_dir
else:
raise dbt.exceptions.RuntimeException(
"fatal: Invalid --project-dir flag. Not a dbt project. "
Expand All @@ -168,8 +170,8 @@ def get_nearest_project_dir(args):
)


def move_to_nearest_project_dir(args):
nearest_project_dir = get_nearest_project_dir(args)
def move_to_nearest_project_dir(project_dir: Optional[str]) -> str:
nearest_project_dir = get_nearest_project_dir(project_dir)
os.chdir(nearest_project_dir)
return nearest_project_dir

Expand All @@ -183,7 +185,7 @@ def __init__(self, args, config):

@classmethod
def from_args(cls, args):
move_to_nearest_project_dir(args)
move_to_nearest_project_dir(args.project_dir)
return super().from_args(args)


Expand Down
2 changes: 1 addition & 1 deletion core/dbt/task/clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def run(self):
This function takes all the paths in the target file
and cleans the project paths that are not protected.
"""
move_to_nearest_project_dir(self.args)
move_to_nearest_project_dir(self.args.project_dir)
if (
"dbt_modules" in self.config.clean_targets
and self.config.packages_install_path not in self.config.clean_targets
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/task/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, args, config):
self.profiles_dir = flags.PROFILES_DIR
self.profile_path = os.path.join(self.profiles_dir, "profiles.yml")
try:
self.project_dir = get_nearest_project_dir(self.args)
self.project_dir = get_nearest_project_dir(self.args.project_dir)
except dbt.exceptions.Exception:
# we probably couldn't find a project directory. Set project dir
# to whatever was given, or default to the current directory.
Expand Down
63 changes: 50 additions & 13 deletions core/dbt/task/deps.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from typing import Dict, Any

from dbt import flags

import dbt.utils
import dbt.deprecations
import dbt.exceptions

from dbt.config import UnsetProfileConfig
from dbt.config.profile import read_user_config
from dbt.config.runtime import load_project, UnsetProfile
from dbt.config.renderer import DbtProjectYamlRenderer
from dbt.config.utils import parse_cli_vars
from dbt.deps.base import downloads_directory
from dbt.deps.resolver import resolve_packages

Expand All @@ -23,11 +29,21 @@
from dbt.task.base import BaseTask, move_to_nearest_project_dir


from dbt.config import Project
from dbt.task.base import NoneConfig


class DepsTask(BaseTask):
ChenyuLInx marked this conversation as resolved.
Show resolved Hide resolved
ConfigType = UnsetProfileConfig
ConfigType = NoneConfig

def __init__(self, args, config: UnsetProfileConfig):
super().__init__(args=args, config=config)
def __init__(
self,
args: Any,
project: Project,
cli_vars: Dict[str, Any],
):
super().__init__(args=args, config=None, project=project)
self.cli_vars = cli_vars

def track_package_install(self, package_name: str, source_type: str, version: str) -> None:
# Hub packages do not need to be hashed, as they are public
Expand All @@ -39,22 +55,22 @@ def track_package_install(self, package_name: str, source_type: str, version: st
package_name = dbt.utils.md5(package_name)
version = dbt.utils.md5(version)
dbt.tracking.track_package_install(
self.config,
self.config.args,
"deps",
self.project.hashed_name(),
{"name": package_name, "source": source_type, "version": version},
)

def run(self):
system.make_directory(self.config.packages_install_path)
packages = self.config.packages.packages
system.make_directory(self.project.packages_install_path)
packages = self.project.packages.packages
if not packages:
fire_event(DepsNoPackagesFound())
return

with downloads_directory():
final_deps = resolve_packages(packages, self.config)
final_deps = resolve_packages(packages, self.project, self.cli_vars)

renderer = DbtProjectYamlRenderer(self.config, self.config.cli_vars)
renderer = DbtProjectYamlRenderer(None, self.cli_vars)

packages_to_upgrade = []
for package in final_deps:
Expand All @@ -63,7 +79,7 @@ def run(self):
version = package.get_version()

fire_event(DepsStartPackageInstall(package_name=package_name))
package.install(self.config, renderer)
package.install(self.project, renderer)
fire_event(DepsInstallInfo(version_name=package.nice_version_name()))
if source_type == "hub":
version_latest = package.get_version_latest()
Expand All @@ -82,9 +98,30 @@ def run(self):
fire_event(EmptyLine())
fire_event(DepsNotifyUpdatesAvailable(packages=packages_to_upgrade))

@classmethod
def _get_unset_profile(cls) -> UnsetProfile:
MichelleArk marked this conversation as resolved.
Show resolved Hide resolved
profile = UnsetProfile()
# The profile (for warehouse connection) is not needed, but we want
# to get the UserConfig, which is also in profiles.yml
user_config = read_user_config(flags.PROFILES_DIR)
profile.user_config = user_config
return profile

@classmethod
def from_args(cls, args):
# deps needs to move to the project directory, as it does put files
# into the modules directory
move_to_nearest_project_dir(args)
return super().from_args(args)
nearest_project_dir = move_to_nearest_project_dir(args.project_dir)

cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, "vars", "{}"))
project_root: str = args.project_dir or nearest_project_dir
profile: UnsetProfile = cls._get_unset_profile()
project = load_project(project_root, args.version_check, profile, cli_vars)

return cls(args, project, cli_vars)

@classmethod
def from_project(cls, project: Project, cli_vars: Dict[str, Any]) -> "DepsTask":
move_to_nearest_project_dir(project.project_root)
# TODO: remove args=None once BaseTask does not require args
return cls(None, project, cli_vars)
2 changes: 1 addition & 1 deletion core/dbt/task/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def run(self):
self.create_profiles_dir(profiles_dir)

try:
move_to_nearest_project_dir(self.args)
move_to_nearest_project_dir(self.args.project_dir)
in_project = True
except dbt.exceptions.RuntimeException:
in_project = False
Expand Down
9 changes: 2 additions & 7 deletions core/dbt/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,17 +287,12 @@ def get_base_invocation_context():
}


def track_package_install(config, args, options):
def track_package_install(command_name: str, project_hashed_name: Optional[str], options):
assert active_user is not None, "Cannot track package installs when active user is None"

invocation_data = get_base_invocation_context()

invocation_data.update(
{
"project_id": None if config is None else config.hashed_name(),
"command": args.which,
}
)
invocation_data.update({"project_id": project_hashed_name, "command": command_name})

context = [
SelfDescribingJson(INVOCATION_SPEC, invocation_data),
Expand Down
Loading