Skip to content

Commit

Permalink
mutually exclusive handling for warn_error_options and warn_error par…
Browse files Browse the repository at this point in the history
…ams in Click CLI (#6771)

warn_error_options, warn_error mutual exclusivity with click
  • Loading branch information
MichelleArk authored and stu-k committed Jan 31, 2023
1 parent b035836 commit 2ac3d38
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 39 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20230130-180917.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: warn_error/warn_error_options mutual exclusivity in click
time: 2023-01-30T18:09:17.240662-05:00
custom:
Author: michelleark
Issue: "6579"
28 changes: 26 additions & 2 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from importlib import import_module
from multiprocessing import get_context
from pprint import pformat as pf
from typing import Set
from typing import Set, List

from click import Context, get_current_context
from click import Context, get_current_context, BadOptionUsage
from click.core import ParameterSource

from dbt.config.profile import read_user_config
Expand Down Expand Up @@ -59,12 +59,15 @@ def assign_params(ctx, params_assigned_from_default):

# Overwrite default assignments with user config if available
if user_config:
param_assigned_from_default_copy = params_assigned_from_default.copy()
for param_assigned_from_default in params_assigned_from_default:
user_config_param_value = getattr(user_config, param_assigned_from_default, None)
if user_config_param_value is not None:
object.__setattr__(
self, param_assigned_from_default.upper(), user_config_param_value
)
param_assigned_from_default_copy.remove(param_assigned_from_default)
params_assigned_from_default = param_assigned_from_default_copy

# Hard coded flags
object.__setattr__(self, "WHICH", invoked_subcommand_name or ctx.info_name)
Expand All @@ -78,6 +81,10 @@ def assign_params(ctx, params_assigned_from_default):
if os.getenv("DO_NOT_TRACK", "").lower() in ("1", "t", "true", "y", "yes")
else True,
)
# Check mutual exclusivity once all flags are set
self._assert_mutually_exclusive(
params_assigned_from_default, ["WARN_ERROR", "WARN_ERROR_OPTIONS"]
)

# Support lower cased access for legacy code
params = set(
Expand All @@ -88,3 +95,20 @@ def assign_params(ctx, params_assigned_from_default):

def __str__(self) -> str:
return str(pf(self.__dict__))

def _assert_mutually_exclusive(
self, params_assigned_from_default: Set[str], group: List[str]
) -> None:
"""
Ensure no elements from group are simultaneously provided by a user, as inferred from params_assigned_from_default.
Raises click.UsageError if any two elements from group are simultaneously provided by a user.
"""
set_flag = None
for flag in group:
flag_set_by_user = flag.lower() not in params_assigned_from_default
if flag_set_by_user and set_flag:
raise BadOptionUsage(
flag.lower(), f"{flag.lower()}: not allowed with argument {set_flag.lower()}"
)
elif flag_set_by_user:
set_flag = flag
2 changes: 1 addition & 1 deletion core/dbt/cli/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@
envvar="DBT_WARN_ERROR",
help="If dbt would normally warn, instead raise an exception. Examples include --select that selects nothing, deprecations, configurations with no associated models, invalid test configurations, and missing sources/refs in tests.",
default=None,
flag_value=True,
is_flag=True,
)

