Skip to content

Commit

Permalink
Fix generic test not null and unique custom configs (#11208)
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank authored Jan 15, 2025
1 parent 8a8857a commit 3de3b82
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 202 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20250110-155824.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Fix for custom fields in generic test config for not_null and unique tests
time: 2025-01-10T15:58:24.479245-05:00
custom:
Author: gshank
Issue: "11208"
86 changes: 0 additions & 86 deletions core/dbt/parser/generic_test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,92 +229,6 @@ def extract_test_args(data_test, name=None) -> Tuple[str, Dict[str, Any]]:
test_args["column_name"] = name
return test_name, test_args

@property
def enabled(self) -> Optional[bool]:
return self.config.get("enabled")

@property
def alias(self) -> Optional[str]:
return self.config.get("alias")

@property
def severity(self) -> Optional[str]:
sev = self.config.get("severity")
if sev:
return sev.upper()
else:
return None

@property
def store_failures(self) -> Optional[bool]:
return self.config.get("store_failures")

@property
def store_failures_as(self) -> Optional[bool]:
return self.config.get("store_failures_as")

@property
def where(self) -> Optional[str]:
return self.config.get("where")

@property
def limit(self) -> Optional[int]:
return self.config.get("limit")

@property
def warn_if(self) -> Optional[str]:
return self.config.get("warn_if")

@property
def error_if(self) -> Optional[str]:
return self.config.get("error_if")

@property
def fail_calc(self) -> Optional[str]:
return self.config.get("fail_calc")

@property
def meta(self) -> Optional[dict]:
return self.config.get("meta")

@property
def database(self) -> Optional[str]:
return self.config.get("database")

@property
def schema(self) -> Optional[str]:
return self.config.get("schema")

def get_static_config(self):
config = {}
if self.alias is not None:
config["alias"] = self.alias
if self.severity is not None:
config["severity"] = self.severity
if self.enabled is not None:
config["enabled"] = self.enabled
if self.where is not None:
config["where"] = self.where
if self.limit is not None:
config["limit"] = self.limit
if self.warn_if is not None:
config["warn_if"] = self.warn_if
if self.error_if is not None:
config["error_if"] = self.error_if
if self.fail_calc is not None:
config["fail_calc"] = self.fail_calc
if self.store_failures is not None:
config["store_failures"] = self.store_failures
if self.store_failures_as is not None:
config["store_failures_as"] = self.store_failures_as
if self.meta is not None:
config["meta"] = self.meta
if self.database is not None:
config["database"] = self.database
if self.schema is not None:
config["schema"] = self.schema
return config

def tags(self) -> List[str]:
tags = self.config.get("tags", [])
if isinstance(tags, str):
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/parser/schema_generic_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def render_test_update(self, node, config, builder, schema_file_id):
# to the context in rendering processing
node.depends_on.add_macro(macro_unique_id)
if macro_unique_id in ["macro.dbt.test_not_null", "macro.dbt.test_unique"]:
config_call_dict = builder.get_static_config()
config_call_dict = builder.config
config._config_call_dict = config_call_dict
# This sets the config from dbt_project
self.update_parsed_node_config(node, config)
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/partial_parsing/test_pp_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_env_vars_models(self, project):
schema_file = manifest.files[source.file_id]
test_id = "test.test.source_not_null_seed_sources_raw_customers_id.e39ee7bf0d"
test_node = manifest.nodes[test_id]
assert test_node.config.severity == "WARN"
assert test_node.config.severity == "warn"

# Change severity env var
os.environ["ENV_VAR_SEVERITY"] = "error"
Expand All @@ -125,7 +125,7 @@ def test_env_vars_models(self, project):
}
assert expected_schema_file_env_vars == schema_file.env_vars
test_node = manifest.nodes[test_id]
assert test_node.config.severity == "ERROR"
assert test_node.config.severity == "error"

