Skip to content

Commit

Permalink
Move unit testing to test and build commands (#9108)
Browse files Browse the repository at this point in the history
* Switch to using 'test' command instead of 'unit-test'

* Remove old unit test

* Put daff changes into task/test.py

* changie

* test_type:unit

* Add unit test to build and make test

* Select test_type:data

* Add unit tets to test_graph_selector_methods.py

* Fix fqn to incude path components

* Update build test

* Remove part of message in test_csv_fixtures.py that's different on Windows

* Rename build test directory
  • Loading branch information
gshank authored Nov 27, 2023
1 parent e001991 commit 3f1ed23
Show file tree
Hide file tree
Showing 26 changed files with 381 additions and 431 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231116-144006.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Move unit testing to test command
time: 2023-11-16T14:40:06.121336-05:00
custom:
Author: gshank
Issue: "8979"
1 change: 0 additions & 1 deletion core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,6 @@ def command_args(command: CliCommand) -> ArgsList:
CliCommand.SOURCE_FRESHNESS: cli.freshness,
CliCommand.TEST: cli.test,
CliCommand.RETRY: cli.retry,
CliCommand.UNIT_TEST: cli.unit_test,
}
click_cmd: Optional[ClickCommand] = CMD_DICT.get(command, None)
if click_cmd is None:
Expand Down
45 changes: 0 additions & 45 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from dbt.task.show import ShowTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask
from dbt.task.unit_test import UnitTestTask


@dataclass
Expand Down Expand Up @@ -897,50 +896,6 @@ def test(ctx, **kwargs):
return results, success


# dbt unit-test
@cli.command("unit-test")
@click.pass_context
@p.defer
@p.deprecated_defer
@p.exclude
@p.fail_fast
@p.favor_state
@p.deprecated_favor_state
@p.indirect_selection
@p.show_output_format
@p.profile
@p.profiles_dir
@p.project_dir
@p.select
@p.selector
@p.state
@p.defer_state
@p.deprecated_state
@p.store_failures
@p.target
@p.target_path
@p.threads
@p.vars
@p.version_check
@requires.postflight
@requires.preflight
@requires.profile
@requires.project
@requires.runtime_config
@requires.manifest
def unit_test(ctx, **kwargs):
"""Runs tests on data in deployed models. Run this after `dbt run`"""
task = UnitTestTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
ctx.obj["manifest"],
)

results = task.run()
success = task.interpret_results(results)
return results, success


# Support running as a module
if __name__ == "__main__":
cli()
1 change: 0 additions & 1 deletion core/dbt/cli/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class Command(Enum):
SOURCE_FRESHNESS = "freshness"
TEST = "test"
RETRY = "retry"
UNIT_TEST = "unit-test"