warn_error_options = click.option(
Expand Down
6 changes: 1 addition & 5 deletions core/dbt/events/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,7 @@ def msg_to_dict(msg: EventMsg) -> dict:


def warn_or_error(event, node=None):
# TODO: resolve this circular import when flags.WARN_ERROR_OPTIONS is WarnErrorOptions type via click CLI.
from dbt.helper_types import WarnErrorOptions

warn_error_options = WarnErrorOptions.from_yaml_string(flags.WARN_ERROR_OPTIONS)
if flags.WARN_ERROR or warn_error_options.includes(type(event).__name__):
if flags.WARN_ERROR or flags.WARN_ERROR_OPTIONS.includes(type(event).__name__):
# TODO: resolve this circular import when at top
from dbt.exceptions import EventCompilationError

Expand Down
12 changes: 10 additions & 2 deletions core/dbt/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from pathlib import Path
from typing import Optional

from dbt.helper_types import WarnErrorOptions

# PROFILES_DIR must be set before the other flags
# It also gets set in main.py and in set_from_args because the rpc server
# doesn't go through exactly the same main arg processing.
Expand Down Expand Up @@ -46,7 +48,7 @@
USE_EXPERIMENTAL_PARSER = None
VERSION_CHECK = None
WARN_ERROR = None
WARN_ERROR_OPTIONS = None
WARN_ERROR_OPTIONS = WarnErrorOptions(include=[])
WHICH = None
WRITE_JSON = None

Expand Down Expand Up @@ -170,7 +172,13 @@ def set_from_args(args, user_config):
USE_EXPERIMENTAL_PARSER = get_flag_value("USE_EXPERIMENTAL_PARSER", args, user_config)
VERSION_CHECK = get_flag_value("VERSION_CHECK", args, user_config)
WARN_ERROR = get_flag_value("WARN_ERROR", args, user_config)
WARN_ERROR_OPTIONS = get_flag_value("WARN_ERROR_OPTIONS", args, user_config)

warn_error_options_str = get_flag_value("WARN_ERROR_OPTIONS", args, user_config)
from dbt.cli.option_types import WarnErrorOptionsType

# Converting to WarnErrorOptions for consistency with dbt/cli/flags.py
WARN_ERROR_OPTIONS = WarnErrorOptionsType().convert(warn_error_options_str, None, None)

WRITE_JSON = get_flag_value("WRITE_JSON", args, user_config)

_check_mutually_exclusive(["WARN_ERROR", "WARN_ERROR_OPTIONS"], args, user_config)
Expand Down
16 changes: 0 additions & 16 deletions core/dbt/helper_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,22 +123,6 @@ def _validate_items(self, items: List[str]):


class WarnErrorOptions(IncludeExclude):
# TODO: this method can be removed once the click CLI is in use
@classmethod
def from_yaml_string(cls, warn_error_options_str: Optional[str]):

# TODO: resolve circular import
from dbt.config.utils import parse_cli_yaml_string

warn_error_options_str = (
str(warn_error_options_str) if warn_error_options_str is not None else "{}"
)
warn_error_options = parse_cli_yaml_string(warn_error_options_str, "warn-error-options")
return cls(
include=warn_error_options.get("include", []),
exclude=warn_error_options.get("exclude", []),
)

def _validate_items(self, items: List[str]):
valid_exception_names = set(
[name for name, cls in dbt_event_types.__dict__.items() if isinstance(cls, type)]
Expand Down
13 changes: 7 additions & 6 deletions test/unit/test_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dbt import flags
from dbt.contracts.project import UserConfig
from dbt.graph.selector_spec import IndirectSelection
from dbt.helper_types import WarnErrorOptions

class TestFlags(TestCase):

Expand Down Expand Up @@ -66,13 +67,13 @@ def test__flags(self):
# warn_error_options
self.user_config.warn_error_options = '{"include": "all"}'
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.WARN_ERROR_OPTIONS, '{"include": "all"}')
self.assertEqual(flags.WARN_ERROR_OPTIONS, WarnErrorOptions(include="all"))
os.environ['DBT_WARN_ERROR_OPTIONS'] = '{"include": []}'
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.WARN_ERROR_OPTIONS, '{"include": []}')
self.assertEqual(flags.WARN_ERROR_OPTIONS, WarnErrorOptions(include=[]))
setattr(self.args, 'warn_error_options', '{"include": "all"}')
flags.set_from_args(self.args, self.user_config)
self.assertEqual(flags.WARN_ERROR_OPTIONS, '{"include": "all"}')
self.assertEqual(flags.WARN_ERROR_OPTIONS, WarnErrorOptions(include="all"))
# cleanup
os.environ.pop('DBT_WARN_ERROR_OPTIONS')
delattr(self.args, 'warn_error_options')
Expand Down Expand Up @@ -283,7 +284,7 @@ def test__flags(self):
def test__flags_are_mutually_exclusive(self):
# options from user config
self.user_config.warn_error = False
self.user_config.warn_error_options = '{"include":"all}'
self.user_config.warn_error_options = '{"include":"all"}'
with pytest.raises(ValueError):
flags.set_from_args(self.args, self.user_config)
#cleanup
Expand All @@ -292,7 +293,7 @@ def test__flags_are_mutually_exclusive(self):

# options from args
setattr(self.args, 'warn_error', False)
setattr(self.args, 'warn_error_options', '{"include":"all}')
setattr(self.args, 'warn_error_options', '{"include":"all"}')
with pytest.raises(ValueError):
flags.set_from_args(self.args, self.user_config)
# cleanup
Expand All @@ -310,7 +311,7 @@ def test__flags_are_mutually_exclusive(self):

# options from user config + args
self.user_config.warn_error = False
setattr(self.args, 'warn_error_options', '{"include":"all}')
setattr(self.args, 'warn_error_options', '{"include":"all"}')
with pytest.raises(ValueError):
flags.set_from_args(self.args, self.user_config)
# cleanup
Expand Down
88 changes: 81 additions & 7 deletions tests/unit/test_cli_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dbt.cli.main import cli
from dbt.contracts.project import UserConfig
from dbt.cli.flags import Flags
from dbt.helper_types import WarnErrorOptions


class TestFlags:
Expand All @@ -18,6 +19,10 @@ def make_dbt_context(self, context_name: str, args: List[str]) -> click.Context:
def run_context(self) -> click.Context:
return self.make_dbt_context("run", ["run"])

@pytest.fixture
def user_config(self) -> UserConfig:
return UserConfig()