# Change database env var
os.environ["ENV_VAR_DATABASE"] = "test_dbt"
Expand Down
77 changes: 0 additions & 77 deletions tests/functional/schema_tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,80 +1273,3 @@
data_tests:
- my_custom_test
"""

custom_config_yml = """
version: 2
models:
- name: table
columns:
- name: color
data_tests:
- accepted_values:
values: ['blue', 'red']
config:
custom_config_key: some_value
- custom_color_from_config:
severity: error
config:
test_color: orange
store_failures: true
unlogged: True
"""

mixed_config_yml = """
version: 2
models:
- name: table
columns:
- name: color
data_tests:
- accepted_values:
values: ['blue', 'red']
config:
custom_config_key: some_value
severity: warn
- custom_color_from_config:
severity: error
config:
test_color: blue
"""

same_key_error_yml = """
version: 2
models:
- name: table
columns:
- name: color
data_tests:
- accepted_values:
values: ['blue', 'red']
severity: warn
config:
severity: error
"""

seed_csv = """
id,color,value
1,blue,10
2,red,20
3,green,30
4,yellow,40
5,blue,50
6,red,60
7,blue,70
8,green,80
9,yellow,90
10,blue,100
""".strip()

table_sql = """
-- content of the table.sql
select * from {{ ref('seed') }}
"""

test_custom_color_from_config = """
{% test custom_color_from_config(model, column_name) %}
select * from {{ model }}
where color = '{{ config.get('test_color') }}'
{% endtest %}
"""
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,84 @@
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import TestNode
from dbt.exceptions import CompilationError
from dbt.tests.util import get_manifest, run_dbt
from tests.functional.schema_tests.fixtures import (
custom_config_yml,
mixed_config_yml,
same_key_error_yml,
seed_csv,
table_sql,
test_custom_color_from_config,
)
from dbt.tests.util import get_manifest, run_dbt, update_config_file

custom_config_yml = """
models:
- name: table
columns:
- name: color
data_tests:
- accepted_values:
values: ['blue', 'red']
config:
custom_config_key: some_value
- custom_color_from_config:
severity: error
config:
test_color: orange
store_failures: true
unlogged: True
- not_null:
config:
not_null_key: abc
"""

mixed_config_yml = """
models:
- name: table
columns:
- name: color
data_tests:
- accepted_values:
values: ['blue', 'red']
config:
custom_config_key: some_value
severity: warn
- custom_color_from_config:
severity: error
config:
test_color: blue
"""

same_key_error_yml = """
models:
- name: table
columns:
- name: color
data_tests:
- accepted_values:
values: ['blue', 'red']
severity: warn
config:
severity: error
"""

seed_csv = """
id,color,value
1,blue,10
2,red,20
3,green,30
4,yellow,40
5,blue,50
6,red,60
7,blue,70
8,green,80
9,yellow,90
10,blue,100
""".strip()

table_sql = """
-- content of the table.sql
select * from {{ ref('seed') }}
"""

test_custom_color_from_config = """
{% test custom_color_from_config(model, column_name) %}
select * from {{ model }}
where color = '{{ config.get('test_color') }}'
{% endtest %}
"""


def _select_test_node(manifest: Manifest, pattern: re.Pattern[str]):
Expand Down Expand Up @@ -50,12 +119,6 @@ def seeds(self):
def macros(self):
return {"custom_color_from_config.sql": test_custom_color_from_config}

@pytest.fixture(scope="class")
def project_config_update(self):
return {
"config-version": 2,
}

@pytest.fixture(scope="class", autouse=True)
def setUp(self, project):
run_dbt(["seed"])
Expand All @@ -68,7 +131,7 @@ def models(self):

def test_custom_config(self, project):
run_dbt(["run"])
run_dbt(["test", "--log-level", "debug"], expect_pass=False)
run_dbt(["test"], expect_pass=False)

manifest = get_manifest(project.project_root)
# Pattern to match the test_id without the specific suffix
Expand All @@ -79,12 +142,29 @@ def test_custom_config(self, project):
assert "custom_config_key" in test_node.config
assert test_node.config["custom_config_key"] == "some_value"

# pattern = re.compile(r"test\.test\.custom_color_from_config.*")
# test_node = _select_test_node(manifest, pattern)
custom_color_pattern = re.compile(r"test\.test\.custom_color_from_config.*")
custom_color_test_node = _select_test_node(manifest, custom_color_pattern)
assert custom_color_test_node.config.get("test_color") == "orange"
assert custom_color_test_node.config.get("unlogged") is True
persistence = get_table_persistence(project, "custom_color_from_config_table_color")

assert persistence == "u"

not_null_pattern = re.compile(r"test\.test\.not_null.*")
not_null_test_node = _select_test_node(manifest, not_null_pattern)
assert not_null_test_node.config.get("not_null_key") == "abc"

# set dbt_project.yml config and ensure that schema configs override project configs
config_patch = {
"data_tests": {"test_color": "blue", "some_key": "strange", "not_null_key": "def"}
}
update_config_file(config_patch, project.project_root, "dbt_project.yml")
manifest = run_dbt(["parse"])
custom_color_test_node = _select_test_node(manifest, custom_color_pattern)
assert custom_color_test_node.config.get("test_color") == "orange"
assert custom_color_test_node.config.get("some_key") == "strange"
not_null_test_node = _select_test_node(manifest, not_null_pattern)
assert not_null_test_node.config.get("not_null_key") == "abc"


class TestMixedDataTestConfig(BaseDataTestsConfig):
@pytest.fixture(scope="class")
Expand All @@ -97,21 +177,8 @@ def test_mixed_config(self, project):

# Pattern to match the test_id without the specific suffix
pattern = re.compile(r"test\.test\.accepted_values_table_color__blue__red\.\d+")
test_node = _select_test_node(manifest, pattern)

# Find the test_id dynamically
test_id = None
for node_id in manifest.nodes:
if pattern.match(node_id):
test_id = node_id
break

# Ensure the test_id was found
assert (
test_id is not None
), "Test ID matching the pattern was not found in the manifest nodes"

# Proceed with the assertions
test_node = manifest.nodes[test_id]
assert "custom_config_key" in test_node.config
assert test_node.config["custom_config_key"] == "some_value"
assert "severity" in test_node.config
Expand All @@ -137,6 +204,3 @@ def test_same_key_error(self, project):

# Assert that the error message contains the expected text
assert "Test cannot have the same key at the top-level and in config" in exception_message

# Assert that the error message contains the context of the error
assert "models/same_key_error.yml" in exception_message

0 comments on commit 3de3b82

Please sign in to comment.