Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move unit testing to test and build commands #9108

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231116-144006.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Move unit testing to test command
time: 2023-11-16T14:40:06.121336-05:00
custom:
Author: gshank
Issue: "8979"
1 change: 0 additions & 1 deletion core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,6 @@ def command_args(command: CliCommand) -> ArgsList:
CliCommand.SOURCE_FRESHNESS: cli.freshness,
CliCommand.TEST: cli.test,
CliCommand.RETRY: cli.retry,
CliCommand.UNIT_TEST: cli.unit_test,
}
click_cmd: Optional[ClickCommand] = CMD_DICT.get(command, None)
if click_cmd is None:
Expand Down
46 changes: 1 addition & 45 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from dbt.task.show import ShowTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask
from dbt.task.unit_test import UnitTestTask


@dataclass
Expand Down Expand Up @@ -870,6 +869,7 @@ def freshness(ctx, **kwargs):
@p.project_dir
@p.select
@p.selector
@p.show_output_format
MichelleArk marked this conversation as resolved.
Show resolved Hide resolved
@p.state
@p.defer_state
@p.deprecated_state
Expand Down Expand Up @@ -897,50 +897,6 @@ def test(ctx, **kwargs):
return results, success


# dbt unit-test
@cli.command("unit-test")
@click.pass_context
@p.defer
@p.deprecated_defer
@p.exclude
@p.fail_fast
@p.favor_state
@p.deprecated_favor_state
@p.indirect_selection
@p.show_output_format
@p.profile
@p.profiles_dir
@p.project_dir
@p.select
@p.selector
@p.state
@p.defer_state
@p.deprecated_state
@p.store_failures
@p.target
@p.target_path
@p.threads
@p.vars
@p.version_check
@requires.postflight
@requires.preflight
@requires.profile
@requires.project
@requires.runtime_config
@requires.manifest
def unit_test(ctx, **kwargs):
"""Runs tests on data in deployed models. Run this after `dbt run`"""
task = UnitTestTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
ctx.obj["manifest"],
)

results = task.run()
success = task.interpret_results(results)
return results, success


# Support running as a module
if __name__ == "__main__":
cli()
1 change: 0 additions & 1 deletion core/dbt/cli/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class Command(Enum):
SOURCE_FRESHNESS = "freshness"
TEST = "test"
RETRY = "retry"
UNIT_TEST = "unit-test"

