Skip to content

Commit

Permalink
Fix #8652: Use seed file from disk for unit testing if rows not speci…
Browse files Browse the repository at this point in the history
…fied in YAML config (#9064)

Co-authored-by: Michelle Ark <MichelleArk@users.noreply.github.com>
Fix #8652: Use seed value if rows not specified
  • Loading branch information
aranke authored Nov 16, 2023
1 parent 35f579e commit 3432436
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 4 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20231113-154535.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Use seed file from disk for unit testing if rows not specified in YAML config
time: 2023-11-13T15:45:35.008565Z
custom:
Author: aranke
Issue: "8652"
38 changes: 36 additions & 2 deletions core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from csv import DictReader
from pathlib import Path
from typing import List, Set, Dict, Any

from dbt_extractor import py_extract_from_source, ExtractionError # type: ignore

from dbt.config import RuntimeConfig
from dbt.context.context_config import ContextConfig
from dbt.context.providers import generate_parse_exposure, get_rendered
Expand Down Expand Up @@ -28,7 +32,6 @@
ParseResult,
)
from dbt.utils import get_pseudo_test_path
from dbt_extractor import py_extract_from_source, ExtractionError # type: ignore


class UnitTestManifestLoader:
Expand Down Expand Up @@ -130,7 +133,7 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
),
}

if original_input_node.resource_type == NodeType.Model:
if original_input_node.resource_type in (NodeType.Model, NodeType.Seed):
input_name = f"{unit_test_node.name}__{original_input_node.name}"
input_node = ModelNode(
**common_fields,
Expand Down Expand Up @@ -219,6 +222,35 @@ def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock) -> None:
self.schema_parser = schema_parser
self.yaml = yaml

def _load_rows_from_seed(self, ref_str: str) -> List[Dict[str, Any]]:
"""Read rows from seed file on disk if not specified in YAML config. If seed file doesn't exist, return empty list."""
ref = py_extract_from_source("{{ " + ref_str + " }}")["refs"][0]

rows: List[Dict[str, Any]] = []

seed_name = ref["name"]
package_name = ref.get("package", self.project.project_name)

seed_node = self.manifest.ref_lookup.find(seed_name, package_name, None, self.manifest)

if not seed_node or seed_node.resource_type != NodeType.Seed:
# Seed not found in custom package specified
if package_name != self.project.project_name:
raise ParsingError(

Check warning on line 239 in core/dbt/parser/unit_tests.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/unit_tests.py#L239

Added line #L239 was not covered by tests
f"Unable to find seed '{package_name}.{seed_name}' for unit tests in '{package_name}' package"
)
else:
raise ParsingError(
f"Unable to find seed '{package_name}.{seed_name}' for unit tests in directories: {self.project.seed_paths}"
)

seed_path = Path(seed_node.root_path) / seed_node.original_file_path
with open(seed_path, "r") as f:
for row in DictReader(f):
rows.append(row)

return rows

def parse(self) -> ParseResult:
for data in self.get_key_dicts():
unit_test = self._get_unit_test(data)
Expand All @@ -232,6 +264,8 @@ def parse(self) -> ParseResult:

# Check that format and type of rows matches for each given input
for input in unit_test.given:
if input.rows is None and input.fixture is None:
input.rows = self._load_rows_from_seed(input.input)
input.validate_fixture("input", unit_test.name)
unit_test.expect.validate_fixture("expected", unit_test.name)

Expand Down
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
black==23.3.0
bumpversion
ddtrace
ddtrace==2.1.7
docutils
flake8
flaky
Expand Down
134 changes: 133 additions & 1 deletion tests/functional/unit_testing/test_unit_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
get_manifest,
get_artifact,
)
from dbt.exceptions import DuplicateResourceNameError
from dbt.exceptions import DuplicateResourceNameError, ParsingError
from fixtures import (
my_model_vars_sql,
my_model_a_sql,
Expand Down Expand Up @@ -105,3 +105,135 @@ def test_basic(self, project):
# Select by model name
results = run_dbt(["unit-test", "--select", "my_incremental_model"], expect_pass=True)
assert len(results) == 2


my_new_model = """
select
my_favorite_seed.id,
a + b as c
from {{ ref('my_favorite_seed') }} as my_favorite_seed
inner join {{ ref('my_favorite_model') }} as my_favorite_model
on my_favorite_seed.id = my_favorite_model.id
"""

my_favorite_model = """
select
2 as id,
3 as b
"""

seed_my_favorite_seed = """id,a
1,5
2,4
3,3
4,2
5,1
"""

schema_yml_explicit_seed = """
unit_tests:
- name: t
model: my_new_model
given:
- input: ref('my_favorite_seed')
rows:
- {id: 1, a: 10}
- input: ref('my_favorite_model')
rows:
- {id: 1, b: 2}
expect:
rows:
- {id: 1, c: 12}
"""

schema_yml_implicit_seed = """
unit_tests:
- name: t
model: my_new_model
given:
- input: ref('my_favorite_seed')
- input: ref('my_favorite_model')
rows:
- {id: 1, b: 2}
expect:
rows:
- {id: 1, c: 7}
"""

schema_yml_nonexistent_seed = """
unit_tests:
- name: t
model: my_new_model
given:
- input: ref('my_second_favorite_seed')
- input: ref('my_favorite_model')
rows:
- {id: 1, b: 2}
expect:
rows:
- {id: 1, c: 7}
"""


class TestUnitTestExplicitSeed:
@pytest.fixture(scope="class")
def seeds(self):
return {"my_favorite_seed.csv": seed_my_favorite_seed}

@pytest.fixture(scope="class")
def models(self):
return {
"my_new_model.sql": my_new_model,
"my_favorite_model.sql": my_favorite_model,
"schema.yml": schema_yml_explicit_seed,
}

def test_explicit_seed(self, project):
run_dbt(["seed"])
run_dbt(["run"])

# Select by model name
results = run_dbt(["unit-test", "--select", "my_new_model"], expect_pass=True)
assert len(results) == 1


class TestUnitTestImplicitSeed:
@pytest.fixture(scope="class")
def seeds(self):
return {"my_favorite_seed.csv": seed_my_favorite_seed}

@pytest.fixture(scope="class")
def models(self):
return {
"my_new_model.sql": my_new_model,
"my_favorite_model.sql": my_favorite_model,
"schema.yml": schema_yml_implicit_seed,
}

def test_implicit_seed(self, project):
run_dbt(["seed"])
run_dbt(["run"])

# Select by model name
results = run_dbt(["unit-test", "--select", "my_new_model"], expect_pass=True)
assert len(results) == 1


class TestUnitTestNonexistentSeed:
@pytest.fixture(scope="class")
def seeds(self):
return {"my_favorite_seed.csv": seed_my_favorite_seed}

@pytest.fixture(scope="class")
def models(self):
return {
"my_new_model.sql": my_new_model,
"my_favorite_model.sql": my_favorite_model,
"schema.yml": schema_yml_nonexistent_seed,
}

def test_nonexistent_seed(self, project):
with pytest.raises(
ParsingError, match="Unable to find seed 'test.my_second_favorite_seed' for unit tests"
):
run_dbt(["unit-test", "--select", "my_new_model"], expect_pass=False)

0 comments on commit 3432436

Please sign in to comment.