def test_which(self, run_context):
flags = Flags(run_context)
assert flags.WHICH == "run"
Expand Down Expand Up @@ -55,18 +60,16 @@ def test_anonymous_usage_state(
flags = Flags(run_context)
assert flags.ANONYMOUS_USAGE_STATS == expected_anonymous_usage_stats

def test_empty_user_config_uses_default(self, run_context):
user_config = UserConfig()

def test_empty_user_config_uses_default(self, run_context, user_config):
flags = Flags(run_context, user_config)
assert flags.USE_COLORS == run_context.params["use_colors"]

def test_none_user_config_uses_default(self, run_context):
flags = Flags(run_context, None)
assert flags.USE_COLORS == run_context.params["use_colors"]

def test_prefer_user_config_to_default(self, run_context):
user_config = UserConfig(use_colors=False)
def test_prefer_user_config_to_default(self, run_context, user_config):
user_config.use_colors = False
# ensure default value is not the same as user config
assert run_context.params["use_colors"] is not user_config.use_colors

Expand All @@ -80,10 +83,81 @@ def test_prefer_param_value_to_user_config(self):
flags = Flags(context, user_config)
assert flags.USE_COLORS

def test_prefer_env_to_user_config(self, monkeypatch):
user_config = UserConfig(use_colors=False)
def test_prefer_env_to_user_config(self, monkeypatch, user_config):
user_config.use_colors = False
monkeypatch.setenv("DBT_USE_COLORS", "True")
context = self.make_dbt_context("run", ["run"])

flags = Flags(context, user_config)
assert flags.USE_COLORS

def test_mutually_exclusive_options_passed_separately(self):
"""Assert options that are mutually exclusive can be passed separately without error"""
warn_error_context = self.make_dbt_context("run", ["--warn-error", "run"])

flags = Flags(warn_error_context)
assert flags.WARN_ERROR

warn_error_options_context = self.make_dbt_context(
"run", ["--warn-error-options", '{"include": "all"}', "run"]
)
flags = Flags(warn_error_options_context)
assert flags.WARN_ERROR_OPTIONS == WarnErrorOptions(include="all")

def test_mutually_exclusive_options_from_cli(self):
context = self.make_dbt_context(
"run", ["--warn-error", "--warn-error-options", '{"include": "all"}', "run"]
)

with pytest.raises(click.BadOptionUsage):
Flags(context)

@pytest.mark.parametrize("warn_error", [True, False])
def test_mutually_exclusive_options_from_user_config(self, warn_error, user_config):
user_config.warn_error = warn_error
context = self.make_dbt_context(
"run", ["--warn-error-options", '{"include": "all"}', "run"]
)

with pytest.raises(click.BadOptionUsage):
Flags(context, user_config)

@pytest.mark.parametrize("warn_error", ["True", "False"])
def test_mutually_exclusive_options_from_envvar(self, warn_error, monkeypatch):
monkeypatch.setenv("DBT_WARN_ERROR", warn_error)
monkeypatch.setenv("DBT_WARN_ERROR_OPTIONS", '{"include":"all"}')
context = self.make_dbt_context("run", ["run"])

with pytest.raises(click.BadOptionUsage):
Flags(context)

@pytest.mark.parametrize("warn_error", [True, False])
def test_mutually_exclusive_options_from_cli_and_user_config(self, warn_error, user_config):
user_config.warn_error = warn_error
context = self.make_dbt_context(
"run", ["--warn-error-options", '{"include": "all"}', "run"]
)

with pytest.raises(click.BadOptionUsage):
Flags(context, user_config)

@pytest.mark.parametrize("warn_error", ["True", "False"])
def test_mutually_exclusive_options_from_cli_and_envvar(self, warn_error, monkeypatch):
monkeypatch.setenv("DBT_WARN_ERROR", warn_error)
context = self.make_dbt_context(
"run", ["--warn-error-options", '{"include": "all"}', "run"]
)

with pytest.raises(click.BadOptionUsage):
Flags(context)

@pytest.mark.parametrize("warn_error", ["True", "False"])
def test_mutually_exclusive_options_from_user_config_and_envvar(
self, user_config, warn_error, monkeypatch
):
user_config.warn_error = warn_error
monkeypatch.setenv("DBT_WARN_ERROR_OPTIONS", '{"include": "all"}')
context = self.make_dbt_context("run", ["run"])

with pytest.raises(click.BadOptionUsage):
Flags(context, user_config)
4 changes: 4 additions & 0 deletions tests/unit/test_dbt_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def test_command_invalid_option(self, dbt: dbtRunner) -> None:
with pytest.raises(dbtUsageException):
dbt.invoke(["deps", "--invalid-option"])

def test_command_mutually_exclusive_option(self, dbt: dbtRunner) -> None:
with pytest.raises(dbtUsageException):
dbt.invoke(["--warn-error", "--warn-error-options", '{"include": "all"}', "deps"])

def test_invalid_command(self, dbt: dbtRunner) -> None:
with pytest.raises(dbtUsageException):
dbt.invoke(["invalid-command"])
Expand Down

0 comments on commit 2ac3d38

Please sign in to comment.