diff --git a/.changes/unreleased/Fixes-20231113-154535.yaml b/.changes/unreleased/Fixes-20231113-154535.yaml new file mode 100644 index 00000000000..13b900ec2dd --- /dev/null +++ b/.changes/unreleased/Fixes-20231113-154535.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Use seed value if rows not specified +time: 2023-11-13T15:45:35.008565Z +custom: + Author: aranke + Issue: "8652" diff --git a/core/dbt/parser/unit_tests.py b/core/dbt/parser/unit_tests.py index c93f70b2997..91a977400b8 100644 --- a/core/dbt/parser/unit_tests.py +++ b/core/dbt/parser/unit_tests.py @@ -1,5 +1,9 @@ +from csv import DictReader +from pathlib import Path from typing import List, Set, Dict, Any +from dbt_extractor import py_extract_from_source, ExtractionError # type: ignore + from dbt.config import RuntimeConfig from dbt.context.context_config import ContextConfig from dbt.context.providers import generate_parse_exposure, get_rendered @@ -14,7 +18,7 @@ UnitTestConfig, ) from dbt.contracts.graph.unparsed import UnparsedUnitTest -from dbt.exceptions import ParsingError, InvalidUnitTestGivenInput +from dbt.exceptions import ParsingError, InvalidUnitTestGivenInput, DbtInternalError from dbt.graph import UniqueId from dbt.node_types import NodeType from dbt.parser.schemas import ( @@ -27,7 +31,6 @@ ParseResult, ) from dbt.utils import get_pseudo_test_path -from dbt_extractor import py_extract_from_source, ExtractionError # type: ignore class UnitTestManifestLoader: @@ -203,6 +206,22 @@ def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock) -> None: self.schema_parser = schema_parser self.yaml = yaml + def load_rows_from_seed(self, seed_name): + rows = [] + + try: + seed_node = self.manifest.ref_lookup.perform_lookup( + f"seed.{self.project.project_name}.{seed_name}", self.manifest + ) + seed_path = Path(seed_node.root_path) / seed_node.original_file_path + with open(seed_path, "r") as f: + for row in DictReader(f): + rows.append(row) + except DbtInternalError: + pass + finally: + return rows + def parse(self) -> ParseResult: for data in self.get_key_dicts(): unit_test = self._get_unit_test(data) @@ -214,8 +233,12 @@ def parse(self) -> ParseResult: unit_test_fqn = [self.project.project_name] + model_name_split + [unit_test.name] unit_test_config = self._build_unit_test_config(unit_test_fqn, unit_test.config) + # self.manifest.ref_lookup.perform_lookup('seed.test.my_favorite_source', self.manifest) + # Check that format and type of rows matches for each given input for input in unit_test.given: + if input.rows is None: + input.rows = self.load_rows_from_seed(input.input.split("'")[1]) input.validate_fixture("input", unit_test.name) unit_test.expect.validate_fixture("expected", unit_test.name) diff --git a/tests/functional/unit_testing/test_unit_testing.py b/tests/functional/unit_testing/test_unit_testing.py index 2a631c23efe..e475f162e96 100644 --- a/tests/functional/unit_testing/test_unit_testing.py +++ b/tests/functional/unit_testing/test_unit_testing.py @@ -100,3 +100,64 @@ def test_basic(self, project): # Select by model name results = run_dbt(["unit-test", "--select", "my_incremental_model"], expect_pass=True) assert len(results) == 2 + + +my_new_model = """ +select +my_favorite_seed.id, +a + b as c +from {{ ref('my_favorite_seed') }} as my_favorite_seed +inner join {{ ref('my_favorite_model') }} as my_favorite_model +on my_favorite_seed.id = my_second_favorite_model.id +""" + +my_second_favorite_model = """ +select +2 as id, +3 as b +""" + +seed_my_favorite_seed = """id,a +1,5 +2,4 +3,3 +4,2 +5,1 +""" + +test_my_model_implicit_seed = """ +unit_tests: + - name: t + model: my_new_model + given: + - input: ref('my_favorite_seed') + - input: ref('my_second_favorite_model') + rows: + - {id: 1, b: 2} + expect: + rows: + - {id: 1, c: 7} +""" + + +class TestUnitTestImplicitSeed: + @pytest.fixture(scope="class") + def seeds(self): + return {"my_favorite_seed.csv": seed_my_favorite_seed} + + @pytest.fixture(scope="class") + def models(self): + return { + "my_new_model.sql": my_new_model, + "my_second_favorite_model.sql": my_second_favorite_model, + "schema.yml": test_my_model_implicit_seed, + } + + def test_basic(self, project): + run_dbt(["seed"]) + run_dbt(["run"]) + # assert len(results) == 1 + + # Select by model name + results = run_dbt(["unit-test", "--select", "my_new_model"], expect_pass=True) + assert len(results) == 1