Skip to content

Commit

Permalink
Update generate_snapshots.py to use the marker to select tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Oct 30, 2023
1 parent f9c195c commit 22400bd
Showing 1 changed file with 10 additions and 33 deletions.
43 changes: 10 additions & 33 deletions metricflow/test/generate_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,16 @@

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.implementations.base import FrozenBaseModel
from dbt_semantic_interfaces.pretty_print import pformat_big_objects

from metricflow.protocols.sql_client import SqlEngine
from metricflow.test.fixtures.setup_fixtures import SQL_ENGINE_SNAPSHOT_MARKER_NAME

logger = logging.getLogger(__name__)


TEST_DIRECTORY = "metricflow/test"


class MetricFlowTestCredentialSet(FrozenBaseModel): # noqa: D
engine_url: Optional[str]
engine_password: Optional[str]
Expand Down Expand Up @@ -97,41 +100,14 @@ def as_configurations(self) -> Sequence[MetricFlowTestConfiguration]: # noqa: D
)


SNAPSHOT_GENERATING_TESTS = (
"metricflow/test/cli/test_cli.py::test_saved_query",
"metricflow/test/cli/test_cli.py::test_saved_query_with_where",
"metricflow/test/cli/test_cli.py::test_saved_query_with_limit",
"metricflow/test/cli/test_cli.py::test_saved_query_explain",
"metricflow/test/dataflow/builder/test_dataflow_plan_builder.py",
"metricflow/test/dataflow/optimizer/source_scan/test_cm_branch_combiner.py",
"metricflow/test/dataflow/optimizer/source_scan/test_source_scan_optimizer.py",
"metricflow/test/dataset/test_convert_semantic_model.py",
"metricflow/test/integration/test_rendered_query.py",
"metricflow/test/integration/test_rendered_query.py",
"metricflow/test/model/test_data_warehouse_tasks.py",
"metricflow/test/plan_conversion/dataflow_to_sql/test_metric_time_dimension_to_sql.py",
"metricflow/test/plan_conversion/test_dataflow_to_execution.py",
"metricflow/test/plan_conversion/test_dataflow_to_sql_plan.py",
"metricflow/test/sql/optimizer/test_column_pruner.py",
"metricflow/test/sql/optimizer/test_rewriting_sub_query_reducer.py",
"metricflow/test/sql/optimizer/test_sub_query_reducer.py",
"metricflow/test/sql/optimizer/test_sub_query_reducer.py",
"metricflow/test/sql/optimizer/test_table_alias_simplifier.py",
"metricflow/test/sql/test_engine_specific_rendering.py",
"metricflow/test/sql/test_sql_plan_render.py",
"metricflow/test/sql/test_sql_plan_render.py",
)


def run_command(command: str) -> None: # noqa: D
logger.info(f"Running command {command}")
return_code = os.system(command)
if return_code != 0:
raise RuntimeError(f"Error running command: {command}")


def run_tests(test_configuration: MetricFlowTestConfiguration, test_file_paths: Sequence[str]) -> None: # noqa: D
combined_paths = " ".join(test_file_paths)
def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa: D
if test_configuration.credential_set.engine_url is None:
if "MF_SQL_ENGINE_URL" in os.environ:
del os.environ["MF_SQL_ENGINE_URL"]
Expand All @@ -146,7 +122,7 @@ def run_tests(test_configuration: MetricFlowTestConfiguration, test_file_paths:

if test_configuration.engine is SqlEngine.DUCKDB:
# Can't use --use-persistent-source-schema with duckdb since it's in memory.
run_command(f"pytest -x -vv -n 4 --overwrite-snapshots {combined_paths}")
run_command(f"pytest -x -vv -n 4 --overwrite-snapshots -m '{SQL_ENGINE_SNAPSHOT_MARKER_NAME}' {TEST_DIRECTORY}")
elif (
test_configuration.engine is SqlEngine.REDSHIFT
or test_configuration.engine is SqlEngine.SNOWFLAKE
Expand All @@ -162,7 +138,8 @@ def run_tests(test_configuration: MetricFlowTestConfiguration, test_file_paths:
f"hatch -v run {hatch_env}:pytest -x -vv -n 4 "
f"--overwrite-snapshots"
f"{' --use-persistent-source-schema' if use_persistent_source_schema else ''}"
f" {combined_paths}"
f"-m '{SQL_ENGINE_SNAPSHOT_MARKER_NAME}' "
f"{TEST_DIRECTORY}"
)
else:
assert_values_exhausted(test_configuration.engine)
Expand All @@ -182,13 +159,13 @@ def run_cli() -> None: # noqa: D

credential_sets = MetricFlowTestCredentialSetForAllEngines.parse_raw(credential_sets_json_str)

logger.info(f"Running the following tests to generate snapshots:\n{pformat_big_objects(SNAPSHOT_GENERATING_TESTS)}")
logger.info(f"Running tests in '{TEST_DIRECTORY}' with the marker '{SQL_ENGINE_SNAPSHOT_MARKER_NAME}'")

for test_configuration in credential_sets.as_configurations:
logger.info(
f"Running tests for {test_configuration.engine} with URL: {test_configuration.credential_set.engine_url}"
)
run_tests(test_configuration, SNAPSHOT_GENERATING_TESTS)
run_tests(test_configuration)


if __name__ == "__main__":
Expand Down

0 comments on commit 22400bd

Please sign in to comment.