diff --git a/airflow_dbt/dbt_command_params.py b/airflow_dbt/dbt_command_params.py new file mode 100644 index 0000000..9b4b5e9 --- /dev/null +++ b/airflow_dbt/dbt_command_params.py @@ -0,0 +1,71 @@ +# Python versions older than 3.8 have the TypedDict in a different namespace. +# In case we find ourselves in that situation, we use the `older` import +try: + from typing import TypedDict +except ImportError: + from typing_extensions import TypedDict + + +class DbtGlobalParamsConfig(TypedDict, total=False): + """ + Holds the structure of a dictionary containing dbt config. Provides the + types and names for each one, and also helps shortening the constructor + since we can nest it and reuse it + """ + record_timing_info: bool + debug: bool + log_format: str # either 'text', 'json' or 'default' + write_json: bool + warn_error: bool + partial_parse: bool + use_experimental_parser: bool + use_colors: bool + verbose: bool + no_use_colors: bool + + +class DbtCommandParamsConfig(TypedDict, total=False): + """ + Holds the structure of a dictionary containing dbt config. Provides the + types and names for each one, and also helps shortening the constructor + since we can nest it and reuse it + """ + profiles_dir: str + project_dir: str + target: str + vars: dict + models: str + exclude: str + + # run specific + full_refresh: bool + profile: str + + # docs specific + no_compile: bool + + # debug specific + config_dir: str + + # ls specific + resource_type: str # models, snapshots, seeds, tests, and sources. + select: str + models: str + exclude: str + selector: str + output: str + output_keys: str + + # rpc specific + host: str + port: int + + # run specific + fail_fast: bool + + # run-operation specific + args: dict + + # test specific + data: bool + schema: bool diff --git a/airflow_dbt/hooks/dbt_hook.py b/airflow_dbt/hooks/dbt_hook.py index a16e53f..3fb91fc 100644 --- a/airflow_dbt/hooks/dbt_hook.py +++ b/airflow_dbt/hooks/dbt_hook.py @@ -1,139 +1,131 @@ from __future__ import print_function + +import json import os import signal import subprocess -import json +from typing import Any, Dict, List, Union + from airflow.exceptions import AirflowException -from airflow.hooks.base_hook import BaseHook +from airflow.hooks.base import BaseHook + +from airflow_dbt.dbt_command_params import DbtCommandParamsConfig + + +def render_config(config: Dict[str, Union[str, bool]]) -> List[str]: + """Renders a dictionary of options into a list of cli strings""" + dbt_command_config_annotations = DbtCommandParamsConfig.__annotations__ + command_params = [] + for key, value in config.items(): + if key not in dbt_command_config_annotations: + raise ValueError(f"{key} is not a valid key") + if value is not None: + param_value_type = type(value) + # check that the value has the correct type from dbt_command_config_annotations + if param_value_type != dbt_command_config_annotations[key]: + raise TypeError(f"{key} has to be of type {dbt_command_config_annotations[key]}") + # if the param is not bool it must have a non-null value + flag_prefix = '' + if param_value_type is bool and not value: + flag_prefix = 'no-' + cli_param_from_kwarg = "--" + flag_prefix + key.replace("_", "-") + command_params.append(cli_param_from_kwarg) + if param_value_type is str: + command_params.append(value) + elif param_value_type is int: + command_params.append(str(value)) + elif param_value_type is dict: + command_params.append(json.dumps(value)) + return command_params + + +def generate_dbt_cli_command( + dbt_bin: str, + command: str, + global_config: Dict[str, Union[str, bool]], + command_config: Dict[str, Union[str, bool]], +) -> List[str]: + """ + Creates a CLI string from the keys in the dictionary. If the key is none + it is ignored. If the key is of type boolean the name of the key is added. + If the key is of type string it adds the the key prefixed with tow dashes. + If the key is of type integer it adds the the key prefixed with three + dashes. + dbt_bin and command are mandatory. + Boolean flags must always be positive. + Available params are: + :param command_config: Specific params for the commands + :type command_config: dict + :param global_config: Params that apply to the `dbt` program regardless of + the command it is running + :type global_config: dict + :param command: The dbt sub-command to run + :type command: str + :param dbt_bin: Path to the dbt binary, defaults to `dbt` assumes it is + available in the PATH. + :type dbt_bin: str + :param command: The dbt sub command to run, for example for `dbt run` + the base_command will be `run`. If any other flag not contemplated + must be included it can also be added to this string + :type command: str + """ + if not dbt_bin: + raise ValueError("dbt_bin is mandatory") + if not command: + raise ValueError("command mandatory") + base_params = render_config(global_config) + command_params = render_config(command_config) + # commands like 'dbt docs generate' need the command to be split in two + command_pieces = command.split(" ") + return [dbt_bin, *base_params, *command_pieces, *command_params] class DbtCliHook(BaseHook): """ Simple wrapper around the dbt CLI. - :param env: If set, passes the env variables to the subprocess handler + :param env: Environment variables to pass to the dbt process :type env: dict - :param profiles_dir: If set, passed as the `--profiles-dir` argument to the `dbt` command - :type profiles_dir: str - :param target: If set, passed as the `--target` argument to the `dbt` command - :type dir: str - :param dir: The directory to run the CLI in - :type vars: str - :param vars: If set, passed as the `--vars` argument to the `dbt` command - :type vars: dict - :param full_refresh: If `True`, will fully-refresh incremental models. - :type full_refresh: bool - :param models: If set, passed as the `--models` argument to the `dbt` command - :type models: str - :param warn_error: If `True`, treat warnings as errors. - :type warn_error: bool - :param exclude: If set, passed as the `--exclude` argument to the `dbt` command - :type exclude: str - :param select: If set, passed as the `--select` argument to the `dbt` command - :type select: str - :param selector: If set, passed as the `--selector` argument to the `dbt` command - :type selector: str - :param dbt_bin: The `dbt` CLI. Defaults to `dbt`, so assumes it's on your `PATH` + :param dbt_bin: Path to the dbt binary :type dbt_bin: str - :param output_encoding: Output encoding of bash command. Defaults to utf-8 + :param global_flags: Global flags to pass to the dbt process + :type global_flags: dict + :param command_flags: Command flags to pass to the dbt process + :type command_flags: dict + :param command: The dbt command to run + :type command: str + :param output_encoding: The encoding of the output :type output_encoding: str - :param verbose: The operator will log verbosely to the Airflow logs - :type verbose: bool """ - def __init__(self, - env=None, - profiles_dir=None, - target=None, - dir='.', - vars=None, - full_refresh=False, - data=False, - schema=False, - models=None, - exclude=None, - select=None, - selector=None, - dbt_bin='dbt', - output_encoding='utf-8', - verbose=True, - warn_error=False): + def __init__( + self, + env: dict = None, + output_encoding: str = 'utf-8', + ): + super().__init__() self.env = env or {} - self.profiles_dir = profiles_dir - self.dir = dir - self.target = target - self.vars = vars - self.full_refresh = full_refresh - self.data = data - self.schema = schema - self.models = models - self.exclude = exclude - self.select = select - self.selector = selector - self.dbt_bin = dbt_bin - self.verbose = verbose - self.warn_error = warn_error self.output_encoding = output_encoding + self.sp = None # declare the terminal to be user later on - def _dump_vars(self): - # The dbt `vars` parameter is defined using YAML. Unfortunately the standard YAML library - # for Python isn't very good and I couldn't find an easy way to have it formatted - # correctly. However, as YAML is a super-set of JSON, this works just fine. - return json.dumps(self.vars) + def get_conn(self) -> Any: + """Implements the get_conn method of the BaseHook class""" + pass - def run_cli(self, *command): + def run_cli(self, dbt_cmd: List[str]): """ - Run the dbt cli + Run the rendered dbt command - :param command: The dbt command to run - :type command: str + :param dbt_cmd: The dbt command to run + :type dbt_cmd: list """ - - dbt_cmd = [self.dbt_bin, *command] - - if self.profiles_dir is not None: - dbt_cmd.extend(['--profiles-dir', self.profiles_dir]) - - if self.target is not None: - dbt_cmd.extend(['--target', self.target]) - - if self.vars is not None: - dbt_cmd.extend(['--vars', self._dump_vars()]) - - if self.data: - dbt_cmd.extend(['--data']) - - if self.schema: - dbt_cmd.extend(['--schema']) - - if self.models is not None: - dbt_cmd.extend(['--models', self.models]) - - if self.exclude is not None: - dbt_cmd.extend(['--exclude', self.exclude]) - - if self.select is not None: - dbt_cmd.extend(['--select', self.select]) - - if self.selector is not None: - dbt_cmd.extend(['--selector', self.selector]) - - if self.full_refresh: - dbt_cmd.extend(['--full-refresh']) - - if self.warn_error: - dbt_cmd.insert(1, '--warn-error') - - if self.verbose: - self.log.info(" ".join(dbt_cmd)) - sp = subprocess.Popen( - dbt_cmd, + args=dbt_cmd, env=self.env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - cwd=self.dir, - close_fds=True) + close_fds=True + ) self.sp = sp self.log.info("Output:") line = '' @@ -150,5 +142,6 @@ def run_cli(self, *command): raise AirflowException("dbt command failed") def on_kill(self): + """Called when the task is killed by Airflow. This will kill the dbt process and wait for it to exit""" self.log.info('Sending SIGTERM signal to dbt command') os.killpg(os.getpgid(self.sp.pid), signal.SIGTERM) diff --git a/airflow_dbt/operators/dbt_operator.py b/airflow_dbt/operators/dbt_operator.py index 5f8d632..88e85d3 100644 --- a/airflow_dbt/operators/dbt_operator.py +++ b/airflow_dbt/operators/dbt_operator.py @@ -1,7 +1,12 @@ -from airflow_dbt.hooks.dbt_hook import DbtCliHook +import json +import warnings + from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults +from airflow_dbt.dbt_command_params import DbtCommandParamsConfig, DbtGlobalParamsConfig +from airflow_dbt.hooks.dbt_hook import DbtCliHook, generate_dbt_cli_command + class DbtBaseOperator(BaseOperator): """ @@ -13,9 +18,14 @@ class DbtBaseOperator(BaseOperator): :param profiles_dir: If set, passed as the `--profiles-dir` argument to the `dbt` command :type profiles_dir: str :param target: If set, passed as the `--target` argument to the `dbt` command + :type target: str + :param dir: + The directory to run the CLI in + ..deprecated:: 0.4.0 + Use ``project_dir`` instead :type dir: str - :param dir: The directory to run the CLI in - :type vars: str + :param project_dir: The directory to run the CLI in + :type project_dir: str :param vars: If set, passed as the `--vars` argument to the `dbt` command :type vars: dict :param full_refresh: If `True`, will fully-refresh incremental models. @@ -34,130 +44,169 @@ class DbtBaseOperator(BaseOperator): :type dbt_bin: str :param verbose: The operator will log verbosely to the Airflow logs :type verbose: bool + :param record_timing_info: stores runtime info in a text file to be analyzed later + :type record_timing_info: str + :param debug: if true prints debug information during the runtime + :type debug: bool + :param log_format: Determines the log format for the debug information. Only allowed value is 'json' + :type log_format: str. Must be 'json'. If set all output will be json instead of plain text + :param write_json: determines whether dbt writes JSON artifacts (eg. manifest.json, run_results.json) to the + target/ directory. JSON serialization can be slow, and turning this flag off might make invocations of dbt + faster. Alternatively, you might disable this config if you want to perform a dbt operation and avoid overwriting + artifacts from a previous run step. + :type write_json: bool + :param partial_parse: turn partial parsing on or off in your project + :type partial_parse: bool + :param use_experimental_parser: use experimental parser + :type use_experimental_parser: bool + :param use_colors: display logs using escaped colors in the terminal + :type use_colors: bool + :param fail_fast: stop execution as soon as one error is found + :type fail_fast: bool + :param command: the main command to use for dbt. Can be used to invoke the Operator raw with an arbitrary command + :type command: str + :param version: print the version of dbt installed + :type version: bool """ ui_color = '#d6522a' - template_fields = ['env', 'vars'] + template_fields = ['env', 'dbt_bin', 'command', 'command_config', 'global_config'] @apply_defaults - def __init__(self, - env=None, - profiles_dir=None, - target=None, - dir='.', - vars=None, - models=None, - exclude=None, - select=None, - selector=None, - dbt_bin='dbt', - verbose=True, - warn_error=False, - full_refresh=False, - data=False, - schema=False, - *args, - **kwargs): + def __init__( + self, + env: dict = None, + profiles_dir: str = None, + target=None, + dir: str = None, + project_dir: str = '.', + vars: dict = None, + models: str = None, + exclude: str = None, + select: str = None, + selector: str = None, + dbt_bin: str = 'dbt', + verbose: bool = None, + warn_error: bool = None, + full_refresh: bool = None, + data=None, + schema=None, + record_timing_info: bool = None, + debug: bool = None, + log_format: str = None, + write_json: bool = None, + partial_parse: bool = None, + use_experimental_parser: bool = None, + use_colors: bool = None, + fail_fast: bool = None, + command: str = None, + version: bool = None, + *args, + **kwargs + ): super(DbtBaseOperator, self).__init__(*args, **kwargs) + # dbt has a global param to specify the directory containing the project. Also, `dir` shadows a global + # python function for listing directory contents. + if dir is not None: + warnings.warn('"dir" param is deprecated in favor of dbt native param "project_dir"') + + # global flags + global_config: DbtGlobalParamsConfig = { + 'record_timing_info': record_timing_info, + 'debug': debug, + 'log_format': log_format, + 'warn_error': warn_error, + 'write_json': write_json, + 'partial_parse': partial_parse, + 'use_experimental_parser': use_experimental_parser, + 'use_colors': use_colors, + 'verbose': verbose, + 'target': target, + 'version': version, + } + # per command flags + command_config: DbtCommandParamsConfig = { + 'profiles_dir': profiles_dir, + 'project_dir': project_dir or dir, + 'full_refresh': full_refresh, + 'models': models, + 'exclude': exclude, + 'select': select, + 'selector': selector, + 'data': data, + 'fail_fast': fail_fast, + 'schema': schema, + 'vars': json.dumps(vars) if vars is not None else None, + } self.env = env or {} - self.profiles_dir = profiles_dir - self.target = target - self.dir = dir - self.vars = vars - self.models = models - self.full_refresh = full_refresh - self.data = data - self.schema = schema - self.exclude = exclude - self.select = select - self.selector = selector self.dbt_bin = dbt_bin - self.verbose = verbose - self.warn_error = warn_error - self.create_hook() - - def create_hook(self): - self.hook = DbtCliHook( - env=self.env, - profiles_dir=self.profiles_dir, - target=self.target, - dir=self.dir, - vars=self.vars, - full_refresh=self.full_refresh, - data=self.data, - schema=self.schema, - models=self.models, - exclude=self.exclude, - select=self.select, - selector=self.selector, - dbt_bin=self.dbt_bin, - verbose=self.verbose, - warn_error=self.warn_error) + self.command = command + # filter out None values from the config + self.global_config = {k: v for k, v in global_config.items() if v is not None} + self.command_config = {k: v for k, v in command_config.items() if v is not None} + self.hook = self.create_hook() - return self.hook + def create_hook(self) -> DbtCliHook: + """Create the hook to be used by the operator. This is useful for subclasses to override""" + return DbtCliHook(env=self.env) + + def execute(self, context): + """Execute the dbt command""" + dbt_full_command = generate_dbt_cli_command( + dbt_bin=self.dbt_bin, + command=self.command, + global_config=self.global_config, + command_config=self.command_config, + ) + self.hook.run_cli(dbt_full_command) class DbtRunOperator(DbtBaseOperator): + """ Runs a dbt run command. """ @apply_defaults - def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtRunOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, **kwargs) - - def execute(self, context): - self.create_hook().run_cli('run') + def __init__(self, command='', *args, **kwargs): + super(DbtRunOperator, self).__init__(command='run', *args, **kwargs) class DbtTestOperator(DbtBaseOperator): + """ Runs a dbt test command. """ @apply_defaults - def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtTestOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, **kwargs) - - def execute(self, context): - self.create_hook().run_cli('test') + def __init__(self, *args, **kwargs): + super(DbtTestOperator, self).__init__(command='test', *args, **kwargs) class DbtDocsGenerateOperator(DbtBaseOperator): + """ Runs a dbt docs generate command. """ @apply_defaults - def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtDocsGenerateOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, - **kwargs) - - def execute(self, context): - self.create_hook().run_cli('docs', 'generate') + def __init__(self, *args, **kwargs): + super(DbtDocsGenerateOperator, self).__init__(command='docs generate', *args, **kwargs) class DbtSnapshotOperator(DbtBaseOperator): + """ Runs a dbt snapshot command. """ @apply_defaults - def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtSnapshotOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, **kwargs) - - def execute(self, context): - self.create_hook().run_cli('snapshot') + def __init__(self, *args, **kwargs): + super(DbtSnapshotOperator, self).__init__(command='snapshot', *args, **kwargs) class DbtSeedOperator(DbtBaseOperator): + """ Runs a dbt seed command. """ @apply_defaults - def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtSeedOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, **kwargs) - - def execute(self, context): - self.create_hook().run_cli('seed') + def __init__(self, *args, **kwargs): + super(DbtSeedOperator, self).__init__(command='seed', *args, **kwargs) class DbtDepsOperator(DbtBaseOperator): + """ Runs a dbt deps command. """ @apply_defaults - def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtDepsOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, **kwargs) - - def execute(self, context): - self.create_hook().run_cli('deps') + def __init__(self, command='', *args, **kwargs): + super(DbtDepsOperator, self).__init__(command='deps', *args, **kwargs) class DbtCleanOperator(DbtBaseOperator): + """ Runs a dbt clean command. """ @apply_defaults - def __init__(self, profiles_dir=None, target=None, *args, **kwargs): - super(DbtCleanOperator, self).__init__(profiles_dir=profiles_dir, target=target, *args, **kwargs) - - def execute(self, context): - self.create_hook().run_cli('clean') + def __init__(self, command='', *args, **kwargs): + super(DbtCleanOperator, self).__init__(command='clean', *args, **kwargs) diff --git a/tests/hooks/test_dbt_hook.py b/tests/hooks/test_dbt_hook.py index 383a953..626da3b 100644 --- a/tests/hooks/test_dbt_hook.py +++ b/tests/hooks/test_dbt_hook.py @@ -1,6 +1,6 @@ -from unittest import TestCase -from unittest import mock import subprocess +from unittest import TestCase, mock + from airflow_dbt.hooks.dbt_hook import DbtCliHook @@ -15,67 +15,29 @@ def test_sub_commands(self, mock_subproc_popen): .stdout.readline.side_effect = [b"placeholder"] hook = DbtCliHook() - hook.run_cli('docs', 'generate') - - mock_subproc_popen.assert_called_once_with( - [ - 'dbt', - 'docs', - 'generate' - ], - env={}, - close_fds=True, - cwd='.', - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT - ) - - @mock.patch('subprocess.Popen') - def test_vars(self, mock_subproc_popen): - mock_subproc_popen.return_value \ - .communicate.return_value = ('output', 'error') - mock_subproc_popen.return_value.returncode = 0 - mock_subproc_popen.return_value \ - .stdout.readline.side_effect = [b"placeholder"] - - hook = DbtCliHook(vars={"foo": "bar", "baz": "true"}) - hook.run_cli('run') + hook.run_cli(['dbt', 'docs', 'generate']) mock_subproc_popen.assert_called_once_with( - [ - 'dbt', - 'run', - '--vars', - '{"foo": "bar", "baz": "true"}' - ], + args=['dbt', 'docs', 'generate'], env={}, close_fds=True, - cwd='.', stdout=subprocess.PIPE, stderr=subprocess.STDOUT - ) + ) @mock.patch('subprocess.Popen') def test_envs(self, mock_subproc_popen): - mock_subproc_popen.return_value \ - .communicate.return_value = ('output', 'error') + mock_subproc_popen.return_value.communicate.return_value = ('output', 'error') mock_subproc_popen.return_value.returncode = 0 - mock_subproc_popen.return_value \ - .stdout.readline.side_effect = [b"placeholder"] + mock_subproc_popen.return_value.stdout.readline.side_effect = [b"placeholder"] - hook = DbtCliHook(vars={"foo": "bar", "baz": "true"}, env={"foo": "bar", "baz": "true"}) - hook.run_cli('run') + hook = DbtCliHook(env={"foo": "bar", "baz": "true"}) + hook.run_cli(['dbt', 'run']) mock_subproc_popen.assert_called_once_with( - [ - 'dbt', - 'run', - '--vars', - '{"foo": "bar", "baz": "true"}' - ], + args=['dbt', 'run'], env={"foo": "bar", "baz": "true"}, close_fds=True, - cwd='.', stdout=subprocess.PIPE, stderr=subprocess.STDOUT ) diff --git a/tests/operators/test_dbt_operator.py b/tests/operators/test_dbt_operator.py index 78604d1..0b5f247 100644 --- a/tests/operators/test_dbt_operator.py +++ b/tests/operators/test_dbt_operator.py @@ -28,7 +28,7 @@ def test_dbt_run(self, mock_run_cli): dag=self.dag ) operator.execute(None) - mock_run_cli.assert_called_once_with('run') + mock_run_cli.assert_called_once_with(['dbt', 'run', '--project-dir', '.']) @mock.patch.object(DbtCliHook, 'run_cli') def test_dbt_test(self, mock_run_cli): @@ -37,7 +37,7 @@ def test_dbt_test(self, mock_run_cli): dag=self.dag ) operator.execute(None) - mock_run_cli.assert_called_once_with('test') + mock_run_cli.assert_called_once_with(['dbt', 'test', '--project-dir', '.']) @mock.patch.object(DbtCliHook, 'run_cli') def test_dbt_snapshot(self, mock_run_cli): @@ -46,7 +46,7 @@ def test_dbt_snapshot(self, mock_run_cli): dag=self.dag ) operator.execute(None) - mock_run_cli.assert_called_once_with('snapshot') + mock_run_cli.assert_called_once_with(['dbt', 'snapshot', '--project-dir', '.']) @mock.patch.object(DbtCliHook, 'run_cli') def test_dbt_seed(self, mock_run_cli): @@ -55,7 +55,7 @@ def test_dbt_seed(self, mock_run_cli): dag=self.dag ) operator.execute(None) - mock_run_cli.assert_called_once_with('seed') + mock_run_cli.assert_called_once_with(['dbt', 'seed', '--project-dir', '.']) @mock.patch.object(DbtCliHook, 'run_cli') def test_dbt_deps(self, mock_run_cli): @@ -64,7 +64,7 @@ def test_dbt_deps(self, mock_run_cli): dag=self.dag ) operator.execute(None) - mock_run_cli.assert_called_once_with('deps') + mock_run_cli.assert_called_once_with(['dbt', 'deps', '--project-dir', '.']) @mock.patch.object(DbtCliHook, 'run_cli') def test_dbt_clean(self, mock_run_cli): @@ -73,4 +73,4 @@ def test_dbt_clean(self, mock_run_cli): dag=self.dag ) operator.execute(None) - mock_run_cli.assert_called_once_with('clean') + mock_run_cli.assert_called_once_with(['dbt', 'clean', '--project-dir', '.'])