diff --git a/.gitignore b/.gitignore index 762cc89..f2b3eaf 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,11 @@ dmypy.json # pytype static type analyzer .pytype/ + +.run/ + +.terraform/ + +.user.yaml + +.idea/ diff --git a/README.md b/README.md index 9a4368d..df197aa 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ There are five operators currently implemented: * Calls [`dbt test`](https://docs.getdbt.com/docs/test) -Each of the above operators accept the following arguments: +Each of the above operators accept the arguments in [here (dbt_command_config)](airflow_dbt/dbt_command_config.py). The main ones being: * `profiles_dir` * If set, passed as the `--profiles-dir` argument to the `dbt` command @@ -96,6 +96,68 @@ Typically you will want to use the `DbtRunOperator`, followed by the `DbtTestOpe You can also use the hook directly. Typically this can be used for when you need to combine the `dbt` command with another task in the same operators, for example running `dbt docs` and uploading the docs to somewhere they can be served from. +## A more advanced example: + +If want to run your `dbt` project other tan in the airflow worker you can use +the `DbtCloudBuildHook` and apply it to the `DbtBaseOperator` or simply use the +provided `DbtCloudBuildOperator`: + +```python +from airflow_dbt.hooks import DbtCloudBuildHook +from airflow_dbt.operators import DbtBaseOperator, DbtCloudBuildOperator +DbtBaseOperator( + task_id='provide_hook', + command='run', + use_colors=False, + config={ + 'profiles_dir': './jaffle-shop', + 'project_dir': './jaffle-shop', + }, + dbt_hook=DbtCloudBuildHook( + gcs_staging_location='gs://my-bucket/compressed-dbt-project.tar.gz' + ) +) + +DbtCloudBuildOperator( + task_id='default_hook_cloudbuild', + gcs_staging_location='gs://my-bucket/compressed-dbt-project.tar.gz', + command='run', + use_colors=False, + config={ + 'profiles_dir': './jaffle-shop', + 'project_dir': './jaffle-shop', + }, +) +``` + +You can either define the dbt params/config/flags in the operator or you can +group them into a `config` param. They both have validation, but only the config +has templating. The following two tasks are equivalent: + +```python +from airflow_dbt.operators.dbt_operator import DbtBaseOperator + +DbtBaseOperator( + task_id='config_param', + command='run', + config={ + 'profiles_dir': './jaffle-shop', + 'project_dir': './jaffle-shop', + 'dbt_bin': '/usr/local/airflow/.local/bin/dbt', + 'use_colors': False + } +) + +DbtBaseOperator( + task_id='flat_config', + command='run', + profiles_dir='./jaffle-shop', + project_dir='./jaffle-shop', + dbt_bin='/usr/local/airflow/.local/bin/dbt', + use_colors=False +) +``` + ## Building Locally To install from the repository: @@ -147,7 +209,9 @@ If you use MWAA, you just need to update the `requirements.txt` file and add `ai Then you can have your dbt code inside a folder `{DBT_FOLDER}` in the dags folder on S3 and configure the dbt task like below: ```python -dbt_run = DbtRunOperator( +from airflow_dbt.operators.dbt_operator import DbtRunOperator + +dbt_run=DbtRunOperator( task_id='dbt_run', dbt_bin='/usr/local/airflow/.local/bin/dbt', profiles_dir='/usr/local/airflow/dags/{DBT_FOLDER}/', diff --git a/airflow_dbt/__init__.py b/airflow_dbt/__init__.py index 9419f4f..e69de29 100644 --- a/airflow_dbt/__init__.py +++ b/airflow_dbt/__init__.py @@ -1,9 +0,0 @@ -from .hooks import DbtCliHook -from .operators import ( - DbtSeedOperator, - DbtSnapshotOperator, - DbtRunOperator, - DbtTestOperator, - DbtDocsGenerateOperator, - DbtDepsOperator -) diff --git a/airflow_dbt/__version__.py b/airflow_dbt/__version__.py index 9f59e6c..db55ef1 100644 --- a/airflow_dbt/__version__.py +++ b/airflow_dbt/__version__.py @@ -1,3 +1 @@ -VERSION = (0, 4, 0) - -__version__ = '.'.join(map(str, VERSION)) +__version__ = "0.5.10" diff --git a/airflow_dbt/dbt_command_config.py b/airflow_dbt/dbt_command_config.py new file mode 100644 index 0000000..9a51751 --- /dev/null +++ b/airflow_dbt/dbt_command_config.py @@ -0,0 +1,69 @@ +import sys + +# Python versions older than 3.8 have the TypedDict in a different namespace. +# In case we find ourselves in that situation, we use the `older` import +if sys.version_info[0] == 3 and sys.version_info[1] >= 8: + from typing import TypedDict +else: + from typing_extensions import TypedDict + + +class DbtCommandConfig(TypedDict, total=False): + """ + Holds the structure of a dictionary containing dbt config. Provides the + types and names for each one, and also helps shortening the constructor + since we can nest it and reuse it + """ + # global flags + version: bool + record_timing_info: bool + debug: bool + log_format: str # either 'text', 'json' or 'default' + write_json: bool + strict: bool + warn_error: bool + partial_parse: bool + use_experimental_parser: bool + use_colors: bool + no_use_colors: bool + + # per command flags + profiles_dir: str + project_dir: str + target: str + vars: dict + models: str + exclude: str + + # run specific + full_refresh: bool + profile: str + + # docs specific + no_compile: bool + + # debug specific + config_dir: str + + # ls specific + resource_type: str # models, snapshots, seeds, tests, and sources. + select: str + models: str + exclude: str + selector: str + output: str + output_keys: str + + # rpc specific + host: str + port: int + + # run specific + fail_fast: bool + + # run-operation specific + args: dict + + # test specific + data: bool + schema: bool diff --git a/airflow_dbt/hooks/__init__.py b/airflow_dbt/hooks/__init__.py index 8644e88..e69de29 100644 --- a/airflow_dbt/hooks/__init__.py +++ b/airflow_dbt/hooks/__init__.py @@ -1 +0,0 @@ -from .dbt_hook import DbtCliHook diff --git a/airflow_dbt/hooks/base.py b/airflow_dbt/hooks/base.py new file mode 100644 index 0000000..66936a8 --- /dev/null +++ b/airflow_dbt/hooks/base.py @@ -0,0 +1,98 @@ +import json +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Union + +# noinspection PyDeprecation +from airflow.hooks.base_hook import BaseHook + +from airflow_dbt.dbt_command_config import DbtCommandConfig + + +def render_config(config: Dict[str, Union[str, bool]]) -> List[str]: + """Renders a dictionary of options into a list of cli strings""" + dbt_command_config_annotations = DbtCommandConfig.__annotations__ + command_params = [] + for key, value in config.items(): + if key not in dbt_command_config_annotations: + raise ValueError(f"{key} is not a valid key") + if value is not None: + param_value_type = type(value) + # check that the value has the correct type from dbt_command_config_annotations + if param_value_type != dbt_command_config_annotations[key]: + raise TypeError(f"{key} has to be of type {dbt_command_config_annotations[key]}") + # if the param is not bool it must have a non null value + flag_prefix = '' + if param_value_type is bool and not value: + flag_prefix = 'no-' + cli_param_from_kwarg = "--" + flag_prefix + key.replace("_", "-") + command_params.append(cli_param_from_kwarg) + if param_value_type is str: + command_params.append(value) + elif param_value_type is int: + command_params.append(str(value)) + elif param_value_type is dict: + command_params.append(json.dumps(value)) + return command_params + + +def generate_dbt_cli_command( + dbt_bin: str, + command: str, + base_config: Dict[str, Union[str, bool]], + command_config: Dict[str, Union[str, bool]], +) -> List[str]: + """ + Creates a CLI string from the keys in the dictionary. If the key is none + it is ignored. If the key is of type boolean the name of the key is added. + If the key is of type string it adds the the key prefixed with tow dashes. + If the key is of type integer it adds the the key prefixed with three + dashes. + dbt_bin and command are mandatory. + Boolean flags must always be positive. + + Available params are: + :param command_config: Specific params for the commands + :type command_config: dict + :param base_config: Params that apply to the `dbt` program regardless of + the command it is running + :type base_config: dict + :param command: The dbt sub-command to run + :type command: str + :param dbt_bin: Path to the dbt binary, defaults to `dbt` assumes it is + available in the PATH. + :type dbt_bin: str + :param command: The dbt sub command to run, for example for `dbt run` + the base_command will be `run`. If any other flag not contemplated + must be included it can also be added to this string + :type command: str + """ + if not dbt_bin: + raise ValueError("dbt_bin is mandatory") + if not command: + raise ValueError("command mandatory") + base_params = render_config(base_config) + command_params = render_config(command_config) + # commands like 'dbt docs generate' need the command to be split in two + command_pieces = command.split(" ") + return [dbt_bin, *base_params, *command_pieces, *command_params] + + +class DbtBaseHook(BaseHook, ABC): + """ + Base abstract class for all DbtHooks to have a common interface and force + implement the mandatory `run_dbt()` function. + """ + + def __init__(self, env: Optional[Dict] = None): + """ + :param env: If set will be passed over to cloud build to run in the + dbt step + :type env: dict + """ + super().__init__() + self.env = env or {} + + @abstractmethod + def run_dbt(self, dbt_cmd: Union[str, List[str]]): + """Run the dbt command""" + pass diff --git a/airflow_dbt/hooks/cli.py b/airflow_dbt/hooks/cli.py new file mode 100644 index 0000000..93c5ce6 --- /dev/null +++ b/airflow_dbt/hooks/cli.py @@ -0,0 +1,51 @@ +from __future__ import print_function + +from typing import Any, Dict, List, Optional, Union + +from airflow import AirflowException +from airflow.hooks.subprocess import SubprocessHook + +from airflow_dbt.hooks.base import DbtBaseHook + + +class DbtCliHook(DbtBaseHook): + """ + Run the dbt command in the same airflow worker the task is being run. + This requires the `dbt` python package to be installed in it first. + """ + + def __init__(self, env: Optional[Dict] = None): + """ + :type env: + :param env: Environment variables that will be passed to the + subprocess. Must be a dictionary of key-values + """ + self.sp = SubprocessHook() + super().__init__(env=env) + + def get_conn(self) -> Any: + """ + Return the subprocess connection, which isn't implemented, just for + conformity + """ + return self.sp.get_conn() + + def run_dbt(self, dbt_cmd: Union[str, List[str]]): + """ + Run the dbt cli + + :param dbt_cmd: The dbt whole command to run + :type dbt_cmd: List[str] + """ + result = self.sp.run_command( + command=dbt_cmd, + env=self.env, + ) + + if result.exit_code != 0: + raise AirflowException(f'Error executing the DBT command: ' + f'{result.output}') + + def on_kill(self): + """Kill the open subprocess if the task gets killed by Airflow""" + self.sp.send_sigterm() diff --git a/airflow_dbt/hooks/dbt_hook.py b/airflow_dbt/hooks/dbt_hook.py deleted file mode 100644 index 0b4caf0..0000000 --- a/airflow_dbt/hooks/dbt_hook.py +++ /dev/null @@ -1,142 +0,0 @@ -from __future__ import print_function -import os -import signal -import subprocess -import json -from airflow.exceptions import AirflowException -from airflow.hooks.base_hook import BaseHook - - -class DbtCliHook(BaseHook): - """ - Simple wrapper around the dbt CLI. - - :param profiles_dir: If set, passed as the `--profiles-dir` argument to the `dbt` command - :type profiles_dir: str - :param target: If set, passed as the `--target` argument to the `dbt` command - :type dir: str - :param dir: The directory to run the CLI in - :type vars: str - :param vars: If set, passed as the `--vars` argument to the `dbt` command - :type vars: dict - :param full_refresh: If `True`, will fully-refresh incremental models. - :type full_refresh: bool - :param models: If set, passed as the `--models` argument to the `dbt` command - :type models: str - :param warn_error: If `True`, treat warnings as errors. - :type warn_error: bool - :param exclude: If set, passed as the `--exclude` argument to the `dbt` command - :type exclude: str - :param select: If set, passed as the `--select` argument to the `dbt` command - :type select: str - :param dbt_bin: The `dbt` CLI. Defaults to `dbt`, so assumes it's on your `PATH` - :type dbt_bin: str - :param output_encoding: Output encoding of bash command. Defaults to utf-8 - :type output_encoding: str - :param verbose: The operator will log verbosely to the Airflow logs - :type verbose: bool - """ - - def __init__(self, - profiles_dir=None, - target=None, - dir='.', - vars=None, - full_refresh=False, - data=False, - schema=False, - models=None, - exclude=None, - select=None, - dbt_bin='dbt', - output_encoding='utf-8', - verbose=True, - warn_error=False): - self.profiles_dir = profiles_dir - self.dir = dir - self.target = target - self.vars = vars - self.full_refresh = full_refresh - self.data = data - self.schema = schema - self.models = models - self.exclude = exclude - self.select = select - self.dbt_bin = dbt_bin - self.verbose = verbose - self.warn_error = warn_error - self.output_encoding = output_encoding - - def _dump_vars(self): - # The dbt `vars` parameter is defined using YAML. Unfortunately the standard YAML library - # for Python isn't very good and I couldn't find an easy way to have it formatted - # correctly. However, as YAML is a super-set of JSON, this works just fine. - return json.dumps(self.vars) - - def run_cli(self, *command): - """ - Run the dbt cli - - :param command: The dbt command to run - :type command: str - """ - - dbt_cmd = [self.dbt_bin, *command] - - if self.profiles_dir is not None: - dbt_cmd.extend(['--profiles-dir', self.profiles_dir]) - - if self.target is not None: - dbt_cmd.extend(['--target', self.target]) - - if self.vars is not None: - dbt_cmd.extend(['--vars', self._dump_vars()]) - - if self.data: - dbt_cmd.extend(['--data']) - - if self.schema: - dbt_cmd.extend(['--schema']) - - if self.models is not None: - dbt_cmd.extend(['--models', self.models]) - - if self.exclude is not None: - dbt_cmd.extend(['--exclude', self.exclude]) - - if self.select is not None: - dbt_cmd.extend(['--select', self.select]) - - if self.full_refresh: - dbt_cmd.extend(['--full-refresh']) - - if self.warn_error: - dbt_cmd.insert(1, '--warn-error') - - if self.verbose: - self.log.info(" ".join(dbt_cmd)) - - sp = subprocess.Popen( - dbt_cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - cwd=self.dir, - close_fds=True) - self.sp = sp - self.log.info("Output:") - line = '' - for line in iter(sp.stdout.readline, b''): - line = line.decode(self.output_encoding).rstrip() - self.log.info(line) - sp.wait() - self.log.info( - "Command exited with return code %s", - sp.returncode - ) - - if sp.returncode: - raise AirflowException("dbt command failed") - - def on_kill(self): - self.log.info('Sending SIGTERM signal to dbt command') - os.killpg(os.getpgid(self.sp.pid), signal.SIGTERM) diff --git a/airflow_dbt/hooks/google.py b/airflow_dbt/hooks/google.py new file mode 100644 index 0000000..305127e --- /dev/null +++ b/airflow_dbt/hooks/google.py @@ -0,0 +1,189 @@ +"""Provides hooks and helper functions to allow running dbt in GCP""" + +import logging +import traceback +from pathlib import Path +from typing import Any, Dict, List, Optional + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.cloud_build import ( + CloudBuildHook, +) +from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url +from airflow.utils.yaml import dump +from google.cloud.devtools.cloudbuild_v1 import ( + Build, +) + +from airflow_dbt.hooks.base import DbtBaseHook + + +class DbtCloudBuildHook(DbtBaseHook): + """ + Connects to GCP Cloud Build, creates a build config, submits it and waits + for results. + """ + + def __init__( + self, + project_id: Optional[str] = None, + gcs_staging_location: str = None, + gcp_conn_id: str = None, + env: Optional[Dict] = None, + service_account: Optional[str] = None, + dbt_version: str = 'latest', + dbt_image: str = 'ghcr.io/dbt-labs/dbt-bigquery', + dbt_project_dir: str = None, + dbt_artifacts_dest: str = None, + ): + """ + Runs the dbt command in a Cloud Build job in GCP + + :param dbt_artifacts_dest: Folder in GCS destination for the artifacts. + For example `gs://my-bucket/path/to/artifacts/` + :type dbt_artifacts_dest: str + :type env: dict + :param env: If set, passed to the dbt executor + :param project_id: GCP Project ID as stated in the console + :type project_id: str + :param gcp_conn_id: The connection ID to use when fetching connection + info. + :type gcp_conn_id: str + :param gcs_staging_location: Where to store the sources to be fetched + later by the cloud build job. It should be the GCS url of a folder. + For example: `gs://my-bucket/stored. A sub-folder will be generated + to avoid collision between possible different concurrent runs. + :type gcs_staging_location: str + :param dbt_version: the DBT version to be fetched from dockerhub. + Defaults to 'latest'. It represents the image tag. So it must also be + a tag for your custom Docker dbt image if you provide one. + :type dbt_version: str + :param service_account: email for the service account. If set must be + accompanied by the project_id + + + """ + staging_bucket, staging_blob = _parse_gcs_url(gcs_staging_location) + # we have provided something similar to + # 'gs:///' + if Path(staging_blob).suffix not in ['.gz', '.gzip', '.zip']: + raise AirflowException( + f'The provided blob "{staging_blob}" to a compressed file ' + 'does not have the right extension ".tar.gz" or ".gzip"' + ) + # gcp config + self.gcs_staging_bucket = staging_bucket + self.gcs_staging_blob = staging_blob + if gcp_conn_id is None: + self.cloud_build_hook = CloudBuildHook() + else: + self.cloud_build_hook = CloudBuildHook(gcp_conn_id=gcp_conn_id) + self.gcp_conn_id = gcp_conn_id + self.project_id = project_id or self.cloud_build_hook.project_id + self.service_account = service_account + # dbt config + self.dbt_version = dbt_version + self.dbt_image = dbt_image + self.dbt_project_dir = dbt_project_dir + self.dbt_artifacts_dest = dbt_artifacts_dest + + super().__init__(env=env) + + def get_conn(self) -> Any: + """Returns the cloud build connection, which is a gcp connection""" + return self.cloud_build_hook.get_conn() + + def _get_cloud_build_config(self, dbt_cmd: List[str]) -> Dict: + cloud_build_config = { + 'steps': [{ + 'name': f'{self.dbt_image}:{self.dbt_version}', + 'entrypoint': dbt_cmd[0], + 'args': dbt_cmd[1:], + 'env': [f'{k}={v}' for k, v in self.env.items()], + }], + 'source': { + 'storage_source': { + "bucket": self.gcs_staging_bucket, + "object_": self.gcs_staging_blob, + } + }, + 'options': { + # default is legacy and its behaviour is subject to change + 'logging': 'GCS_ONLY', + }, + 'logs_bucket': self.gcs_staging_bucket, + } + + if self.service_account: + cloud_build_config['service_account'] = ( + f'projects/{self.project_id}/serviceAccounts/' + f'{self.service_account}' + ) + + if self.dbt_artifacts_dest: + # ensure the path ends with a slash as it should if it's a folder + gcs_dest_url = self.dbt_artifacts_dest.lstrip('/') + '/' + artifacts_step = { + 'name': 'gcr.io/cloud-builders/gsutil', + 'args': [ + '-m', 'cp', '-r', + f'{self.dbt_project_dir}/target/**', + gcs_dest_url + ] + } + cloud_build_config['steps'].append(artifacts_step) + + return cloud_build_config + + def run_dbt(self, dbt_cmd: List[str]): + """ + Run the dbt command. In version 5 of the providers + + :param dbt_cmd: The dbt whole command to run + :type dbt_cmd: List[str] + """ + # See: https://cloud.google.com/cloud-build/docs/api/reference/rest/v1/projects.builds + cloud_build_config = self._get_cloud_build_config(dbt_cmd) + logging.info( + f'Running the following cloud build' + f' config:\n{dump(cloud_build_config)}' + ) + + try: + # cloud_build_client = self.get_conn() + self.log.info("Creating build") + result_build: Build = self.cloud_build_hook.create_build( + cloud_build_config + ) + self.log.info( + f"Build has been created: {result_build.id}.\n" + f'Build logs available at: {result_build.log_url} and the ' + f'file gs://{result_build.logs_bucket}/log-' + f'{result_build.id}.txt' + ) + self.build_id = result_build.id + # print logs from GCS + with GCSHook().provide_file( + bucket_name=result_build.logs_bucket, + object_name=f'log-{result_build.id}.txt', + ) as log_file_handle: + clean_lines = [ + line.decode('utf-8').strip() + for line in log_file_handle if line + ] + log_block = '\n'.join(clean_lines) + hr = '-' * 80 + logging.info( + f'Logs from the build {result_build.id}:\n' + f'{hr}\n' + f'{log_block}\n' + f'{hr}' + ) + return result_build + except Exception as ex: + traceback.print_exc() + raise AirflowException("Exception running the build: ", str(ex)) + + def on_kill(self): + """Stopping the build is not implemented until google providers v6""" + self.cloud_build_hook.cancel_build(self.build_id) diff --git a/airflow_dbt/operators/__init__.py b/airflow_dbt/operators/__init__.py index 295d53b..e69de29 100644 --- a/airflow_dbt/operators/__init__.py +++ b/airflow_dbt/operators/__init__.py @@ -1,8 +0,0 @@ -from .dbt_operator import ( - DbtSeedOperator, - DbtSnapshotOperator, - DbtRunOperator, - DbtTestOperator, - DbtDocsGenerateOperator, - DbtDepsOperator -) diff --git a/airflow_dbt/operators/dbt_operator.py b/airflow_dbt/operators/dbt_operator.py index 6233d8d..84f85e1 100644 --- a/airflow_dbt/operators/dbt_operator.py +++ b/airflow_dbt/operators/dbt_operator.py @@ -1,144 +1,378 @@ -from airflow_dbt.hooks.dbt_hook import DbtCliHook +import logging +import warnings +from typing import Any, Dict, List, Optional + from airflow.models import BaseOperator +# noinspection PyDeprecation from airflow.utils.decorators import apply_defaults +from airflow_dbt.dbt_command_config import DbtCommandConfig +from airflow_dbt.hooks.base import DbtBaseHook, generate_dbt_cli_command +from airflow_dbt.hooks.cli import DbtCliHook + class DbtBaseOperator(BaseOperator): """ - Base dbt operator - All other dbt operators are derived from this operator. - - :param profiles_dir: If set, passed as the `--profiles-dir` argument to the `dbt` command - :type profiles_dir: str - :param target: If set, passed as the `--target` argument to the `dbt` command - :type dir: str - :param dir: The directory to run the CLI in - :type vars: str - :param vars: If set, passed as the `--vars` argument to the `dbt` command - :type vars: dict - :param full_refresh: If `True`, will fully-refresh incremental models. - :type full_refresh: bool - :param models: If set, passed as the `--models` argument to the `dbt` command - :type models: str - :param warn_error: If `True`, treat warnings as errors. - :type warn_error: bool - :param exclude: If set, passed as the `--exclude` argument to the `dbt` command - :type exclude: str - :param select: If set, passed as the `--select` argument to the `dbt` command - :type select: str - :param dbt_bin: The `dbt` CLI. Defaults to `dbt`, so assumes it's on your `PATH` - :type dbt_bin: str - :param verbose: The operator will log verbosely to the Airflow logs - :type verbose: bool + Base dbt operator. All other dbt operators should inherit from this one. + + It receives all possible dbt options in the constructor. If no hook is + provided it uses the DbtCliHook to run the generated command. """ ui_color = '#d6522a' + ui_fgcolor = "white" + # add all the str/dict params to the templates + template_fields = ['dbt_env', 'dbt_bin', 'dbt_command', 'dbt_config'] + template_fields_renderers = { + 'dbt_env': 'json', + 'dbt_config': 'json', + } - template_fields = ['vars'] + dbt_env: Dict + dbt_bin: str + dbt_command: str + dbt_config: DbtCommandConfig + dbt_hook: DbtBaseHook + dbt_cli_command: List[str] + # noinspection PyShadowingBuiltins, PyDeprecation @apply_defaults - def __init__(self, - profiles_dir=None, - target=None, - dir='.', - vars=None, - models=None, - exclude=None, - select=None, - dbt_bin='dbt', - verbose=True, - warn_error=False, - full_refresh=False, - data=False, - schema=False, - *args, - **kwargs): - super(DbtBaseOperator, self).__init__(*args, **kwargs) - - self.profiles_dir = profiles_dir - self.target = target - self.dir = dir - self.vars = vars - self.models = models - self.full_refresh = full_refresh - self.data = data - self.schema = schema - self.exclude = exclude - self.select = select + def __init__( + self, + env: Dict = None, + dbt_bin: Optional[str] = 'dbt', + dbt_hook=None, + command: Optional[str] = None, + config: DbtCommandConfig = None, + # dir deprecated in favor of dbt native project and profile directories + dir: str = None, + # if config was not provided we un-flatten them from the kwargs + # global flags + version: bool = None, + record_timing_info: bool = None, + debug: bool = None, + log_format: str = None, # either 'text', 'json' or 'default' + write_json: bool = None, + strict: bool = None, + warn_error: bool = None, + partial_parse: bool = None, + use_experimental_parser: bool = None, + use_colors: bool = None, + # command specific config + profiles_dir: str = None, + project_dir: str = None, + profile: str = None, + target: str = None, + config_dir: str = None, + resource_type: str = None, + vars: Dict = None, + # run specific + full_refresh: bool = None, + # ls specific + data: bool = None, + schema: bool = None, + models: str = None, + exclude: str = None, + select: str = None, + selector: str = None, + output: str = None, + output_keys: str = None, + # rpc specific + host: str = None, + port: str = None, + # test specific + fail_fast: bool = None, + args: dict = None, + no_compile: bool = None, + + *vargs, + **kwargs + ): + """ + :param env: Dictionary with environment variables to be used in the + runtime + :type env: dict + :param dbt_bin: Path to the dbt binary, defaults to `dbt` assumes it is + available in the PATH. + :type dbt_bin: str + :param dbt_hook: The dbt hook to use as executor. For now the + implemented ones are: DbtCliHook, DbtCloudBuildHook. It should be an + instance of one of those, or another that inherits from DbtBaseHook. + If not provided by default a DbtCliHook will be instantiated with + the provided params + :type dbt_hook: DbtBaseHook + :param command: The dbt sub command to run, for example for `dbt run` + the base_command will be `run`. If any other flag not contemplated + must be included it can also be added to this string + :type command: str + :param config: TypedDictionary which accepts all of the commands + related to executing dbt. This way you can separate them from the + ones destined for execution + :type config: DbtCommandConfig + :param dir: Legacy param to set the dbt project directory + :type dir: str + :param version: Dbt version to use, in SEMVER. Defaults + to the last one '0.21.0' + :type version: str + :param record_timing_info: Dbt flag to add '--record-timing-info' + :type record_timing_info: bool + :param debug: Dbt flag to add '--debug' + :type debug: bool + :param log_format: Specifies how dbt's logs should be formatted. The + value for this flag can be one of: text, json, or default + :type log_format: str + :param write_json: If set to no it adds the `--no-write-json` Dbt flag + :type write_json: bool + :param strict: Only for use during dbt development. It performs extra + validation of dbt objects and internal consistency checks during + compilation + :type strict: bool + :param warn_error: Converts dbt warnings into errors + :type warn_error: bool + :param partial_parse: configure partial parsing in your project, and + will override the value set in `profiles.yml + :type partial_parse: bool + :param use_experimental_parser: Statically analyze model files in your + project and, if possible, extract needed information 3x faster than + a full Jinja render + :type use_experimental_parser: bool + :param use_colors: Displays colors in dbt logs + :type use_colors: bool + :param profiles_dir: Path to profiles.yaml dir. Can be relative from + the folder the DAG is being run, which usually is the home or de + DAGs folder + :type profiles_dir: str + :param project_dir: Path to the dbt project you want to run. Can be + relative to the path the DAG is being run + :type project_dir: str + :param profile: Which profile to load. Overrides setting in + dbt_project.yml + :type profile: Which profile to load. Overrides setting in + dbt_project.yml + :param target: Which target to load for the given profile + :type target: str + :param config_dir: Sames a profile_dir + :type config_dir: str + :param resource_type: One of: model,snapshot,source,analysis,seed, + exposure,test,default,all + :type resource_type: str + :param vars: Supply variables to the project. This argument overrides + variables defined in your dbt_project.yml file. This argument should + be a YAML string, eg. '{my_variable: my_value}' + :type vars: dict + :param full_refresh: If specified, dbt will drop incremental models and + fully-recalculate the incremental table from the model definition + :type full_refresh: bool + :param data: Run data tests defined in "tests" directory. + :type data: bool + :param schema: Run constraint validations from schema.yml files + :type schema: bool + :param models: Flag used to choose a node or subset of nodes to apply + the command to (v0.210.0 and lower) + :type models: str + :param exclude: Nodes to exclude from the set defined with + select/models + :type exclude: str + :param select: Flag used to choose a node or subset of nodes to apply + the command to (v0.21.0 and higher) + :type select: str + :param selector: Config param to reference complex selects defined in + the config yaml + :type selector: str + :param output: {json,name,path,selector} + :type output: str + :param output_keys: Which keys to output + :type output_keys: str + :param host: Specify the host to listen on for the rpc server + :type host: str + :param port: Specify the port number for the rpc server + :type port: int + :param fail_fast: Stop execution upon a first test failure + :type fail_fast: bool + :param args: + :type args: + :param no_compile: Do not run "dbt compile" as part of docs generation + :type no_compile: bool + :param vargs: rest of the positional args + :param kwargs: rest of the keyword args + + """ + super(DbtBaseOperator, self).__init__(*vargs, **kwargs) + + if dir is not None: + warnings.warn( + '"dir" param is deprecated in favor of dbt native ' + 'param "project_dir"', PendingDeprecationWarning + ) + if project_dir is None: + logging.warning('Using "dir" as "project_dir"') + project_dir = dir + + self.dbt_env = env or {} self.dbt_bin = dbt_bin - self.verbose = verbose - self.warn_error = warn_error - self.create_hook() - - def create_hook(self): - self.hook = DbtCliHook( - profiles_dir=self.profiles_dir, - target=self.target, - dir=self.dir, - vars=self.vars, - full_refresh=self.full_refresh, - data=self.data, - schema=self.schema, - models=self.models, - exclude=self.exclude, - select=self.select, - dbt_bin=self.dbt_bin, - verbose=self.verbose, - warn_error=self.warn_error) + self.dbt_command = command + # defaults to an empty dict + config = config or {} + # overrides with the top level config + config.update({ + # global flags + 'version': version, + 'record_timing_info': record_timing_info, + 'debug': debug, + 'log_format': log_format, + 'write_json': write_json, + 'strict': strict, + 'warn_error': warn_error, + 'partial_parse': partial_parse, + 'use_experimental_parser': use_experimental_parser, + 'use_colors': use_colors, + # per command flags + 'profiles_dir': profiles_dir, + 'project_dir': project_dir, + 'target': target, + 'vars': vars, + # run specific + 'full_refresh': full_refresh, + 'profile': profile, + # docs specific + 'no_compile': no_compile, + # debug specific + 'config_dir': config_dir, + # ls specific + 'resource_type': resource_type, + 'select': select, + 'models': models, + 'exclude': exclude, + 'selector': selector, + 'output': output, + 'output_keys': output_keys, + # rpc specific + 'host': host, + 'port': port, + # run specific + 'fail_fast': fail_fast, + # run-operation specific + 'args': args, + # test specific + 'data': data, + 'schema': schema, + }) + # filter out None values from the constructor + config = { + key: val + for key, val in config.items() + if val is not None + } + self.dbt_config = config + self.dbt_env = env + self.dbt_hook = dbt_hook - return self.hook + def instantiate_hook(self): + """ + Instantiates the underlying dbt hook. This has to be deferred until + after the constructor or the templated params won't be interpolated. + """ + dbt_hook = self.dbt_hook + self.dbt_hook = dbt_hook if dbt_hook is not None else DbtCliHook( + env=self.dbt_env, + ) + + def execute(self, context: Any): + """Runs the provided command in the provided execution environment""" + self.instantiate_hook() + dbt_base_params = [ + 'log_format', 'version', 'use_colors', 'warn_error', + 'partial_parse', 'use_experimental_parser', 'profiles_dir' + ] + + dbt_base_config = { + key: val + for key, val in self.dbt_config.items() + if key in dbt_base_params + } + + dbt_command_config = { + key: val + for key, val in self.dbt_config.items() + if key not in dbt_base_params + } + + self.dbt_cli_command = generate_dbt_cli_command( + dbt_bin=self.dbt_bin, + command=self.dbt_command, + base_config=dbt_base_config, + command_config=dbt_command_config, + ) + self.dbt_hook.run_dbt(self.dbt_cli_command) class DbtRunOperator(DbtBaseOperator): - @apply_defaults - def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtRunOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, **kwargs) + """Runs a dbt run command""" - def execute(self, context): - self.create_hook().run_cli('run') + # noinspection PyDeprecation + @apply_defaults + def __init__(self, *args, **kwargs): + super().__init__(*args, command='run', **kwargs) class DbtTestOperator(DbtBaseOperator): - @apply_defaults - def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtTestOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, **kwargs) + """Runs a dbt test command""" - def execute(self, context): - self.create_hook().run_cli('test') + # noinspection PyDeprecation + @apply_defaults + def __init__(self, *args, **kwargs): + super().__init__(*args, command='test', **kwargs) class DbtDocsGenerateOperator(DbtBaseOperator): - @apply_defaults - def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtDocsGenerateOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, - **kwargs) + """Runs a dbt docs generate command""" - def execute(self, context): - self.create_hook().run_cli('docs', 'generate') + # noinspection PyDeprecation + @apply_defaults + def __init__(self, *args, **kwargs): + super().__init__(*args, command='docs generate', **kwargs) class DbtSnapshotOperator(DbtBaseOperator): - @apply_defaults - def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtSnapshotOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, **kwargs) + """Runs a dbt snapshot command""" - def execute(self, context): - self.create_hook().run_cli('snapshot') + # noinspection PyDeprecation + @apply_defaults + def __init__(self, *args, **kwargs): + super().__init__(*args, command='snapshot', **kwargs) class DbtSeedOperator(DbtBaseOperator): - @apply_defaults - def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtSeedOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, **kwargs) + """Runs a dbt seed command""" - def execute(self, context): - self.create_hook().run_cli('seed') + # noinspection PyDeprecation + @apply_defaults + def __init__(self, *args, **kwargs): + super().__init__(*args, command='seed', **kwargs) class DbtDepsOperator(DbtBaseOperator): + """Runs a dbt deps command""" + + # noinspection PyDeprecation @apply_defaults - def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtDepsOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, **kwargs) + def __init__(self, *args, **kwargs): + super().__init__(*args, command='deps', **kwargs) - def execute(self, context): - self.create_hook().run_cli('deps') + +class DbtCleanOperator(DbtBaseOperator): + """Runs a dbt clean command""" + + # noinspection PyDeprecation + @apply_defaults + def __init__(self, *args, **kwargs): + super().__init__(*args, command='clean', **kwargs) + + +class DbtDebugOperator(DbtBaseOperator): + """Runs a dbt clean command""" + + # noinspection PyDeprecation + @apply_defaults + def __init__(self, *args, **kwargs): + super().__init__(*args, command='debug', **kwargs) diff --git a/airflow_dbt/operators/google.py b/airflow_dbt/operators/google.py new file mode 100644 index 0000000..dd83921 --- /dev/null +++ b/airflow_dbt/operators/google.py @@ -0,0 +1,58 @@ +from airflow.utils.decorators import apply_defaults + +from airflow_dbt.hooks.google import DbtCloudBuildHook +from airflow_dbt.operators.dbt_operator import DbtBaseOperator + + +class DbtCloudBuildOperator(DbtBaseOperator): + """Uses the CloudBuild Hook to run the provided dbt config""" + + template_fields = DbtBaseOperator.template_fields + [ + 'gcs_staging_location', 'project_id', 'dbt_version', + 'service_account', 'dbt_artifacts_dest' + ] + + # noinspection PyDeprecation + @apply_defaults + def __init__( + self, + gcs_staging_location: str, + project_id: str = None, + gcp_conn_id: str = "google_cloud_default", + dbt_version: str = '1.3.latest', + dbt_image: str = 'ghcr.io/dbt-labs/dbt-bigquery', + dbt_artifacts_dest: str = None, + service_account: str = None, + *args, + **kwargs + ): + self.dbt_artifacts_dest = dbt_artifacts_dest + self.gcs_staging_location = gcs_staging_location + self.gcp_conn_id = gcp_conn_id + self.project_id = project_id + self.dbt_version = dbt_version + self.dbt_image = dbt_image + self.service_account = service_account + + super(DbtCloudBuildOperator, self).__init__( + *args, + **kwargs + ) + + def instantiate_hook(self): + """ + Instantiates a Cloud build dbt hook. This has to be done out of the + constructor because by the time the constructor runs the params have + not been yet interpolated. + """ + self.dbt_hook = DbtCloudBuildHook( + env=self.dbt_env, + gcs_staging_location=self.gcs_staging_location, + gcp_conn_id=self.gcp_conn_id, + dbt_version=self.dbt_version, + dbt_image=self.dbt_image, + service_account=self.service_account, + project_id=self.project_id, + dbt_project_dir=self.dbt_config.get('project_dir'), + dbt_artifacts_dest=self.dbt_artifacts_dest, + ) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..4ed21dc --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,32 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "airflow-dbt-dinigo" +dynamic = ["version"] +description = "Apache Airflow integration for dbt" +readme = "README.md" +license = "MIT" +classifiers = [ + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: MIT License", + "Operating System :: MacOS :: MacOS X", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python :: 3.7", +] +dependencies = [ + "apache-airflow >= 1.10.3", +] +packages = [ + { include = "airflow_dbt" }, +] + +[project.optional-dependencies] +google = [ + "apache-airflow-providers-google", + "google-cloud-build", +] + +[tool.hatch.version] +path = "airflow_dbt/__version__.py" diff --git a/setup.py b/setup.py deleted file mode 100644 index 7d3b2f8..0000000 --- a/setup.py +++ /dev/null @@ -1,81 +0,0 @@ -import io -import os -import sys -from shutil import rmtree -from setuptools import setup, find_packages, Command - -here = os.path.abspath(os.path.dirname(__file__)) - -# Load the package's __version__.py module as a dictionary. -about = {} -with open(os.path.join(here, 'airflow_dbt', '__version__.py')) as f: - exec(f.read(), about) - -with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: - long_description = f.read() - - -class UploadCommand(Command): - """Support setup.py upload.""" - - description = 'Build and publish the package.' - user_options = [] - - @staticmethod - def status(s): - """Prints things in bold.""" - print('\033[1m{0}\033[0m'.format(s)) - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - try: - self.status('Removing previous builds…') - rmtree(os.path.join(here, 'dist')) - except OSError: - pass - - self.status('Building Source and Wheel (universal) distribution…') - os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable)) - - self.status('Uploading the package to PyPI via Twine…') - os.system('twine upload dist/*') - - self.status('Pushing git tags…') - os.system('git tag v{0}'.format(about['__version__'])) - os.system('git push --tags') - - sys.exit() - - -setup( - name='airflow_dbt', - version=about['__version__'], - packages=find_packages(exclude=['tests']), - install_requires=['apache-airflow >= 1.10.3'], - author='GoCardless', - author_email='engineering@gocardless.com', - description='Apache Airflow integration for dbt', - long_description=long_description, - long_description_content_type='text/markdown', - license='MIT', - url='https://github.com/gocardless/airflow-dbt', - classifiers=[ - 'Development Status :: 5 - Production/Stable', - - 'License :: OSI Approved :: MIT License', - - 'Operating System :: MacOS :: MacOS X', - 'Operating System :: POSIX :: Linux', - - 'Programming Language :: Python :: 3.7', - ], - # $ setup.py upload support. - cmdclass={ - 'upload': UploadCommand, - }, -) diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/hooks/test_dbt_hook.py b/tests/hooks/test_dbt_hook.py index 4dd39ed..400ef10 100644 --- a/tests/hooks/test_dbt_hook.py +++ b/tests/hooks/test_dbt_hook.py @@ -1,54 +1,331 @@ -from unittest import TestCase -from unittest import mock -import subprocess -from airflow_dbt.hooks.dbt_hook import DbtCliHook +from typing import Union +from unittest import TestCase, mock +from unittest.mock import patch +import pytest +from airflow import AirflowException +from airflow.hooks.subprocess import SubprocessHook, SubprocessResult +from pytest import mark -class TestDbtHook(TestCase): +from airflow_dbt.dbt_command_config import DbtCommandConfig +from airflow_dbt.hooks.base import generate_dbt_cli_command +from airflow_dbt.hooks.cli import DbtCliHook +from airflow_dbt.hooks.google import ( + DbtCloudBuildHook, + check_google_provider_version, +) - @mock.patch('subprocess.Popen') - def test_sub_commands(self, mock_subproc_popen): - mock_subproc_popen.return_value \ - .communicate.return_value = ('output', 'error') - mock_subproc_popen.return_value.returncode = 0 - mock_subproc_popen.return_value \ - .stdout.readline.side_effect = [b"placeholder"] +cli_command_from_params_data = [ + [("dbt", "run", {}, ["dbt", "run"]), "regular dbt run"], + # check it runs with empty params + [("dbt", None, {}, ValueError()), "it fails with no command"], + [(None, "run", {}, ValueError()), "it fails with no dbt_bin"], + [("dbt", "test", {}, ["dbt", "test"]), "test without params"], + [ + ("dbt", "test", {'non_existing_param'}, TypeError()), + "invalid param raises TypeError" + ], + # test invalid param + [ + ("dbt", "test", {'--models': None}, ValueError()), + "required --models value raises ValueError if not provided" + ], + # test mandatory value + [ + ("dbt", "test", {'--models': 3}, ValueError()), + "required --models value raises ValueError if not correct type" + ], + [ + ("/bin/dbt", "test", {}, ["/bin/dbt", "test"]), + "dbt_bin other than the default gets passed through" + ], + [ + ("dbt", "run", {'full_refresh': False}, ValueError()), + "flags param fails if contains False value" + ], + # test flags always positive + [('/home/airflow/.local/bin/dbt', 'run', { + 'full_refresh': True, + 'profiles_dir': '/opt/airflow/dags/dbt_project', + 'project_dir': '/opt/airflow/dags/project_dir', + 'vars': {'execution_date': '2021-01-01'}, + 'select': 'my_model', + }, ['/home/airflow/.local/bin/dbt', 'run', '--full-refresh', + '--profiles-dir', '/opt/airflow/dags/dbt_project', + '--project-dir', '/opt/airflow/dags/project_dir', + '--vars', '{"execution_date": "2021-01-01"}', '--select', + 'my_model']), + "fully fledged dbt run with all types of params" + ], + # test all the params + [ + ("dbt", "test", {'profiles_dir': '/path/profiles_folder'}, + ["dbt", "test", "--profiles-dir", "/path/profiles_folder"]), + "test profiles_dir param" + ], + [ + ("dbt", "run", {'project_dir': '/path/dbt_project_dir'}, + ["dbt", "run", "--project-dir", "/path/dbt_project_dir"]), + "test project_dir param" + ], + [ + ("dbt", "test", {'target': 'model_target'}, + ["dbt", "test", "--target", "model_target"]), + "test target param" + ], + [ + ("dbt", "test", {'vars': {"hello": "world"}}, + ["dbt", "test", "--vars", '{"hello": "world"}']), + "test vars param" + ], + [ + ("dbt", "run", {'models': 'my_model'}, + ["dbt", "run", "--models", "my_model"]), + "test models param" + ], + [ + ("dbt", "run", {'exclude': 'my_model'}, + ["dbt", "run", "--exclude", "my_model"]), + "test exclude param" + ], - hook = DbtCliHook() - hook.run_cli('docs', 'generate') - - mock_subproc_popen.assert_called_once_with( - [ - 'dbt', - 'docs', - 'generate' - ], - close_fds=True, - cwd='.', - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT - ) + # run specific params + [ + ("dbt", "run", {'full_refresh': True}, + ["dbt", "run", "--full-refresh"]), + "[dbt run] test full_refresh flag succeeds" + ], + [ + ("dbt", "run", {'full_refresh': 3}, TypeError()), + "[dbt run] test full_refresh param fails if not bool but integer" + ], + [ + ("dbt", "run", {'full_refresh': 'hello'}, TypeError()), + "[dbt run] test full_refresh project_dir fails if not bool but string" + ], + [ + ("dbt", "run", {'profile': 'test_profile'}, + ["dbt", "run", "--profile", "test_profile"]), + "[dbt run] test profile param" + ], + + # docs specific params + [ + ("dbt", "docs", {'no_compile': True}, + ["dbt", "docs", "--no-compile"]), + "test no_compile flag succeeds" + ], + # debug specific params + [ + ("dbt", "debug", {'config_dir': '/path/to/config_dir'}, + ["dbt", "debug", "--config-dir", '/path/to/config_dir']), + "[dbt debug] test config_dir param" + ], + + # ls specific params + [ + ("dbt", "ls", {'resource_type': '/path/to/config_dir'}, + ["dbt", "ls", "--resource-type", '/path/to/config_dir']), + "[dbt ls] test resource_type param" + ], + [ + ("dbt", "ls", {'select': 'my_model'}, + ["dbt", "ls", "--select", "my_model"]), + "[dbt ls] test select param" + ], + [ + ("dbt", "ls", {'exclude': 'my_model'}, + ["dbt", "ls", "--exclude", "my_model"]), + "[dbt ls] test exclude param" + ], + [ + ("dbt", "ls", {'output': 'my_model'}, + ["dbt", "ls", "--output", "my_model"]), + "[dbt ls] test output param" + ], + [ + ("dbt", "ls", {'output_keys': 'my_model'}, + ["dbt", "ls", "--output-keys", "my_model"]), + "[dbt ls] test output_keys param" + ], + + # rpc specific params + [ + ("dbt", "rpc", {'host': 'http://my-host-url.com'}, + ["dbt", "rpc", "--host", 'http://my-host-url.com']), + "[dbt rpc] test host param" + ], + [ + ("dbt", "rpc", {'port': '8080'}, TypeError()), + "[dbt rpc] test port param fails if not integer" + ], + [ + ("dbt", "rpc", {'port': 8080}, ["dbt", "rpc", "--port", '8080']), + "[dbt rpc] test port param" + ], + + # run specific params + [ + ("dbt", "run", {'fail_fast': True}, ["dbt", "run", "--fail-fast"]), + "[dbt run] test fail_fast flag succeeds" + ], + + # test specific params + [ + ("dbt", "test", {'data': True}, ["dbt", "test", '--data']), + "[dbt test] test data flag succeeds" + ], + [ + ("dbt", "test", {'schema': True}, ["dbt", "test", '--schema']), + "[dbt test] test schema flag succeeds" + ], + ] + + +@mark.parametrize( + ["dbt_bin", "command", "params", "expected_command"], + [test_params[0] for test_params in cli_command_from_params_data], + ids=[test_params[1] for test_params in cli_command_from_params_data] +) +def test_create_cli_command_from_params( + dbt_bin: str, + command: str, + params: DbtCommandConfig, + expected_command: Union[list[str], Exception] +): + """ + Test that the function create_cli_command_from_params returns the + correct + command or raises the correct exception + :type expected_command: object + """ + if isinstance(expected_command, Exception): + with pytest.raises(expected_command.__class__): + generate_dbt_cli_command(dbt_bin, command, **params) + else: + assert generate_dbt_cli_command(dbt_bin, command, **params) \ + == expected_command + + +class TestDbtCliHook(TestCase): + @mock.patch.object( + SubprocessHook, + 'run_command', + return_value=SubprocessResult(exit_code=0, output='all good') + ) + def test_sub_commands(self, mock_run_command): + """ + Test that sub commands are called with the right params + """ + hook = DbtCliHook(env={'GOOGLE_APPLICATION_CREDENTIALS': 'my_creds'}) + hook.run_dbt(['dbt', 'docs', 'generate']) + mock_run_command.assert_called_once_with( + command=['dbt', 'docs', 'generate'], + env={'GOOGLE_APPLICATION_CREDENTIALS': 'my_creds'} + ) + + @mock.patch.object( + SubprocessHook, + 'run_command', + return_value=SubprocessResult(exit_code=1, output='some error') + ) + def test_run_dbt(self, mock_run_command): + """ + Patch SubProcessHook to return a non-0 exit code and check we raise + an exception for such a result + """ - @mock.patch('subprocess.Popen') - def test_vars(self, mock_subproc_popen): - mock_subproc_popen.return_value \ - .communicate.return_value = ('output', 'error') - mock_subproc_popen.return_value.returncode = 0 - mock_subproc_popen.return_value \ - .stdout.readline.side_effect = [b"placeholder"] - - hook = DbtCliHook(vars={"foo": "bar", "baz": "true"}) - hook.run_cli('run') - - mock_subproc_popen.assert_called_once_with( - [ - 'dbt', - 'run', - '--vars', - '{"foo": "bar", "baz": "true"}' - ], - close_fds=True, - cwd='.', - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT + with pytest.raises(AirflowException): + hook = DbtCliHook(env={'GOOGLE_APPLICATION_CREDENTIALS': 'my_creds'}) + hook.run_dbt(['dbt', 'run']) + mock_run_command.assert_called_once_with( + command=['dbt', 'run'], + env={'GOOGLE_APPLICATION_CREDENTIALS': 'my_creds'} ) + + @mock.patch.object(SubprocessHook, 'get_conn') + def test_subprocess_kill_called(self, mock_get_conn): + hook = DbtCliHook() + hook.get_conn() + mock_get_conn.assert_called_once() + + @mock.patch.object(SubprocessHook, 'send_sigterm') + def test_subprocess_get_conn_called(self, mock_send_sigterm): + hook = DbtCliHook() + hook.on_kill() + mock_send_sigterm.assert_called_once() + + +class TestDbtCloudBuildHook(TestCase): + @patch('airflow_dbt.hooks.google.CloudBuildHook') + @patch('airflow_dbt.hooks.google.GCSHook') + def test_create_build(self, _, MockCloudBuildHook): + mock_create_build = MockCloudBuildHook().create_build + mock_create_build.return_value = { + 'id': 'test_id', 'logUrl': 'http://testurl.com' + } + hook = DbtCloudBuildHook( + project_id='test_project_id', + gcs_staging_location='gs://hello/file.tar.gz', + dbt_version='0.10.10', + env={'TEST_ENV_VAR': 'test'}, + service_account='robot@mail.com' + ) + hook.run_dbt(['docs', 'generate']) + + expected_body = { + 'steps': [{ + 'name': 'fishtownanalytics/dbt:0.10.10', + 'args': ['docs', 'generate'], + 'env': ['TEST_ENV_VAR=test'] + }], + 'source': { + 'storageSource': { + 'bucket': 'hello', + 'object': 'file.tar.gz', + } + }, + 'serviceAccount': 'projects/test_project_id/serviceAccounts/robot@mail.com', + 'options': { + 'logging': 'GCS_ONLY', + + }, + 'logsBucket': 'hello', + } + + mock_create_build.assert_called_once_with( + body=expected_body, + project_id='test_project_id' + ) + + +@pytest.mark.parametrize( + ['min_version', 'max_version', 'versions', 'expected_result'], + [ + ('5.0.0', '6.0.0', ['5.0.0', '4.0.0'], None), + ('5.0.0', '6.0.0', ['4.0.0', '3.0.0'], Exception), + ('5.0.0', '6.0.0', ['6.0.0', '5.0.0'], Exception), + ], + ids=[ + 'provider version within min and max allowed versions', + 'provider version below min allowed versions', + 'provider version above max allowed versions', + ] +) +@patch('airflow_dbt.hooks.google.get_provider_info') +def test_check_google_provider_version( + mock_get_provider_info, + min_version, + max_version, + versions, + expected_result +): + mock_get_provider_info.return_value = {'versions': versions} + if expected_result is None: + check_google_provider_version( + min_version, + max_version + ) + else: + with pytest.raises(expected_result): + check_google_provider_version(min_version, max_version) diff --git a/tests/operators/test_dbt_operator.py b/tests/operators/test_dbt_operator.py index 8ce2c5f..52d5a0d 100644 --- a/tests/operators/test_dbt_operator.py +++ b/tests/operators/test_dbt_operator.py @@ -1,66 +1,110 @@ import datetime -from unittest import TestCase, mock +from typing import Union +from unittest.mock import MagicMock, patch + +import pytest from airflow import DAG, configuration -from airflow_dbt.hooks.dbt_hook import DbtCliHook +from pytest import fixture, mark + +from airflow_dbt.hooks.cli import DbtCliHook +from airflow_dbt.hooks.google import DbtCloudBuildHook from airflow_dbt.operators.dbt_operator import ( + DbtBaseOperator, + DbtCleanOperator, + DbtDepsOperator, + DbtDocsGenerateOperator, + DbtRunOperator, DbtSeedOperator, DbtSnapshotOperator, - DbtRunOperator, DbtTestOperator, - DbtDepsOperator ) +from airflow_dbt.operators.google import DbtCloudBuildOperator -class TestDbtOperator(TestCase): - def setUp(self): - configuration.conf.load_test_config() - args = { - 'owner': 'airflow', - 'start_date': datetime.datetime(2020, 2, 27) - } - self.dag = DAG('test_dag_id', default_args=args) +@fixture +def airflow_dag(): + """Instantiates an Airflow DAG to be used as a test fixture""" + configuration.conf.load_test_config() + args = { + 'owner': 'airflow', + 'start_date': datetime.datetime(2020, 2, 27) + } + yield DAG('test_dag_id', default_args=args) - @mock.patch.object(DbtCliHook, 'run_cli') - def test_dbt_run(self, mock_run_cli): - operator = DbtRunOperator( - task_id='run', - dag=self.dag - ) - operator.execute(None) - mock_run_cli.assert_called_once_with('run') - @mock.patch.object(DbtCliHook, 'run_cli') - def test_dbt_test(self, mock_run_cli): - operator = DbtTestOperator( - task_id='test', - dag=self.dag - ) +@mark.parametrize( + ['Operator', 'expected_command'], [ + (DbtBaseOperator, ValueError()), + (DbtDepsOperator, ['dbt', 'deps']), + (DbtRunOperator, ['dbt', 'run']), + (DbtSeedOperator, ['dbt', 'seed']), + (DbtDocsGenerateOperator, ['dbt', 'docs generate']), + (DbtSnapshotOperator, ['dbt', 'snapshot']), + (DbtCleanOperator, ['dbt', 'clean']), + (DbtTestOperator, ['dbt', 'test']), + ] +) +@patch.object(DbtCliHook, 'run_dbt') +def test_basic_dbt_operators( + mock_run_dbt: MagicMock, + Operator: DbtBaseOperator, + expected_command: Union[list[str], Exception], + airflow_dag: DAG, +): + """ + Test that all the basic Dbt{Command}Operators instantiate the right + default dbt command. And that the basic DbtBaseOperator raises a value + Error since there's no base command defined to be executed + command + """ + # noinspection PyCallingNonCallable + operator = Operator( + task_id=f'{Operator.__name__}', + dag=airflow_dag + ) + if isinstance(expected_command, Exception): + with pytest.raises(expected_command.__class__): + operator.execute(None) + else: operator.execute(None) - mock_run_cli.assert_called_once_with('test') + mock_run_dbt.assert_called_once_with(expected_command) - @mock.patch.object(DbtCliHook, 'run_cli') - def test_dbt_snapshot(self, mock_run_cli): - operator = DbtSnapshotOperator( - task_id='snapshot', - dag=self.dag - ) - operator.execute(None) - mock_run_cli.assert_called_once_with('snapshot') - @mock.patch.object(DbtCliHook, 'run_cli') - def test_dbt_seed(self, mock_run_cli): - operator = DbtSeedOperator( - task_id='seed', - dag=self.dag +def test_dbt_warns_about_dir_param(airflow_dag: DAG): + """ + Test that the DbtBaseOperator warns about the use of the dir parameter + """ + with pytest.warns(PendingDeprecationWarning): + DbtBaseOperator( + task_id='test_task_id', + dag=airflow_dag, + dir='/tmp/dbt' ) - operator.execute(None) - mock_run_cli.assert_called_once_with('seed') - @mock.patch.object(DbtCliHook, 'run_cli') - def test_dbt_deps(self, mock_run_cli): - operator = DbtDepsOperator( - task_id='deps', - dag=self.dag - ) - operator.execute(None) - mock_run_cli.assert_called_once_with('deps') + +@patch.object(DbtCloudBuildHook, '__init__', return_value=None) +def test_cloud_build_operator_instantiates_hook( + cloud_build_hook_constructor: MagicMock, + airflow_dag: DAG +): + hook = DbtCloudBuildOperator( + task_id='test_cloud_build', + gcs_staging_location='gs://my_bucket/dbt_proj.tar.gz', + env={'CONFIG_VAR': 'HELLO'}, + config={'project_dir': 'not used'}, + project_id='my_project', + gcp_conn_id='test_gcp_conn', + dbt_version='0.19.2', + service_account='dbt-sa@google.com', + dag=airflow_dag + ) + hook.instantiate_hook() + + cloud_build_hook_constructor.assert_called_once_with( + env={'CONFIG_VAR': 'HELLO'}, + gcs_staging_location='gs://my_bucket/dbt_proj.tar.gz', + project_id='my_project', + gcp_conn_id='test_gcp_conn', + dbt_version='0.19.2', + service_account='dbt-sa@google.com' + )