Skip to content

Commit

Permalink
Enable cache precomputation for run templates (#3156)
Browse files Browse the repository at this point in the history
* Enable cache precomputation for run templates

* Fix test

* Never delete placeholder run

* Fix tests
  • Loading branch information
schustmi authored Oct 31, 2024
1 parent 9f83323 commit 4004706
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 44 deletions.
7 changes: 3 additions & 4 deletions src/zenml/orchestrators/base_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,16 @@ def prepare_or_run_pipeline(
environment: Environment variables to set in the orchestration
environment. These don't need to be set if running locally.
Returns:
The optional return value from this method will be returned by the
`pipeline_instance.run()` call when someone is running a pipeline.
Yields:
Metadata for the pipeline run.
"""

def run(
self,
deployment: "PipelineDeploymentResponse",
stack: "Stack",
placeholder_run: Optional["PipelineRunResponse"] = None,
) -> Any:
) -> None:
"""Runs a pipeline on a stack.
Args:
Expand Down
4 changes: 3 additions & 1 deletion src/zenml/pipelines/pipeline_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,9 @@ def _run(
)

deploy_pipeline(
deployment=deployment_model, stack=stack, placeholder_run=run
deployment=deployment_model,
stack=stack,
placeholder_run=run,
)
if run:
return Client().get_pipeline_run(run.id)
Expand Down
18 changes: 9 additions & 9 deletions src/zenml/pipelines/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PipelineRunResponse,
StackResponse,
)
from zenml.orchestrators.publish_utils import publish_failed_pipeline_run
from zenml.orchestrators.utils import get_run_name
from zenml.stack import Flavor, Stack
from zenml.utils import code_utils, notebook_utils, source_utils
Expand Down Expand Up @@ -125,34 +126,33 @@ def deploy_pipeline(
Args:
deployment: The deployment to run.
stack: The stack on which to run the deployment.
placeholder_run: An optional placeholder run for the deployment. This
will be deleted in case the pipeline deployment failed.
placeholder_run: An optional placeholder run for the deployment.
Raises:
Exception: Any exception that happened while deploying or running
(in case it happens synchronously) the pipeline.
"""
stack.prepare_pipeline_deployment(deployment=deployment)

# Prevent execution of nested pipelines which might lead to
# unexpected behavior
previous_value = constants.SHOULD_PREVENT_PIPELINE_EXECUTION
constants.SHOULD_PREVENT_PIPELINE_EXECUTION = True
try:
stack.prepare_pipeline_deployment(deployment=deployment)
stack.deploy_pipeline(
deployment=deployment,
placeholder_run=placeholder_run,
)
except Exception as e:
if (
placeholder_run
and Client().get_pipeline_run(placeholder_run.id).status
and Client()
.get_pipeline_run(placeholder_run.id, hydrate=False)
.status
== ExecutionStatus.INITIALIZING
):
# The run hasn't actually started yet, which means that we
# failed during initialization -> We don't want the
# placeholder run to stay in the database
Client().delete_pipeline_run(placeholder_run.id)
# The run failed during the initialization phase -> We change it's
# status to `Failed`
publish_failed_pipeline_run(placeholder_run.id)

raise e
finally:
Expand Down
8 changes: 2 additions & 6 deletions src/zenml/stack/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,18 +810,14 @@ def deploy_pipeline(
self,
deployment: "PipelineDeploymentResponse",
placeholder_run: Optional["PipelineRunResponse"] = None,
) -> Any:
) -> None:
"""Deploys a pipeline on this stack.
Args:
deployment: The pipeline deployment.
placeholder_run: An optional placeholder run for the deployment.
This will be deleted in case the pipeline deployment failed.
Returns:
The return value of the call to `orchestrator.run_pipeline(...)`.
"""
return self.orchestrator.run(
self.orchestrator.run(
deployment=deployment, stack=self, placeholder_run=placeholder_run
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
from zenml.entrypoints.base_entrypoint_configuration import (
BaseEntrypointConfiguration,
)
from zenml.pipelines.run_utils import (
deploy_pipeline,
)
from zenml.pipelines.run_utils import deploy_pipeline, get_placeholder_run


class RunnerEntrypointConfiguration(BaseEntrypointConfiguration):
Expand All @@ -36,4 +34,9 @@ def run(self) -> None:
stack = Client().active_stack
assert deployment.stack and stack.id == deployment.stack.id

deploy_pipeline(deployment=deployment, stack=stack)
placeholder_run = get_placeholder_run(deployment_id=deployment.id)
deploy_pipeline(
deployment=deployment,
stack=stack,
placeholder_run=placeholder_run,
)
14 changes: 5 additions & 9 deletions tests/unit/pipelines/test_base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,23 +605,18 @@ def test_rerunning_deloyment_does_not_fail(
assert runs.total == 2


def test_failure_during_initialization_deletes_placeholder_run(
def test_failure_during_initialization_marks_placeholder_run_as_failed(
clean_client,
empty_pipeline, # noqa: F811
mocker,
):
"""Tests that when a pipeline run fails during initialization, the
placeholder run that was created for it is deleted."""
placeholder run is marked as failed."""
mock_create_run = mocker.patch.object(
type(clean_client.zen_store),
"create_run",
wraps=clean_client.zen_store.create_run,
)
mock_delete_run = mocker.patch.object(
type(clean_client.zen_store),
"delete_run",
wraps=clean_client.zen_store.delete_run,
)

pipeline_instance = empty_pipeline
assert clean_client.list_pipeline_runs().total == 0
Expand All @@ -634,9 +629,10 @@ def test_failure_during_initialization_deletes_placeholder_run(
pipeline_instance()

mock_create_run.assert_called_once()
mock_delete_run.assert_called_once()

assert clean_client.list_pipeline_runs().total == 0
runs = clean_client.list_pipeline_runs()
assert len(runs) == 1
assert runs[0].status == ExecutionStatus.FAILED


def test_running_scheduled_pipeline_does_not_create_placeholder_run(
Expand Down
12 changes: 1 addition & 11 deletions tests/unit/stack/test_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,32 +160,22 @@ def test_stack_deployment(
components."""
# Mock the pipeline run registering which tries (and fails) to serialize
# our mock objects

pipeline_run_return_value = object()
stack_with_mock_components.orchestrator.run.return_value = (
pipeline_run_return_value
)

with empty_pipeline:
empty_pipeline.entrypoint()
deployment = Compiler().compile(
pipeline=empty_pipeline,
stack=stack_with_mock_components,
run_configuration=PipelineRunConfiguration(),
)
return_value = stack_with_mock_components.deploy_pipeline(
stack_with_mock_components.deploy_pipeline(
deployment=deployment,
)

# for component in stack_with_mock_components.components.values():
# component.prepare_step_run.assert_called_once()

stack_with_mock_components.orchestrator.run.assert_called_once_with(
deployment=deployment,
stack=stack_with_mock_components,
placeholder_run=None,
)
assert return_value is pipeline_run_return_value


def test_requires_remote_server(stack_with_mock_components, mocker):
Expand Down

0 comments on commit 4004706

Please sign in to comment.