@classmethod
def from_str(cls, s: str) -> "Command":
Expand Down
4 changes: 4 additions & 0 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
InjectedCTE,
SeedNode,
UnitTestNode,
UnitTestDefinition,
)
from dbt.exceptions import (
GraphDependencyNotFoundError,
Expand Down Expand Up @@ -539,6 +540,9 @@ def compile_node(
the node's raw_code into compiled_code, and then calls the
recursive method to "prepend" the ctes.
"""
if isinstance(node, UnitTestDefinition):
return node

# Make sure Lexer for sqlparse 0.4.4 is initialized
from sqlparse.lexer import Lexer # type: ignore

Expand Down
1 change: 0 additions & 1 deletion core/dbt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@
MANIFEST_FILE_NAME = "manifest.json"
SEMANTIC_MANIFEST_FILE_NAME = "semantic_manifest.json"
PARTIAL_PARSE_FILE_NAME = "partial_parse.msgpack"
UNIT_TEST_MANIFEST_FILE_NAME = "unit_test_manifest.json"
PACKAGE_LOCK_HASH_KEY = "sha1_hash"
16 changes: 15 additions & 1 deletion core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,16 +1082,30 @@ class UnitTestNode(CompiledNode):


@dataclass
class UnitTestDefinition(GraphNode):
class UnitTestDefinitionMandatory:
model: str
given: Sequence[UnitTestInputFixture]
expect: UnitTestOutputFixture


@dataclass
class UnitTestDefinition(NodeInfoMixin, GraphNode, UnitTestDefinitionMandatory):
description: str = ""
overrides: Optional[UnitTestOverrides] = None
depends_on: DependsOn = field(default_factory=DependsOn)
config: UnitTestConfig = field(default_factory=UnitTestConfig)
checksum: Optional[str] = None

@property
def build_path(self):
# TODO: is this actually necessary?
return self.original_file_path

@property
def compiled_path(self):
# TODO: is this actually necessary?
return self.original_file_path

@property
def depends_on_nodes(self):
return self.depends_on.nodes
Expand Down
9 changes: 7 additions & 2 deletions core/dbt/contracts/results.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import threading

from dbt.contracts.graph.unparsed import FreshnessThreshold
from dbt.contracts.graph.nodes import CompiledNode, SourceDefinition, ResultNode
from dbt.contracts.graph.nodes import (
CompiledNode,
SourceDefinition,
ResultNode,
UnitTestDefinition,
)
from dbt.contracts.util import (
BaseArtifactMetadata,
ArtifactMixin,
Expand Down Expand Up @@ -153,7 +158,7 @@ def to_msg_dict(self):

@dataclass
class NodeResult(BaseResult):
node: ResultNode
node: Union[ResultNode, UnitTestDefinition]


@dataclass
Expand Down
20 changes: 12 additions & 8 deletions core/dbt/graph/selector_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,20 +552,24 @@ class TestTypeSelectorMethod(SelectorMethod):
__test__ = False

def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[UniqueId]:
search_type: Type
search_types: List[Any]
# continue supporting 'schema' + 'data' for backwards compatibility
if selector in ("generic", "schema"):
search_type = GenericTestNode
elif selector in ("singular", "data"):
search_type = SingularTestNode
search_types = [GenericTestNode]
elif selector in ("data"):
search_types = [GenericTestNode, SingularTestNode]
elif selector in ("singular"):
search_types = [SingularTestNode]
elif selector in ("unit"):
search_types = [UnitTestDefinition]
else:
raise DbtRuntimeError(
f'Invalid test type selector {selector}: expected "generic" or ' '"singular"'
f'Invalid test type selector {selector}: expected "generic", "singular", "unit", or "data"'
)

for node, real_node in self.parsed_nodes(included_nodes):
if isinstance(real_node, search_type):
yield node
for unique_id, node in self.parsed_and_unit_nodes(included_nodes):
if isinstance(node, tuple(search_types)):
yield unique_id


class StateSelectorMethod(SelectorMethod):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
{%- endif -%}

{%- if not column_name_to_data_types -%}
{{ exceptions.raise_compiler_error("columns not available for " ~ model.name) }}
{{ exceptions.raise_compiler_error("Not able to get columns for unit test '" ~ model.name ~ "' from relation " ~ this) }}
{%- endif -%}

{%- for column_name, column_type in column_name_to_data_types.items() -%}
Expand Down
7 changes: 1 addition & 6 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
MANIFEST_FILE_NAME,
PARTIAL_PARSE_FILE_NAME,
SEMANTIC_MANIFEST_FILE_NAME,
UNIT_TEST_MANIFEST_FILE_NAME,
)
from dbt.helper_types import PathSet
from dbt.events.functions import fire_event, get_invocation_id, warn_or_error
Expand Down Expand Up @@ -1767,11 +1766,7 @@ def write_semantic_manifest(manifest: Manifest, target_path: str) -> None:


def write_manifest(manifest: Manifest, target_path: str, which: Optional[str] = None):
if which and which == "unit-test":
file_name = UNIT_TEST_MANIFEST_FILE_NAME
else:
file_name = MANIFEST_FILE_NAME

file_name = MANIFEST_FILE_NAME
path = os.path.join(target_path, file_name)
manifest.write(path)

Expand Down
29 changes: 23 additions & 6 deletions core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from csv import DictReader
from pathlib import Path
from typing import List, Set, Dict, Any
import os

from dbt_extractor import py_extract_from_source, ExtractionError # type: ignore

from dbt import utils
from dbt.config import RuntimeConfig
from dbt.context.context_config import ContextConfig
from dbt.context.providers import generate_parse_exposure, get_rendered
Expand Down Expand Up @@ -44,9 +46,9 @@ def __init__(self, manifest, root_project, selected) -> None:

def load(self) -> Manifest:
for unique_id in self.selected:
unit_test_case = self.manifest.unit_tests[unique_id]
self.parse_unit_test_case(unit_test_case)

if unique_id in self.manifest.unit_tests:
unit_test_case = self.manifest.unit_tests[unique_id]
self.parse_unit_test_case(unit_test_case)
return self.unit_test_manifest

def parse_unit_test_case(self, test_case: UnitTestDefinition):
Expand Down Expand Up @@ -86,7 +88,6 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
overrides=test_case.overrides,
)

# TODO: generalize this method
ctx = generate_parse_exposure(
unit_test_node, # type: ignore
self.root_project,
Expand Down Expand Up @@ -254,12 +255,16 @@ def _load_rows_from_seed(self, ref_str: str) -> List[Dict[str, Any]]:
def parse(self) -> ParseResult:
for data in self.get_key_dicts():
unit_test = self._get_unit_test(data)
model_name_split = unit_test.model.split()
tested_model_node = self._find_tested_model_node(unit_test)
unit_test_case_unique_id = (
f"{NodeType.Unit}.{self.project.project_name}.{unit_test.model}.{unit_test.name}"
)
unit_test_fqn = [self.project.project_name] + model_name_split + [unit_test.name]
unit_test_fqn = self._build_fqn(
self.project.project_name,
self.yaml.path.original_file_path,
unit_test.model,
unit_test.name,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took this for a spin locally and it worked as expected! 🎩

--- dbt_project.yml
...
unit-tests:
  jaffle_shop:
    unit_testing:
      my_model:
        +tags: project

schema.yml in models/unit_testing/schema.yml

dbt build -s +my_model,test_type:unit --exclude tag:project

=> finds no unit tests. moving the config around at various levels also worked as expected.

Would be good to add some functional testing around this either in this PR or in a follow-on issue if it balloons the scope too much. We do have tests on the fqn in tests/unit/test_unit_test_parser.py so for the scope of this refactor those are good!

unit_test_config = self._build_unit_test_config(unit_test_fqn, unit_test.config)

# Check that format and type of rows matches for each given input
Expand Down Expand Up @@ -329,3 +334,15 @@ def _build_unit_test_config(
unit_test_config_dict = self.render_entry(unit_test_config_dict)

return UnitTestConfig.from_dict(unit_test_config_dict)

def _build_fqn(self, package_name, original_file_path, model_name, test_name):
# This code comes from "get_fqn" and "get_fqn_prefix" in the base parser.
# We need to get the directories underneath the model-path.
path = Path(original_file_path)
relative_path = str(path.relative_to(*path.parts[:1]))
no_ext = os.path.splitext(relative_path)[0]
fqn = [package_name]
fqn.extend(utils.split_path(no_ext)[:-1])
fqn.append(model_name)
fqn.append(test_name)
return fqn
2 changes: 1 addition & 1 deletion core/dbt/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def compile_and_execute(self, manifest, ctx):
with collect_timing_info("compile", ctx.timing.append):
# if we fail here, we still have a compiled node to return
# this has the benefit of showing a build path for the errant
# model
# model. This calls the 'compile' method in CompileTask
ctx.node = self.compile(manifest)

# for ephemeral nodes, we only want to compile, not run
Expand Down
1 change: 1 addition & 0 deletions core/dbt/task/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class BuildTask(RunTask):
NodeType.Snapshot: snapshot_model_runner,
NodeType.Seed: seed_runner,
NodeType.Test: test_runner,
NodeType.Unit: test_runner,
}
ALL_RESOURCE_VALUES = frozenset({x for x in RUNNER_MAP.keys()})

Expand Down
3 changes: 0 additions & 3 deletions core/dbt/task/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from dbt.task.seed import SeedTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask
from dbt.task.unit_test import UnitTestTask

RETRYABLE_STATUSES = {NodeStatus.Error, NodeStatus.Fail, NodeStatus.Skipped, NodeStatus.RuntimeErr}
OVERRIDE_PARENT_FLAGS = {
Expand All @@ -41,7 +40,6 @@
"test": TestTask,
"run": RunTask,
"run-operation": RunOperationTask,
"unit-test": UnitTestTask,
}

CMD_DICT = {
Expand All @@ -54,7 +52,6 @@
"test": CliCommand.TEST,
"run": CliCommand.RUN,
"run-operation": CliCommand.RUN_OPERATION,
"unit-test": CliCommand.UNIT_TEST,
}


Expand Down
13 changes: 4 additions & 9 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,20 +141,13 @@ def get_graph_queue(self) -> GraphQueue:
spec = self.get_selection_spec()
return selector.get_graph_queue(spec)

# A callback for unit testing
def reset_job_queue_and_manifest(self):
pass

def _runtime_initialize(self):
self.compile_manifest()
if self.manifest is None or self.graph is None:
raise DbtInternalError("_runtime_initialize never loaded the graph!")

self.job_queue = self.get_graph_queue()

# for unit testing
self.reset_job_queue_and_manifest()

# we use this a couple of times. order does not matter.
self._flattened_nodes = []
for uid in self.job_queue.get_selected_nodes():
Expand All @@ -164,9 +157,11 @@ def _runtime_initialize(self):
self._flattened_nodes.append(self.manifest.sources[uid])
elif uid in self.manifest.saved_queries:
self._flattened_nodes.append(self.manifest.saved_queries[uid])
elif uid in self.manifest.unit_tests:
self._flattened_nodes.append(self.manifest.unit_tests[uid])
else:
raise DbtInternalError(
f"Node selection returned {uid}, expected a node or a source"
f"Node selection returned {uid}, expected a node, a source, or a unit test"
)

self.num_nodes = len([n for n in self._flattened_nodes if not n.is_ephemeral_model])
Expand Down Expand Up @@ -496,7 +491,7 @@ def run(self):

if self.args.write_json:
# args.which used to determine file name for unit test manifest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small nit: comment can be deleted now too

write_manifest(self.manifest, self.config.project_target_path, self.args.which)
write_manifest(self.manifest, self.config.project_target_path)
if hasattr(result, "write"):
result.write(self.result_path())

Expand Down
Loading
Loading