@classmethod
def from_str(cls, s: str) -> "Command":
Expand Down
4 changes: 4 additions & 0 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
InjectedCTE,
SeedNode,
UnitTestNode,
UnitTestDefinition,
)
from dbt.exceptions import (
GraphDependencyNotFoundError,
Expand Down Expand Up @@ -539,6 +540,9 @@ def compile_node(
the node's raw_code into compiled_code, and then calls the
recursive method to "prepend" the ctes.
"""
if isinstance(node, UnitTestDefinition):
return node

# Make sure Lexer for sqlparse 0.4.4 is initialized
from sqlparse.lexer import Lexer # type: ignore

Expand Down
1 change: 0 additions & 1 deletion core/dbt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@
MANIFEST_FILE_NAME = "manifest.json"
SEMANTIC_MANIFEST_FILE_NAME = "semantic_manifest.json"
PARTIAL_PARSE_FILE_NAME = "partial_parse.msgpack"
UNIT_TEST_MANIFEST_FILE_NAME = "unit_test_manifest.json"
PACKAGE_LOCK_HASH_KEY = "sha1_hash"
17 changes: 16 additions & 1 deletion core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,15 +1082,30 @@ class UnitTestNode(CompiledNode):


@dataclass
class UnitTestDefinition(GraphNode):
class UnitTestDefinitionMandatory:
model: str
given: Sequence[UnitTestInputFixture]
expect: UnitTestOutputFixture


@dataclass
class UnitTestDefinition(NodeInfoMixin, GraphNode, UnitTestDefinitionMandatory):
description: str = ""
overrides: Optional[UnitTestOverrides] = None
depends_on: DependsOn = field(default_factory=DependsOn)
config: UnitTestConfig = field(default_factory=UnitTestConfig)
checksum: Optional[str] = None
schema: Optional[str] = None

@property
def build_path(self):
# TODO: is this actually necessary?
return self.original_file_path

@property
def compiled_path(self):
# TODO: is this actually necessary?
return self.original_file_path

@property
def depends_on_nodes(self):
Expand Down
9 changes: 7 additions & 2 deletions core/dbt/contracts/results.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import threading

from dbt.contracts.graph.unparsed import FreshnessThreshold
from dbt.contracts.graph.nodes import CompiledNode, SourceDefinition, ResultNode
from dbt.contracts.graph.nodes import (
CompiledNode,
SourceDefinition,
ResultNode,
UnitTestDefinition,
)
from dbt.contracts.util import (
BaseArtifactMetadata,
ArtifactMixin,
Expand Down Expand Up @@ -153,7 +158,7 @@ def to_msg_dict(self):

@dataclass
class NodeResult(BaseResult):
node: ResultNode
node: Union[ResultNode, UnitTestDefinition]


@dataclass
Expand Down
20 changes: 12 additions & 8 deletions core/dbt/graph/selector_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,20 +552,24 @@ class TestTypeSelectorMethod(SelectorMethod):
__test__ = False

def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[UniqueId]:
search_type: Type
search_types: List[Any]
# continue supporting 'schema' + 'data' for backwards compatibility
if selector in ("generic", "schema"):
search_type = GenericTestNode
elif selector in ("singular", "data"):
search_type = SingularTestNode
search_types = [GenericTestNode]
elif selector in ("data"):
search_types = [GenericTestNode, SingularTestNode]
elif selector in ("singular"):
search_types = [SingularTestNode]
elif selector in ("unit"):
search_types = [UnitTestDefinition]
else:
raise DbtRuntimeError(
f'Invalid test type selector {selector}: expected "generic" or ' '"singular"'
f'Invalid test type selector {selector}: expected "generic", "singular", "unit", or "data"'
)

for node, real_node in self.parsed_nodes(included_nodes):
if isinstance(real_node, search_type):
yield node
for unique_id, node in self.parsed_and_unit_nodes(included_nodes):
if isinstance(node, tuple(search_types)):
yield unique_id


class StateSelectorMethod(SelectorMethod):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
{%- endif -%}

{%- if not column_name_to_data_types -%}
{{ exceptions.raise_compiler_error("columns not available for " ~ model.name) }}
{{ exceptions.raise_compiler_error("Not able to get columns for unit test '" ~ model.name ~ "' from relation " ~ this) }}
{%- endif -%}

{%- for column_name, column_type in column_name_to_data_types.items() -%}
Expand Down
7 changes: 1 addition & 6 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
MANIFEST_FILE_NAME,
PARTIAL_PARSE_FILE_NAME,
SEMANTIC_MANIFEST_FILE_NAME,
UNIT_TEST_MANIFEST_FILE_NAME,
)
from dbt.helper_types import PathSet
from dbt.events.functions import fire_event, get_invocation_id, warn_or_error
Expand Down Expand Up @@ -1767,11 +1766,7 @@ def write_semantic_manifest(manifest: Manifest, target_path: str) -> None:


def write_manifest(manifest: Manifest, target_path: str, which: Optional[str] = None):
if which and which == "unit-test":
file_name = UNIT_TEST_MANIFEST_FILE_NAME
else:
file_name = MANIFEST_FILE_NAME

file_name = MANIFEST_FILE_NAME
path = os.path.join(target_path, file_name)
manifest.write(path)

Expand Down
36 changes: 29 additions & 7 deletions core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from csv import DictReader
from pathlib import Path
from typing import List, Set, Dict, Any
import os

from dbt_extractor import py_extract_from_source, ExtractionError # type: ignore

from dbt import utils
from dbt.config import RuntimeConfig
from dbt.context.context_config import ContextConfig
from dbt.context.providers import generate_parse_exposure, get_rendered
Expand Down Expand Up @@ -44,9 +46,9 @@ def __init__(self, manifest, root_project, selected) -> None:

def load(self) -> Manifest:
for unique_id in self.selected:
unit_test_case = self.manifest.unit_tests[unique_id]
self.parse_unit_test_case(unit_test_case)

if unique_id in self.manifest.unit_tests:
unit_test_case = self.manifest.unit_tests[unique_id]
self.parse_unit_test_case(unit_test_case)
return self.unit_test_manifest

def parse_unit_test_case(self, test_case: UnitTestDefinition):
Expand Down Expand Up @@ -86,7 +88,6 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
overrides=test_case.overrides,
)

# TODO: generalize this method
ctx = generate_parse_exposure(
unit_test_node, # type: ignore
self.root_project,
Expand Down Expand Up @@ -133,7 +134,11 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
),
}

if original_input_node.resource_type in (NodeType.Model, NodeType.Seed):
if original_input_node.resource_type in (
NodeType.Model,
NodeType.Seed,
NodeType.Snapshot,
):
input_name = f"{unit_test_node.name}__{original_input_node.name}"
input_node = ModelNode(
**common_fields,
Expand Down Expand Up @@ -254,12 +259,16 @@ def _load_rows_from_seed(self, ref_str: str) -> List[Dict[str, Any]]:
def parse(self) -> ParseResult:
for data in self.get_key_dicts():
unit_test = self._get_unit_test(data)
model_name_split = unit_test.model.split()
tested_model_node = self._find_tested_model_node(unit_test)
unit_test_case_unique_id = (
f"{NodeType.Unit}.{self.project.project_name}.{unit_test.model}.{unit_test.name}"
)
unit_test_fqn = [self.project.project_name] + model_name_split + [unit_test.name]
unit_test_fqn = self._build_fqn(
self.project.project_name,
self.yaml.path.original_file_path,
unit_test.model,
unit_test.name,
)
unit_test_config = self._build_unit_test_config(unit_test_fqn, unit_test.config)

# Check that format and type of rows matches for each given input
Expand All @@ -284,6 +293,7 @@ def parse(self) -> ParseResult:
depends_on=DependsOn(nodes=[tested_model_node.unique_id]),
fqn=unit_test_fqn,
config=unit_test_config,
schema=tested_model_node.schema,
)
# for calculating state:modified
unit_test_definition.build_unit_test_checksum(
Expand Down Expand Up @@ -329,3 +339,15 @@ def _build_unit_test_config(
unit_test_config_dict = self.render_entry(unit_test_config_dict)

return UnitTestConfig.from_dict(unit_test_config_dict)

def _build_fqn(self, package_name, original_file_path, model_name, test_name):
# This code comes from "get_fqn" and "get_fqn_prefix" in the base parser.
# We need to get the directories underneath the model-path.
path = Path(original_file_path)
relative_path = str(path.relative_to(*path.parts[:1]))
no_ext = os.path.splitext(relative_path)[0]
fqn = [package_name]
fqn.extend(utils.split_path(no_ext)[:-1])
fqn.append(model_name)
fqn.append(test_name)
return fqn
2 changes: 1 addition & 1 deletion core/dbt/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def compile_and_execute(self, manifest, ctx):
with collect_timing_info("compile", ctx.timing.append):
# if we fail here, we still have a compiled node to return
# this has the benefit of showing a build path for the errant
# model
# model. This calls the 'compile' method in CompileTask
ctx.node = self.compile(manifest)

# for ephemeral nodes, we only want to compile, not run
Expand Down
1 change: 1 addition & 0 deletions core/dbt/task/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class BuildTask(RunTask):
NodeType.Snapshot: snapshot_model_runner,
NodeType.Seed: seed_runner,
NodeType.Test: test_runner,
NodeType.Unit: test_runner,
}
ALL_RESOURCE_VALUES = frozenset({x for x in RUNNER_MAP.keys()})

Expand Down
3 changes: 0 additions & 3 deletions core/dbt/task/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from dbt.task.seed import SeedTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask
from dbt.task.unit_test import UnitTestTask

RETRYABLE_STATUSES = {NodeStatus.Error, NodeStatus.Fail, NodeStatus.Skipped, NodeStatus.RuntimeErr}
OVERRIDE_PARENT_FLAGS = {
Expand All @@ -41,7 +40,6 @@
"test": TestTask,
"run": RunTask,
"run-operation": RunOperationTask,
"unit-test": UnitTestTask,
}

CMD_DICT = {
Expand All @@ -54,7 +52,6 @@
"test": CliCommand.TEST,
"run": CliCommand.RUN,
"run-operation": CliCommand.RUN_OPERATION,
"unit-test": CliCommand.UNIT_TEST,
}


Expand Down
Loading

0 comments on commit 3f1ed23

Please sign in to comment.