From 7bf96bdffd089e6cd55a7b0eaa3b3c1d110af65d Mon Sep 17 00:00:00 2001 From: Stu Kilgore Date: Wed, 5 Apr 2023 10:38:53 -0500 Subject: [PATCH] Handle internal exceptions in postflight --- .../unreleased/Fixes-20230329-113203.yaml | 6 + core/dbt/cli/README.md | 41 ++++++- core/dbt/cli/example.py | 4 +- core/dbt/cli/exceptions.py | 43 +++++++ core/dbt/cli/flags.py | 10 +- core/dbt/cli/main.py | 115 ++++++++++++------ core/dbt/cli/requires.py | 51 +++----- core/dbt/docs/source/index.rst | 4 +- core/dbt/tests/util.py | 12 +- .../tests/adapter/dbt_debug/test_dbt_debug.py | 4 +- tests/functional/cli/test_cli_exit_codes.py | 4 +- tests/functional/cli/test_error_handling.py | 3 + .../functional/dbt_runner/test_dbt_runner.py | 51 ++++---- .../defer_state/test_defer_state.py | 4 +- tests/functional/fail_fast/fixtures.py | 24 ---- .../fail_fast/test_fail_fast_run.py | 59 ++++----- tests/unit/test_cli_flags.py | 13 +- tests/unit/test_dbt_runner.py | 38 ++++++ 18 files changed, 310 insertions(+), 176 deletions(-) create mode 100644 .changes/unreleased/Fixes-20230329-113203.yaml create mode 100644 core/dbt/cli/exceptions.py delete mode 100644 tests/functional/fail_fast/fixtures.py create mode 100644 tests/unit/test_dbt_runner.py diff --git a/.changes/unreleased/Fixes-20230329-113203.yaml b/.changes/unreleased/Fixes-20230329-113203.yaml new file mode 100644 index 00000000000..acde149c12b --- /dev/null +++ b/.changes/unreleased/Fixes-20230329-113203.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Handle internal exceptions +time: 2023-03-29T11:32:03.259072-05:00 +custom: + Author: stu-k + Issue: "7118" diff --git a/core/dbt/cli/README.md b/core/dbt/cli/README.md index 1333ed77b7e..d4a3c7a6d4a 100644 --- a/core/dbt/cli/README.md +++ b/core/dbt/cli/README.md @@ -1 +1,40 @@ -TODO +# Exception Handling + +## `requires.py` + +### `postflight` +In the postflight decorator, the click command is invoked (i.e. `func(*args, **kwargs)`) and wrapped in a `try/except` block to handle any exceptions thrown. +Any exceptions thrown from `postflight` are wrapped by custom exceptions from the `dbt.cli.exceptions` module (i.e. `ResultExit`, `ExceptionExit`) to instruct click to complete execution with a particular exit code. + +Some `dbt-core` handled exceptions have an attribute named `results` which contains results from running nodes (e.g. `FailFastError`). These are wrapped in the `ResultExit` exception to represent runs that have failed in a way that `dbt-core` expects. +If the invocation of the command does not throw any exceptions but does not succeed, `postflight` will still raise the `ResultExit` exception to make use of the exit code. +These exceptions produce an exit code of `1`. + +Exceptions wrapped with `ExceptionExit` may be thrown by `dbt-core` intentionally (i.e. an exception that inherits from `dbt.exceptions.Exception`) or unintentionally (i.e. exceptions thrown by the python runtime). In either case these are considered errors that `dbt-core` did not expect and are treated as genuine exceptions. +These exceptions produce an exit code of `2`. + +If no exceptions are thrown from invoking the command and the command succeeds, `postflight` will not raise any exceptions. +When no exceptions are raised an exit code of `0` is produced. + +## `main.py` + +### `dbtRunner` +`dbtRunner` provides a programmatic interface for our click CLI and wraps the invocation of the click commands to handle any exceptions thrown. + +`dbtRunner.invoke` should ideally only ever return an instantiated `DbtRunnerResult` which contains the following fields: +- `success`: A boolean representing whether the command invocation was successful +- `result`: The optional result of the command invoked. This attribute can have many types, please see the definition of `DbtRunnerResult` for more information +- `exception`: If an exception was thrown during command invocation it will be saved here, otherwise it will be `None`. Please note that the exceptions held in this attribute are not the exceptions thrown by `preflight` but instead the exceptions that `ResultExit` and `ExceptionExit` wrap + +## `dbt/tests/util.py` + +### `run_dbt` +In many of our functional and integration tests, we want to be sure that an invocation of `dbt` raises a certain exception. +A common pattern for these assertions: +```python +class TestSomething: + def test_something(self, project): + with pytest.raises(SomeException): + run_dbt(["run"]) +``` +To allow these tests to assert that exceptions have been thrown, the `run_dbt` function will raise any exceptions it recieves from the invocation of a `dbt` command. diff --git a/core/dbt/cli/example.py b/core/dbt/cli/example.py index afa6820efc8..71a99b646dd 100644 --- a/core/dbt/cli/example.py +++ b/core/dbt/cli/example.py @@ -8,7 +8,7 @@ # initialize the dbt runner dbt = dbtRunner() # run the command - res, success = dbt.invoke(cli_args) + res = dbt.invoke(cli_args) # preload profile and project profile = load_profile(project_dir, {}, "testing-postgres") @@ -17,4 +17,4 @@ # initialize the runner with pre-loaded profile and project, you can also pass in a preloaded manifest dbt = dbtRunner(profile=profile, project=project) # run the command, this will use the pre-loaded profile and project instead of loading - res, success = dbt.invoke(cli_args) + res = dbt.invoke(cli_args) diff --git a/core/dbt/cli/exceptions.py b/core/dbt/cli/exceptions.py new file mode 100644 index 00000000000..d88f91c01ac --- /dev/null +++ b/core/dbt/cli/exceptions.py @@ -0,0 +1,43 @@ +from typing import Optional, IO + +from click.exceptions import ClickException +from dbt.utils import ExitCodes + + +class DbtUsageException(Exception): + pass + + +class DbtInternalException(Exception): + pass + + +class CliException(ClickException): + """The base exception class for our implementation of the click CLI. + The exit_code attribute is used by click to determine which exit code to produce + after an invocation.""" + + def __init__(self, exit_code: ExitCodes) -> None: + self.exit_code = exit_code.value + + # the typing of _file is to satisfy the signature of ClickException.show + # overriding this method prevents click from printing any exceptions to stdout + def show(self, _file: Optional[IO] = None) -> None: + pass + + +class ResultExit(CliException): + """This class wraps any exception that contains results while invoking dbt, or the + results of an invocation that did not succeed but did not throw any exceptions.""" + + def __init__(self, result) -> None: + super().__init__(ExitCodes.ModelError) + self.result = result + + +class ExceptionExit(CliException): + """This class wraps any exception that does not contain results thrown while invoking dbt.""" + + def __init__(self, exception: Exception) -> None: + super().__init__(ExitCodes.UnhandledError) + self.exception = exception diff --git a/core/dbt/cli/flags.py b/core/dbt/cli/flags.py index f50e3a4e605..24e542df42c 100644 --- a/core/dbt/cli/flags.py +++ b/core/dbt/cli/flags.py @@ -7,11 +7,12 @@ from pprint import pformat as pf from typing import Callable, Dict, List, Set -from click import Context, get_current_context, BadOptionUsage +from click import Context, get_current_context from click.core import ParameterSource, Command, Group from dbt.config.profile import read_user_config from dbt.contracts.project import UserConfig +from dbt.cli.exceptions import DbtUsageException from dbt.deprecations import renamed_env_var from dbt.helper_types import WarnErrorOptions from dbt.cli.resolvers import default_project_dir, default_log_path @@ -137,8 +138,7 @@ def assign_params(ctx, params_assigned_from_default, deprecated_env_vars): if param_source == ParameterSource.DEFAULT: continue elif param_source != ParameterSource.ENVIRONMENT: - raise BadOptionUsage( - param_name, + raise DbtUsageException( "Deprecated parameters can only be set via environment variables", ) @@ -268,8 +268,8 @@ def _assert_mutually_exclusive( for flag in group: flag_set_by_user = flag.lower() not in params_assigned_from_default if flag_set_by_user and set_flag: - raise BadOptionUsage( - flag.lower(), f"{flag.lower()}: not allowed with argument {set_flag.lower()}" + raise DbtUsageException( + f"{flag.lower()}: not allowed with argument {set_flag.lower()}" ) elif flag_set_by_user: set_flag = flag diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index 4f7d7440be9..7379a9354ba 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -1,12 +1,29 @@ from copy import copy -from typing import Callable, List, Tuple, Optional +from dataclasses import dataclass +from typing import Callable, List, Optional, Union import click +from click.exceptions import ( + Exit as ClickExit, + BadOptionUsage, + NoSuchOption, + UsageError, +) from dbt.cli import requires, params as p +from dbt.cli import requires, params as p +from dbt.cli.exceptions import ( + DbtInternalException, + DbtUsageException, +) from dbt.config.profile import Profile from dbt.config.project import Project from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.results import ( + CatalogArtifact, + RunExecutionResult, + RunOperationResultsArtifact, +) from dbt.events.base_types import EventMsg from dbt.task.build import BuildTask from dbt.task.clean import CleanTask @@ -26,12 +43,22 @@ from dbt.task.test import TestTask -class dbtUsageException(Exception): - pass +@dataclass +class DbtRunnerResult: + """Contains the result of an invocation of the dbtRunner""" + success: bool -class dbtInternalException(Exception): - pass + exception: Optional[BaseException] = None + result: Optional[ # None: clean, deps, init, parse, source + Union[ + bool, # debug + CatalogArtifact, # docs generate + List[str], # list/ls + RunExecutionResult, # build, compile, run, seed, snapshot, test + RunOperationResultsArtifact, # run-operation + ] + ] = None # Programmatic invocation @@ -48,7 +75,7 @@ def __init__( self.manifest = manifest self.callbacks = callbacks - def invoke(self, args: List[str], **kwargs) -> Tuple[Optional[List], bool]: + def invoke(self, args: List[str], **kwargs) -> DbtRunnerResult: try: dbt_ctx = cli.make_context(cli.name, args) dbt_ctx.obj = { @@ -63,18 +90,38 @@ def invoke(self, args: List[str], **kwargs) -> Tuple[Optional[List], bool]: # Hack to set parameter source to custom string dbt_ctx.set_parameter_source(key, "kwargs") # type: ignore - return cli.invoke(dbt_ctx) - except requires.HandledExit as e: - return (e.result, e.success) - except requires.UnhandledExit as e: - raise e.exception - except click.exceptions.Exit as e: - # 0 exit code, expected for --version early exit + result, success = cli.invoke(dbt_ctx) + return DbtRunnerResult( + result=result, + success=success, + ) + except requires.ResultExit as e: + return DbtRunnerResult( + result=e.result, + success=False, + ) + except requires.ExceptionExit as e: + return DbtRunnerResult( + exception=e.exception, + success=False, + ) + except (BadOptionUsage, NoSuchOption, UsageError) as e: + return DbtRunnerResult( + exception=DbtUsageException(e.message), + success=False, + ) + except ClickExit as e: if str(e) == "0": - return [], True - raise dbtInternalException(f"unhandled exit code {str(e)}") - except (click.NoSuchOption, click.UsageError) as e: - raise dbtUsageException(e.message) + return DbtRunnerResult(success=True) + return DbtRunnerResult( + exception=DbtInternalException(f"unhandled exit code {str(e)}"), + success=False, + ) + except BaseException as e: + return DbtRunnerResult( + exception=e, + success=False, + ) # dbt @@ -145,12 +192,12 @@ def cli(ctx, **kwargs): @p.threads @p.vars @p.version_check +@requires.postflight @requires.preflight @requires.profile @requires.project @requires.runtime_config @requires.manifest -@requires.postflight def build(ctx, **kwargs): """Run all Seeds, Models, Snapshots, and tests in DAG order""" task = BuildTask( @@ -172,10 +219,10 @@ def build(ctx, **kwargs): @p.project_dir @p.target @p.vars +@requires.postflight @requires.preflight @requires.unset_profile @requires.project -@requires.postflight def clean(ctx, **kwargs): """Delete all folders in the clean-targets list (usually the dbt_packages and target directories.)""" task = CleanTask(ctx.obj["flags"], ctx.obj["project"]) @@ -213,12 +260,12 @@ def docs(ctx, **kwargs): @p.threads @p.vars @p.version_check +@requires.postflight @requires.preflight @requires.profile @requires.project @requires.runtime_config @requires.manifest(write=False) -@requires.postflight def docs_generate(ctx, **kwargs): """Generate the documentation website for your project""" task = GenerateTask( @@ -242,12 +289,12 @@ def docs_generate(ctx, **kwargs): @p.project_dir @p.target @p.vars +@requires.postflight @requires.preflight @requires.profile @requires.project @requires.runtime_config @requires.manifest -@requires.postflight def docs_serve(ctx, **kwargs): """Serve the documentation website for your project""" task = ServeTask( @@ -287,12 +334,12 @@ def docs_serve(ctx, **kwargs): @p.threads @p.vars @p.version_check +@requires.postflight @requires.preflight @requires.profile @requires.project @requires.runtime_config @requires.manifest -@requires.postflight def compile(ctx, **kwargs): """Generates executable SQL from source, model, test, and analysis files. Compiled SQL files are written to the target/ directory.""" @@ -334,12 +381,12 @@ def compile(ctx, **kwargs): @p.threads @p.vars @p.version_check +@requires.postflight @requires.preflight @requires.profile @requires.project @requires.runtime_config @requires.manifest -@requires.postflight def show(ctx, **kwargs): """Generates executable SQL for a named resource or inline query, runs that SQL, and returns a preview of the results. Does not materialize anything to the warehouse.""" @@ -364,8 +411,8 @@ def show(ctx, **kwargs): @p.target @p.vars @p.version_check -@requires.preflight @requires.postflight +@requires.preflight def debug(ctx, **kwargs): """Show some helpful information about dbt for debugging. Not to be confused with the --debug option which increases verbosity.""" task = DebugTask( @@ -386,10 +433,10 @@ def debug(ctx, **kwargs): @p.project_dir @p.target @p.vars +@requires.postflight @requires.preflight @requires.unset_profile @requires.project -@requires.postflight def deps(ctx, **kwargs): """Pull the most recent version of the dependencies listed in packages.yml""" task = DepsTask(ctx.obj["flags"], ctx.obj["project"]) @@ -409,8 +456,8 @@ def deps(ctx, **kwargs): @p.skip_profile_setup @p.target @p.vars -@requires.preflight @requires.postflight +@requires.preflight def init(ctx, **kwargs): """Initialize a new dbt project.""" task = InitTask(ctx.obj["flags"], None) @@ -438,12 +485,12 @@ def init(ctx, **kwargs): @p.deprecated_state @p.target @p.vars +@requires.postflight @requires.preflight @requires.profile @requires.project @requires.runtime_config @requires.manifest -@requires.postflight def list(ctx, **kwargs): """List the resources in your project""" task = ListTask( @@ -476,12 +523,12 @@ def list(ctx, **kwargs): @p.vars @p.version_check @p.write_manifest +@requires.postflight @requires.preflight @requires.profile @requires.project @requires.runtime_config @requires.manifest(write_perf_info=True) -@requires.postflight def parse(ctx, **kwargs): """Parses the project and provides information on performance""" # manifest generation and writing happens in @requires.manifest @@ -510,12 +557,12 @@ def parse(ctx, **kwargs): @p.threads @p.vars @p.version_check +@requires.postflight @requires.preflight @requires.profile @requires.project @requires.runtime_config @requires.manifest -@requires.postflight def run(ctx, **kwargs): """Compile SQL and execute against the current target database.""" task = RunTask( @@ -539,12 +586,12 @@ def run(ctx, **kwargs): @p.project_dir @p.target @p.vars +@requires.postflight @requires.preflight @requires.profile @requires.project @requires.runtime_config @requires.manifest -@requires.postflight def run_operation(ctx, **kwargs): """Run the named macro with any supplied arguments.""" task = RunOperationTask( @@ -576,12 +623,12 @@ def run_operation(ctx, **kwargs): @p.threads @p.vars @p.version_check +@requires.postflight @requires.preflight @requires.profile @requires.project @requires.runtime_config @requires.manifest -@requires.postflight def seed(ctx, **kwargs): """Load data from csv files into your data warehouse.""" task = SeedTask( @@ -612,12 +659,12 @@ def seed(ctx, **kwargs): @p.target @p.threads @p.vars +@requires.postflight @requires.preflight @requires.profile @requires.project @requires.runtime_config @requires.manifest -@requires.postflight def snapshot(ctx, **kwargs): """Execute snapshots defined in your project""" task = SnapshotTask( @@ -653,12 +700,12 @@ def source(ctx, **kwargs): @p.target @p.threads @p.vars +@requires.postflight @requires.preflight @requires.profile @requires.project @requires.runtime_config @requires.manifest -@requires.postflight def freshness(ctx, **kwargs): """check the current freshness of the project's sources""" task = FreshnessTask( @@ -701,12 +748,12 @@ def freshness(ctx, **kwargs): @p.threads @p.vars @p.version_check +@requires.postflight @requires.preflight @requires.profile @requires.project @requires.runtime_config @requires.manifest -@requires.postflight def test(ctx, **kwargs): """Runs tests on data in deployed models. Run this after `dbt run`""" task = TestTask( diff --git a/core/dbt/cli/requires.py b/core/dbt/cli/requires.py index a879436e186..73c526cf8f5 100644 --- a/core/dbt/cli/requires.py +++ b/core/dbt/cli/requires.py @@ -2,6 +2,10 @@ from dbt.version import installed as installed_version from dbt.adapters.factory import adapter_management, register_adapter from dbt.flags import set_flags, get_flag_dict +from dbt.cli.exceptions import ( + ExceptionExit, + ResultExit, +) from dbt.cli.flags import Flags from dbt.config import RuntimeConfig from dbt.config.runtime import load_project, load_profile, UnsetProfile @@ -13,40 +17,19 @@ MainTrackingUserState, ) from dbt.events.helpers import get_json_string_utcnow -from dbt.exceptions import DbtProjectError +from dbt.events.types import MainEncounteredError, MainStackTrace +from dbt.exceptions import Exception as DbtException, DbtProjectError, FailFastError from dbt.parser.manifest import ManifestLoader, write_manifest from dbt.profiler import profiler from dbt.tracking import active_user, initialize_from_flags, track_run -from dbt.utils import cast_dict_to_dict_of_strings, ExitCodes +from dbt.utils import cast_dict_to_dict_of_strings from click import Context -from click.exceptions import ClickException from functools import update_wrapper import time import traceback -class HandledExit(ClickException): - def __init__(self, result, success, exit_code: ExitCodes) -> None: - self.result = result - self.success = success - self.exit_code = exit_code - - def show(self): - pass - - -class UnhandledExit(ClickException): - exit_code = ExitCodes.UnhandledError.value - - def __init__(self, exception: Exception, message: str) -> None: - self.exception = exception - self.message = message - - def format_message(self) -> str: - return self.message - - def preflight(func): def wrapper(*args, **kwargs): ctx = args[0] @@ -91,6 +74,9 @@ def wrapper(*args, **kwargs): def postflight(func): + """The decorator that handles all exception handling for the click commands. + This decorator must be used before any other decorators that may throw an exception.""" + def wrapper(*args, **kwargs): ctx = args[0] start_func = time.perf_counter() @@ -98,8 +84,15 @@ def wrapper(*args, **kwargs): try: result, success = func(*args, **kwargs) - except Exception as e: - raise UnhandledExit(e, message=traceback.format_exc()) + except BaseException as e: + fire_event(MainEncounteredError(exc=str(e))) + + if not isinstance(e, DbtException): + fire_event(MainStackTrace(stack_trace=traceback.format_exc())) + elif isinstance(e, FailFastError): + raise ResultExit(e.result) + + raise ExceptionExit(e) finally: fire_event( CommandCompleted( @@ -111,11 +104,7 @@ def wrapper(*args, **kwargs): ) if not success: - raise HandledExit( - result=result, - success=success, - exit_code=ExitCodes.ModelError.value, - ) + raise ResultExit(result) return (result, success) diff --git a/core/dbt/docs/source/index.rst b/core/dbt/docs/source/index.rst index dcd1c82499f..476fab207df 100644 --- a/core/dbt/docs/source/index.rst +++ b/core/dbt/docs/source/index.rst @@ -13,7 +13,7 @@ Right now the best way to invoke a command from python runtime is to use the `db # initialize the dbt runner dbt = dbtRunner() # run the command - res, success = dbt.invoke(args) + res = dbt.invoke(args) You can also pass in pre constructed object into dbtRunner, and we will use those objects instead of loading up from the disk. @@ -26,7 +26,7 @@ You can also pass in pre constructed object into dbtRunner, and we will use thos # initialize the runner with pre-loaded profile and project dbt = dbtRunner(profile=profile, project=project) # run the command, this will use the pre-loaded profile and project instead of loading - res, success = dbt.invoke(cli_args) + res = dbt.invoke(cli_args) For the full example code, you can refer to `core/dbt/cli/example.py` diff --git a/core/dbt/tests/util.py b/core/dbt/tests/util.py index d8fb37dfa50..4969c319e2b 100644 --- a/core/dbt/tests/util.py +++ b/core/dbt/tests/util.py @@ -92,13 +92,19 @@ def run_dbt(args: List[str] = None, expect_pass=True): args.extend(["--project-dir", project_dir]) if profiles_dir and "--profiles-dir" not in args: args.extend(["--profiles-dir", profiles_dir]) + dbt = dbtRunner() - res, success = dbt.invoke(args) + res = dbt.invoke(args) + + # the exception is immediately raised to be caught in tests + # using a pattern like `with pytest.raises(SomeException):` + if res.exception is not None: + raise res.exception if expect_pass is not None: - assert success == expect_pass, "dbt exit state did not match expected" + assert res.success == expect_pass, "dbt exit state did not match expected" - return res + return res.result # Use this if you need to capture the command logs in a test. diff --git a/tests/adapter/dbt/tests/adapter/dbt_debug/test_dbt_debug.py b/tests/adapter/dbt/tests/adapter/dbt_debug/test_dbt_debug.py index e2b9beea351..eb973b91728 100644 --- a/tests/adapter/dbt/tests/adapter/dbt_debug/test_dbt_debug.py +++ b/tests/adapter/dbt/tests/adapter/dbt_debug/test_dbt_debug.py @@ -3,7 +3,7 @@ import re import yaml -from dbt.cli.main import dbtUsageException +from dbt.cli.exceptions import DbtUsageException from dbt.tests.util import run_dbt MODELS__MODEL_SQL = """ @@ -88,7 +88,7 @@ def test_badproject(self, project): self.check_project(splitout) def test_not_found_project(self, project): - with pytest.raises(dbtUsageException): + with pytest.raises(DbtUsageException): run_dbt(["debug", "--project-dir", "nopass"]) def test_invalid_project_outside_current_dir(self, project): diff --git a/tests/functional/cli/test_cli_exit_codes.py b/tests/functional/cli/test_cli_exit_codes.py index e3067e5e42e..71c1097ba6a 100644 --- a/tests/functional/cli/test_cli_exit_codes.py +++ b/tests/functional/cli/test_cli_exit_codes.py @@ -1,7 +1,7 @@ import pytest +from dbt.cli.exceptions import ResultExit from dbt.cli.main import cli -from dbt.cli.requires import HandledExit good_sql = """ @@ -34,5 +34,5 @@ def models(self): return {"model_one.sql": bad_sql} def test_exc_thrown(self, project): - with pytest.raises(HandledExit): + with pytest.raises(ResultExit): self.run_cli() diff --git a/tests/functional/cli/test_error_handling.py b/tests/functional/cli/test_error_handling.py index 26bafaa1c1a..83c8a6fc47c 100644 --- a/tests/functional/cli/test_error_handling.py +++ b/tests/functional/cli/test_error_handling.py @@ -15,3 +15,6 @@ def models(self): def test_failed_run_does_not_throw(self, project): run_dbt(["run"], expect_pass=False) + + def test_fail_fast_failed_run_does_not_throw(self, project): + run_dbt(["--fail-fast", "run"], expect_pass=False) diff --git a/tests/functional/dbt_runner/test_dbt_runner.py b/tests/functional/dbt_runner/test_dbt_runner.py index 312db37b99b..20041f05952 100644 --- a/tests/functional/dbt_runner/test_dbt_runner.py +++ b/tests/functional/dbt_runner/test_dbt_runner.py @@ -2,7 +2,8 @@ import pytest -from dbt.cli.main import dbtRunner, dbtUsageException +from dbt.cli.exceptions import DbtUsageException +from dbt.cli.main import dbtRunner from dbt.exceptions import DbtProjectError @@ -12,20 +13,20 @@ def dbt(self) -> dbtRunner: return dbtRunner() def test_group_invalid_option(self, dbt: dbtRunner) -> None: - with pytest.raises(dbtUsageException): - dbt.invoke(["--invalid-option"]) + res = dbt.invoke(["--invalid-option"]) + assert type(res.exception) == DbtUsageException def test_command_invalid_option(self, dbt: dbtRunner) -> None: - with pytest.raises(dbtUsageException): - dbt.invoke(["deps", "--invalid-option"]) + res = dbt.invoke(["deps", "--invalid-option"]) + assert type(res.exception) == DbtUsageException def test_command_mutually_exclusive_option(self, dbt: dbtRunner) -> None: - with pytest.raises(dbtUsageException): - dbt.invoke(["--warn-error", "--warn-error-options", '{"include": "all"}', "deps"]) + res = dbt.invoke(["--warn-error", "--warn-error-options", '{"include": "all"}', "deps"]) + assert type(res.exception) == DbtUsageException def test_invalid_command(self, dbt: dbtRunner) -> None: - with pytest.raises(dbtUsageException): - dbt.invoke(["invalid-command"]) + res = dbt.invoke(["invalid-command"]) + assert type(res.exception) == DbtUsageException def test_invoke_version(self, dbt: dbtRunner) -> None: dbt.invoke(["--version"]) @@ -39,7 +40,7 @@ def test_callbacks(self) -> None: mock_callback.assert_called() def test_invoke_kwargs(self, project, dbt): - results, success = dbt.invoke( + res = dbt.invoke( ["run"], log_format="json", log_path="some_random_path", @@ -47,23 +48,25 @@ def test_invoke_kwargs(self, project, dbt): profile_name="some_random_profile_name", target_dir="some_random_target_dir", ) - assert results.args["log_format"] == "json" - assert results.args["log_path"] == "some_random_path" - assert results.args["version_check"] is False - assert results.args["profile_name"] == "some_random_profile_name" - assert results.args["target_dir"] == "some_random_target_dir" + assert res.result.args["log_format"] == "json" + assert res.result.args["log_path"] == "some_random_path" + assert res.result.args["version_check"] is False + assert res.result.args["profile_name"] == "some_random_profile_name" + assert res.result.args["target_dir"] == "some_random_target_dir" def test_invoke_kwargs_project_dir(self, project, dbt): - with pytest.raises( - DbtProjectError, - match="No dbt_project.yml found at expected path some_random_project_dir", - ): - dbt.invoke(["run"], project_dir="some_random_project_dir") + res = dbt.invoke(["run"], project_dir="some_random_project_dir") + assert type(res.exception) == DbtProjectError + + msg = "No dbt_project.yml found at expected path some_random_project_dir" + assert msg in res.exception.msg def test_invoke_kwargs_profiles_dir(self, project, dbt): - with pytest.raises(DbtProjectError, match="Could not find profile named 'test'"): - dbt.invoke(["run"], profiles_dir="some_random_profiles_dir") + res = dbt.invoke(["run"], profiles_dir="some_random_profiles_dir") + assert type(res.exception) == DbtProjectError + msg = "Could not find profile named 'test'" + assert msg in res.exception.msg def test_invoke_kwargs_and_flags(self, project, dbt): - results, success = dbt.invoke(["--log-format=text", "run"], log_format="json") - assert results.args["log_format"] == "json" + res = dbt.invoke(["--log-format=text", "run"], log_format="json") + assert res.result.args["log_format"] == "json" diff --git a/tests/functional/defer_state/test_defer_state.py b/tests/functional/defer_state/test_defer_state.py index 7b88ba69e8b..a50f09af0d1 100644 --- a/tests/functional/defer_state/test_defer_state.py +++ b/tests/functional/defer_state/test_defer_state.py @@ -5,8 +5,8 @@ import pytest +from dbt.cli.exceptions import DbtUsageException from dbt.tests.util import run_dbt, write_file, rm_file -from dbt.cli.main import dbtUsageException from dbt.exceptions import DbtRuntimeError @@ -99,7 +99,7 @@ def run_and_save_state(self): class TestDeferStateUnsupportedCommands(BaseDeferState): def test_unsupported_commands(self, project): # make sure these commands don"t work with --defer - with pytest.raises(dbtUsageException): + with pytest.raises(DbtUsageException): run_dbt(["seed", "--defer"]) def test_no_state(self, project): diff --git a/tests/functional/fail_fast/fixtures.py b/tests/functional/fail_fast/fixtures.py deleted file mode 100644 index 695cd73bc83..00000000000 --- a/tests/functional/fail_fast/fixtures.py +++ /dev/null @@ -1,24 +0,0 @@ -import pytest -from dbt.tests.fixtures.project import write_project_files - - -models__one_sql = """ -select 1 /failed -""" - -models__two_sql = """ -select 1 /failed -""" - - -@pytest.fixture(scope="class") -def models(): - return {"one.sql": models__one_sql, "two.sql": models__two_sql} - - -@pytest.fixture(scope="class") -def project_files( - project_root, - models, -): - write_project_files(project_root, "models", models) diff --git a/tests/functional/fail_fast/test_fail_fast_run.py b/tests/functional/fail_fast/test_fail_fast_run.py index 5c0c8cf849d..4fb7821d072 100644 --- a/tests/functional/fail_fast/test_fail_fast_run.py +++ b/tests/functional/fail_fast/test_fail_fast_run.py @@ -1,54 +1,36 @@ import pytest +from dbt.contracts.results import RunResult from dbt.tests.util import run_dbt -from tests.functional.fail_fast.fixtures import models, project_files # noqa: F401 -from dbt.exceptions import FailFastError -def check_audit_table(project, count=1): - query = "select * from {schema}.audit".format(schema=project.test_schema) +models__one_sql = """ +select 1 /failed +""" - vals = project.run_sql(query, fetch="all") - assert not (len(vals) == count), "Execution was not stopped before run end" +models__two_sql = """ +select 1 /failed +""" -class TestFastFailingDuringRun: +class FailFastBase: @pytest.fixture(scope="class") - def project_config_update(self): - return { - "config-version": 2, - "on-run-start": "create table if not exists {{ target.schema }}.audit (model text)", - "models": { - "test": { - "pre-hook": [ - { - # we depend on non-deterministic nature of tasks execution - # there is possibility to run next task in-between - # first task failure and adapter connections cancellations - # if you encounter any problems with these tests please report - # the sleep command with random time minimize the risk - "sql": "select pg_sleep(random())", - "transaction": False, - }, - { - "sql": "insert into {{ target.schema }}.audit values ('{{ this }}')", - "transaction": False, - }, - ], - } - }, - } + def models(self): + return {"one.sql": models__one_sql, "two.sql": models__two_sql} + +class TestFastFailingDuringRun(FailFastBase): def test_fail_fast_run( self, project, + models, # noqa: F811 ): - with pytest.raises(FailFastError): - run_dbt(["run", "--threads", "1", "--fail-fast"]) - check_audit_table(project) + res = run_dbt(["run", "--fail-fast", "--threads", "1"], expect_pass=False) + # a RunResult contains only one node so we can be sure only one model was run + assert type(res) == RunResult -class TestFailFastFromConfig(TestFastFailingDuringRun): +class TestFailFastFromConfig(FailFastBase): @pytest.fixture(scope="class") def profiles_config_update(self): return { @@ -61,7 +43,8 @@ def profiles_config_update(self): def test_fail_fast_run_user_config( self, project, + models, # noqa: F811 ): - with pytest.raises(FailFastError): - run_dbt(["run", "--threads", "1"]) - check_audit_table(project) + res = run_dbt(["run", "--threads", "1"], expect_pass=False) + # a RunResult contains only one node so we can be sure only one model was run + assert type(res) == RunResult diff --git a/tests/unit/test_cli_flags.py b/tests/unit/test_cli_flags.py index 12a5a6d649c..1146a91e18d 100644 --- a/tests/unit/test_cli_flags.py +++ b/tests/unit/test_cli_flags.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import List +from dbt.cli.exceptions import DbtUsageException from dbt.cli.main import cli from dbt.contracts.project import UserConfig from dbt.cli.flags import Flags @@ -148,7 +149,7 @@ def test_mutually_exclusive_options_from_cli(self): "run", ["--warn-error", "--warn-error-options", '{"include": "all"}', "run"] ) - with pytest.raises(click.BadOptionUsage): + with pytest.raises(DbtUsageException): Flags(context) @pytest.mark.parametrize("warn_error", [True, False]) @@ -158,7 +159,7 @@ def test_mutually_exclusive_options_from_user_config(self, warn_error, user_conf "run", ["--warn-error-options", '{"include": "all"}', "run"] ) - with pytest.raises(click.BadOptionUsage): + with pytest.raises(DbtUsageException): Flags(context, user_config) @pytest.mark.parametrize("warn_error", ["True", "False"]) @@ -167,7 +168,7 @@ def test_mutually_exclusive_options_from_envvar(self, warn_error, monkeypatch): monkeypatch.setenv("DBT_WARN_ERROR_OPTIONS", '{"include":"all"}') context = self.make_dbt_context("run", ["run"]) - with pytest.raises(click.BadOptionUsage): + with pytest.raises(DbtUsageException): Flags(context) @pytest.mark.parametrize("warn_error", [True, False]) @@ -177,7 +178,7 @@ def test_mutually_exclusive_options_from_cli_and_user_config(self, warn_error, u "run", ["--warn-error-options", '{"include": "all"}', "run"] ) - with pytest.raises(click.BadOptionUsage): + with pytest.raises(DbtUsageException): Flags(context, user_config) @pytest.mark.parametrize("warn_error", ["True", "False"]) @@ -187,7 +188,7 @@ def test_mutually_exclusive_options_from_cli_and_envvar(self, warn_error, monkey "run", ["--warn-error-options", '{"include": "all"}', "run"] ) - with pytest.raises(click.BadOptionUsage): + with pytest.raises(DbtUsageException): Flags(context) @pytest.mark.parametrize("warn_error", ["True", "False"]) @@ -198,7 +199,7 @@ def test_mutually_exclusive_options_from_user_config_and_envvar( monkeypatch.setenv("DBT_WARN_ERROR_OPTIONS", '{"include": "all"}') context = self.make_dbt_context("run", ["run"]) - with pytest.raises(click.BadOptionUsage): + with pytest.raises(DbtUsageException): Flags(context, user_config) @pytest.mark.parametrize( diff --git a/tests/unit/test_dbt_runner.py b/tests/unit/test_dbt_runner.py new file mode 100644 index 00000000000..165de5020cd --- /dev/null +++ b/tests/unit/test_dbt_runner.py @@ -0,0 +1,38 @@ +import pytest + +from dbt.cli.exceptions import DbtUsageException +from dbt.cli.main import dbtRunner +from unittest import mock + + +class TestDbtRunner: + @pytest.fixture + def dbt(self) -> dbtRunner: + return dbtRunner() + + def test_group_invalid_option(self, dbt: dbtRunner) -> None: + res = dbt.invoke(["--invalid-option"]) + assert type(res.exception) == DbtUsageException + + def test_command_invalid_option(self, dbt: dbtRunner) -> None: + res = dbt.invoke(["deps", "--invalid-option"]) + assert type(res.exception) == DbtUsageException + + def test_command_mutually_exclusive_option(self, dbt: dbtRunner) -> None: + res = dbt.invoke(["--warn-error", "--warn-error-options", '{"include": "all"}', "deps"]) + assert type(res.exception) == DbtUsageException + + def test_invalid_command(self, dbt: dbtRunner) -> None: + res = dbt.invoke(["invalid-command"]) + assert type(res.exception) == DbtUsageException + + def test_invoke_version(self, dbt: dbtRunner) -> None: + dbt.invoke(["--version"]) + + def test_callbacks(self) -> None: + mock_callback = mock.MagicMock() + dbt = dbtRunner(callbacks=[mock_callback]) + # the `debug` command is one of the few commands wherein you don't need + # to have a project to run it and it will emit events + dbt.invoke(["debug"]) + mock_callback.assert_called()