diff --git a/src/zenml/cli/utils.py b/src/zenml/cli/utils.py index 21162bfc8b..703dc5b690 100644 --- a/src/zenml/cli/utils.py +++ b/src/zenml/cli/utils.py @@ -2235,6 +2235,8 @@ def get_execution_status_emoji(status: "ExecutionStatus") -> str: """ from zenml.enums import ExecutionStatus + if status == ExecutionStatus.INITIALIZING: + return ":hourglass_flowing_sand:" if status == ExecutionStatus.FAILED: return ":x:" if status == ExecutionStatus.RUNNING: diff --git a/src/zenml/enums.py b/src/zenml/enums.py index 0b981764fa..2a3d43173f 100644 --- a/src/zenml/enums.py +++ b/src/zenml/enums.py @@ -56,11 +56,25 @@ class VisualizationType(StrEnum): class ExecutionStatus(StrEnum): """Enum that represents the current status of a step or pipeline run.""" + INITIALIZING = "initializing" FAILED = "failed" COMPLETED = "completed" RUNNING = "running" CACHED = "cached" + @property + def is_finished(self) -> bool: + """Whether the execution status refers to a finished execution. + + Returns: + Whether the execution status refers to a finished execution. + """ + return self in { + ExecutionStatus.FAILED, + ExecutionStatus.COMPLETED, + ExecutionStatus.CACHED, + } + class LoggingLevels(Enum): """Enum for logging levels.""" diff --git a/src/zenml/integrations/skypilot/orchestrators/skypilot_base_vm_orchestrator.py b/src/zenml/integrations/skypilot/orchestrators/skypilot_base_vm_orchestrator.py index 2b1fe896b9..48990bada1 100644 --- a/src/zenml/integrations/skypilot/orchestrators/skypilot_base_vm_orchestrator.py +++ b/src/zenml/integrations/skypilot/orchestrators/skypilot_base_vm_orchestrator.py @@ -21,7 +21,6 @@ import sky -from zenml.client import Client from zenml.entrypoints import PipelineEntrypointConfiguration from zenml.enums import StackComponentType from zenml.integrations.skypilot.flavors.skypilot_orchestrator_base_vm_config import ( @@ -31,7 +30,6 @@ from zenml.orchestrators import ( ContainerizedOrchestrator, ) -from zenml.orchestrators import utils as orchestrator_utils from zenml.stack import StackValidator from zenml.utils import string_utils @@ -264,12 +262,7 @@ def prepare_or_run_pipeline( self.prepare_environment_variable(set=False) run_duration = time.time() - start_time - run_id = orchestrator_utils.get_run_id_for_orchestrator_run_id( - orchestrator=self, orchestrator_run_id=orchestrator_run_id - ) - run_model = Client().zen_store.get_run(run_id) logger.info( - "Pipeline run `%s` has finished in `%s`.\n", - run_model.name, + "Pipeline run has finished in `%s`.", string_utils.get_human_readable_time(run_duration), ) diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index e22015b0c5..f6ed2390f8 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -62,7 +62,6 @@ class PipelineRunRequest(WorkspaceScopedRequest): """Request model for pipeline runs.""" - id: UUID name: str = Field( title="The name of the pipeline run.", max_length=STR_FIELD_MAX_LENGTH, diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index 3e878e583e..5b9ab13235 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -53,7 +53,7 @@ from zenml.config.pipeline_spec import PipelineSpec from zenml.config.schedule import Schedule from zenml.config.step_configurations import StepConfigurationUpdate -from zenml.enums import StackComponentType +from zenml.enums import ExecutionStatus, StackComponentType from zenml.hooks.hook_validators import resolve_and_validate_hook from zenml.logger import get_logger from zenml.models import ( @@ -65,11 +65,13 @@ PipelineDeploymentResponse, PipelineRequest, PipelineResponse, + PipelineRunRequest, PipelineRunResponse, ScheduleRequest, ) from zenml.new.pipelines import build_utils from zenml.new.pipelines.model_utils import NewModelVersionRequest +from zenml.orchestrators.utils import get_run_name from zenml.stack import Stack from zenml.steps import BaseStep from zenml.steps.entrypoint_function_utils import ( @@ -569,7 +571,7 @@ def _run( config_path: Optional[str] = None, unlisted: bool = False, prevent_build_reuse: bool = False, - ) -> None: + ) -> Optional[PipelineRunResponse]: """Runs the pipeline on the active stack. Args: @@ -597,6 +599,10 @@ def _run( Raises: Exception: bypass any exception from pipeline up. + + Returns: + Model of the pipeline run if running without a schedule, `None` if + running with a schedule. """ if constants.SHOULD_PREVENT_PIPELINE_EXECUTION: # An environment variable was set to stop the execution of @@ -609,7 +615,7 @@ def _run( self.name, constants.ENV_ZENML_PREVENT_PIPELINE_EXECUTION, ) - return + return None logger.info(f"Initiating a new run for the pipeline: `{self.name}`.") @@ -734,24 +740,52 @@ def _run( self.log_pipeline_deployment_metadata(deployment_model) + run = None + if not schedule: + run_request = PipelineRunRequest( + name=get_run_name( + run_name_template=deployment_model.run_name_template + ), + # We set the start time on the placeholder run already to + # make it consistent with the {time} placeholder in the + # run name. This means the placeholder run will usually + # have longer durations than scheduled runs, as for them + # the start_time is only set once the first step starts + # running. + start_time=datetime.utcnow(), + orchestrator_run_id=None, + user=Client().active_user.id, + workspace=deployment_model.workspace.id, + deployment=deployment_model.id, + pipeline=deployment_model.pipeline.id + if deployment_model.pipeline + else None, + status=ExecutionStatus.INITIALIZING, + ) + run = Client().zen_store.create_run(run_request) + # Prevent execution of nested pipelines which might lead to # unexpected behavior constants.SHOULD_PREVENT_PIPELINE_EXECUTION = True try: stack.deploy_pipeline(deployment=deployment_model) except Exception as e: + if ( + run + and Client().get_pipeline_run(run.id).status + == ExecutionStatus.INITIALIZING + ): + # The run hasn't actually started yet, which means that we + # failed during initialization -> We don't want the + # placeholder run to stay in the database + Client().delete_pipeline_run(run.id) + raise e finally: constants.SHOULD_PREVENT_PIPELINE_EXECUTION = False - runs = Client().list_pipeline_runs( - deployment_id=deployment_model.id, - sort_by="desc:start_time", - size=1, - ) - - if runs.items: - run_url = dashboard_utils.get_run_url(runs[0]) + if run: + run_url = dashboard_utils.get_run_url(run) if run_url: logger.info(f"Dashboard URL: {run_url}") else: @@ -760,14 +794,8 @@ def _run( "Dashboard`. In order to try it locally, please run " "`zenml up`." ) - else: - logger.warning( - f"Your orchestrator '{stack.orchestrator.name}' is " - f"running remotely. Note that the pipeline run will " - f"only show up on the ZenML dashboard once the first " - f"step has started executing on the remote " - f"infrastructure.", - ) + + return run @staticmethod def log_pipeline_deployment_metadata( diff --git a/src/zenml/orchestrators/local/local_orchestrator.py b/src/zenml/orchestrators/local/local_orchestrator.py index 217475ea3b..823daa93fa 100644 --- a/src/zenml/orchestrators/local/local_orchestrator.py +++ b/src/zenml/orchestrators/local/local_orchestrator.py @@ -16,10 +16,8 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Type from uuid import uuid4 -from zenml.client import Client from zenml.logger import get_logger from zenml.orchestrators import BaseOrchestrator -from zenml.orchestrators import utils as orchestrator_utils from zenml.orchestrators.base_orchestrator import ( BaseOrchestratorConfig, BaseOrchestratorFlavor, @@ -81,13 +79,8 @@ def prepare_or_run_pipeline( ) run_duration = time.time() - start_time - run_id = orchestrator_utils.get_run_id_for_orchestrator_run_id( - orchestrator=self, orchestrator_run_id=self._orchestrator_run_id - ) - run_model = Client().zen_store.get_run(run_id) logger.info( - "Run `%s` has finished in `%s`.", - run_model.name, + "Pipeline run has finished in `%s`.", string_utils.get_human_readable_time(run_duration), ) self._orchestrator_run_id = None diff --git a/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py b/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py index 58e7d3b282..9b53ff4e38 100644 --- a/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py +++ b/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py @@ -24,7 +24,6 @@ from docker.errors import ContainerError from pydantic import validator -from zenml.client import Client from zenml.config.base_settings import BaseSettings from zenml.config.global_config import GlobalConfiguration from zenml.constants import ( @@ -38,7 +37,6 @@ BaseOrchestratorFlavor, ContainerizedOrchestrator, ) -from zenml.orchestrators import utils as orchestrator_utils from zenml.stack import Stack, StackValidator from zenml.utils import string_utils @@ -193,13 +191,8 @@ def prepare_or_run_pipeline( raise RuntimeError(error_message) run_duration = time.time() - start_time - run_id = orchestrator_utils.get_run_id_for_orchestrator_run_id( - orchestrator=self, orchestrator_run_id=orchestrator_run_id - ) - run_model = Client().zen_store.get_run(run_id) logger.info( - "Pipeline run `%s` has finished in `%s`.\n", - run_model.name, + "Pipeline run has finished in `%s`.", string_utils.get_human_readable_time(run_duration), ) diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index 526f372391..6da14aa48d 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -31,7 +31,6 @@ from zenml.environment import get_run_environment_dict from zenml.logger import get_logger from zenml.logging import step_logging -from zenml.logging.step_logging import StepLogsStorageContext from zenml.model.utils import link_artifact_config_to_model_version from zenml.models import ( ArtifactVersionResponse, @@ -50,7 +49,6 @@ ) from zenml.orchestrators import utils as orchestrator_utils from zenml.orchestrators.step_runner import StepRunner -from zenml.orchestrators.utils import is_setting_enabled from zenml.stack import Stack from zenml.utils import string_utils @@ -152,7 +150,7 @@ def launch(self) -> None: if handle_bool_env_var(ENV_ZENML_DISABLE_STEP_LOGS_STORAGE, False): step_logging_enabled = False else: - step_logging_enabled = is_setting_enabled( + step_logging_enabled = orchestrator_utils.is_setting_enabled( is_enabled_on_step=self._step.config.enable_step_logs, is_enabled_on_pipeline=self._deployment.pipeline_configuration.enable_step_logs, ) @@ -167,7 +165,9 @@ def launch(self) -> None: self._step.config.name, ) - logs_context = StepLogsStorageContext(logs_uri=logs_uri) # type: ignore[assignment] + logs_context = step_logging.StepLogsStorageContext( + logs_uri=logs_uri + ) # type: ignore[assignment] logs_model = LogsRequest( uri=logs_uri, @@ -275,24 +275,14 @@ def _create_or_reuse_run(self) -> Tuple[PipelineRunResponse, bool]: The created or existing pipeline run, and a boolean indicating whether the run was created or reused. """ - run_id = orchestrator_utils.get_run_id_for_orchestrator_run_id( - orchestrator=self._stack.orchestrator, - orchestrator_run_id=self._orchestrator_run_id, - ) - - date = datetime.utcnow().strftime("%Y_%m_%d") - time = datetime.utcnow().strftime("%H_%M_%S_%f") - run_name = self._deployment.run_name_template.format( - date=date, time=time + run_name = orchestrator_utils.get_run_name( + run_name_template=self._deployment.run_name_template ) - logger.debug( - "Creating pipeline run with ID: %s, name: %s", run_id, run_name - ) + logger.debug("Creating pipeline run %s", run_name) client = Client() pipeline_run = PipelineRunRequest( - id=run_id, name=run_name, orchestrator_run_id=self._orchestrator_run_id, user=client.active_user.id, @@ -346,7 +336,7 @@ def _prepare( step_run.parent_step_ids = parent_step_ids step_run.cache_key = cache_key - cache_enabled = is_setting_enabled( + cache_enabled = orchestrator_utils.is_setting_enabled( is_enabled_on_step=self._step.config.enable_cache, is_enabled_on_pipeline=self._deployment.pipeline_configuration.enable_cache, ) diff --git a/src/zenml/orchestrators/utils.py b/src/zenml/orchestrators/utils.py index f7a47e20f7..06a0271ed6 100644 --- a/src/zenml/orchestrators/utils.py +++ b/src/zenml/orchestrators/utils.py @@ -14,6 +14,7 @@ """Utility functions for the orchestrator.""" import random +from datetime import datetime from typing import TYPE_CHECKING, Dict, Optional from uuid import UUID @@ -30,14 +31,9 @@ PIPELINE_API_TOKEN_EXPIRES_MINUTES, ) from zenml.enums import StoreType -from zenml.logger import get_logger -from zenml.utils import uuid_utils if TYPE_CHECKING: from zenml.models import PipelineDeploymentResponse - from zenml.orchestrators import BaseOrchestrator - -logger = get_logger(__name__) def get_orchestrator_run_name(pipeline_name: str) -> str: @@ -55,22 +51,6 @@ def get_orchestrator_run_name(pipeline_name: str) -> str: return f"{pipeline_name}_{random.Random().getrandbits(128):032x}" -def get_run_id_for_orchestrator_run_id( - orchestrator: "BaseOrchestrator", orchestrator_run_id: str -) -> UUID: - """Generates a run ID from an orchestrator run id. - - Args: - orchestrator: The orchestrator of the run. - orchestrator_run_id: The orchestrator run id. - - Returns: - The run id generated from the orchestrator run id. - """ - run_id_seed = f"{orchestrator.id}-{orchestrator_run_id}" - return uuid_utils.generate_uuid_from_string(run_id_seed) - - def is_setting_enabled( is_enabled_on_step: Optional[bool], is_enabled_on_pipeline: Optional[bool], @@ -172,3 +152,26 @@ def get_config_environment_vars( ) return environment_vars + + +def get_run_name(run_name_template: str) -> str: + """Fill out the run name template to get a complete run name. + + Args: + run_name_template: The run name template to fill out. + + Raises: + ValueError: If the run name is empty. + + Returns: + The run name derived from the template. + """ + date = datetime.utcnow().strftime("%Y_%m_%d") + time = datetime.utcnow().strftime("%H_%M_%S_%f") + + run_name = run_name_template.format(date=date, time=time) + + if run_name == "": + raise ValueError("Empty run names are not allowed.") + + return run_name diff --git a/src/zenml/zen_stores/migrations/versions/6917bce75069_add_pipeline_run_unique_constraint.py b/src/zenml/zen_stores/migrations/versions/6917bce75069_add_pipeline_run_unique_constraint.py new file mode 100644 index 0000000000..d9d3ba3e3e --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/6917bce75069_add_pipeline_run_unique_constraint.py @@ -0,0 +1,93 @@ +"""Add pipeline run unique constraint [6917bce75069]. + +Revision ID: 6917bce75069 +Revises: 5cc3f41cf048 +Create Date: 2023-11-15 16:07:24.343126 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "6917bce75069" +down_revision = "5cc3f41cf048" +branch_labels = None +depends_on = None + + +def add_orchestrator_run_id_for_old_runs() -> None: + """Add orchestrator_run_id for old runs. + + In order to add the unique constraint on deployment_id and + orchestrator_run_id, we first need to make sure all existing runs fulfill + this constraint. This is not the case for old pipeline runs which existed + before we had deployments, so we add a dummy value for those runs. + """ + meta = sa.MetaData(bind=op.get_bind()) + meta.reflect(only=("pipeline_run",)) + run_table = sa.Table("pipeline_run", meta) + connection = op.get_bind() + + query = ( + sa.update(run_table) + .where(run_table.c.deployment_id.is_(None)) + .where(run_table.c.orchestrator_run_id.is_(None)) + .values( + orchestrator_run_id="dummy_" + + sa.func.cast(run_table.c.id, sa.types.Text) + ) + ) + connection.execute(query) + + +def verify_unique_constraint_satisfied() -> None: + """Verifies that the unique constraint will be satisfied. + + Raises: + RuntimeError: If there are rows which have identical values for the + `deployment_id` and `orchestrator_run_id` columns. + """ + meta = sa.MetaData(bind=op.get_bind()) + meta.reflect(only=("pipeline_run",)) + run_table = sa.Table("pipeline_run", meta) + connection = op.get_bind() + + query = ( + sa.select(run_table.c.deployment_id, run_table.c.orchestrator_run_id) + .group_by(run_table.c.deployment_id, run_table.c.orchestrator_run_id) + .having(sa.func.count() > 1) + ) + result = connection.execute(query).fetchall() + + if result: + raise RuntimeError( + "Unable to migrate database because the `pipeline_run` table " + "contains rows with identical values for both the " + "`orchestrator_run_id` and the `deployment_id` columns." + ) + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + add_orchestrator_run_id_for_old_runs() + verify_unique_constraint_satisfied() + + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("pipeline_run", schema=None) as batch_op: + batch_op.create_unique_constraint( + "unique_orchestrator_run_id_for_deployment_id", + ["deployment_id", "orchestrator_run_id"], + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("pipeline_run", schema=None) as batch_op: + batch_op.drop_constraint( + "unique_orchestrator_run_id_for_deployment_id", type_="unique" + ) + + # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index c5cba8ec08..b8b0b59869 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, List, Optional from uuid import UUID +from sqlalchemy import UniqueConstraint from sqlmodel import TEXT, Column, Field, Relationship from zenml.config.pipeline_configurations import PipelineConfiguration @@ -54,6 +55,13 @@ class PipelineRunSchema(NamedSchema, table=True): """SQL Model for pipeline runs.""" __tablename__ = "pipeline_run" + __table_args__ = ( + UniqueConstraint( + "deployment_id", + "orchestrator_run_id", + name="unique_orchestrator_run_id_for_deployment_id", + ), + ) # Fields orchestrator_run_id: Optional[str] = Field(nullable=True) @@ -65,7 +73,7 @@ class PipelineRunSchema(NamedSchema, table=True): ) # Foreign keys - deployment_id: UUID = build_foreign_key_field( + deployment_id: Optional[UUID] = build_foreign_key_field( source=__tablename__, target=PipelineDeploymentSchema.__tablename__, source_column="deployment_id", @@ -179,7 +187,6 @@ def from_request( orchestrator_environment = json.dumps(request.orchestrator_environment) return cls( - id=request.id, workspace_id=request.workspace, user_id=request.user, name=request.name, @@ -300,3 +307,48 @@ def update(self, run_update: "PipelineRunUpdate") -> "PipelineRunSchema": self.updated = datetime.utcnow() return self + + def update_placeholder( + self, request: "PipelineRunRequest" + ) -> "PipelineRunSchema": + """Update a placeholder run. + + Args: + request: The pipeline run request which should replace the + placeholder. + + Raises: + RuntimeError: If the DB entry does not represent a placeholder run. + ValueError: If the run request does not match the deployment or + pipeline ID of the placeholder run. + + Returns: + The updated `PipelineRunSchema`. + """ + if ( + self.orchestrator_run_id + or self.status != ExecutionStatus.INITIALIZING + ): + raise RuntimeError( + f"Unable to replace pipeline run {self.id} which is not a " + "placeholder run." + ) + + if ( + self.deployment_id != request.deployment + or self.pipeline_id != request.pipeline + ): + raise ValueError( + "Deployment or orchestrator run ID of placeholder run do not " + "match the IDs of the run request." + ) + + orchestrator_environment = json.dumps(request.orchestrator_environment) + + self.orchestrator_run_id = request.orchestrator_run_id + self.orchestrator_environment = orchestrator_environment + self.status = request.status + + self.updated = datetime.utcnow() + + return self diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 2ee7bbf5fd..c566e75ad1 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -3156,7 +3156,7 @@ def create_run( The created pipeline run. Raises: - EntityExistsError: If an identical pipeline run already exists. + EntityExistsError: If a run with the same name already exists. """ with Session(self.engine) as session: # Check if pipeline run with same name already exists. @@ -3171,18 +3171,6 @@ def create_run( f"'{pipeline_run.name}' already exists." ) - # Check if pipeline run with same ID already exists. - existing_id_run = session.exec( - select(PipelineRunSchema).where( - PipelineRunSchema.id == pipeline_run.id - ) - ).first() - if existing_id_run is not None: - raise EntityExistsError( - f"Unable to create pipeline run: A pipeline run with ID " - f"'{pipeline_run.id}' already exists." - ) - # Create the pipeline run new_run = PipelineRunSchema.from_request(pipeline_run) session.add(new_run) @@ -3208,6 +3196,184 @@ def get_run( run_name_or_id, session=session ).to_model(hydrate=hydrate) + def _replace_placeholder_run( + self, pipeline_run: PipelineRunRequest + ) -> PipelineRunResponse: + """Replace a placeholder run with the requested pipeline run. + + Args: + pipeline_run: Pipeline run request. + + Raises: + KeyError: If no placeholder run exists. + + Returns: + The run model. + """ + with Session(self.engine) as session: + run_schema = session.exec( + select(PipelineRunSchema) + # The following line locks the row in the DB, so anyone else + # calling `SELECT ... FOR UPDATE` will wait until the first + # transaction to do so finishes. After the first transaction + # finishes, the subsequent queries will not be able to find a + # placeholder run anymore, as we already updated the + # orchestrator_run_id. + # Note: This only locks a single row if the where clause of + # the query is indexed (we have a unique index due to the + # unique constraint on those columns). Otherwise this will lock + # multiple rows or even the complete table which we want to + # avoid. + .with_for_update() + .where( + PipelineRunSchema.deployment_id == pipeline_run.deployment + ) + .where( + PipelineRunSchema.orchestrator_run_id.is_(None) # type: ignore[union-attr] + ) + ).first() + + if not run_schema: + raise KeyError("No placeholder run found.") + + run_schema.update_placeholder(pipeline_run) + session.add(run_schema) + session.commit() + + return run_schema.to_model(hydrate=True) + + def _get_run_by_orchestrator_run_id( + self, orchestrator_run_id: str, deployment_id: UUID + ) -> PipelineRunResponse: + """Get a pipeline run based on deployment and orchestrator run ID. + + Args: + orchestrator_run_id: The orchestrator run ID. + deployment_id: The deployment ID. + + Raises: + KeyError: If no run exists for the deployment and orchestrator run + ID. + + Returns: + The pipeline run. + """ + with Session(self.engine) as session: + run_schema = session.exec( + select(PipelineRunSchema) + .where(PipelineRunSchema.deployment_id == deployment_id) + .where( + PipelineRunSchema.orchestrator_run_id + == orchestrator_run_id + ) + ).first() + + if not run_schema: + raise KeyError( + f"Unable to get run for orchestrator run ID " + f"{orchestrator_run_id} and deployment ID {deployment_id}." + ) + + return run_schema.to_model(hydrate=True) + + def get_or_create_run( + self, pipeline_run: PipelineRunRequest + ) -> Tuple[PipelineRunResponse, bool]: + """Gets or creates a pipeline run. + + If a run with the same ID or name already exists, it is returned. + Otherwise, a new run is created. + + Args: + pipeline_run: The pipeline run to get or create. + + # noqa: DAR401 + Raises: + ValueError: If the request does not contain an orchestrator run ID. + EntityExistsError: If a run with the same name already exists. + RuntimeError: If the run fetching failed unexpectedly. + + Returns: + The pipeline run, and a boolean indicating whether the run was + created or not. + """ + if not pipeline_run.orchestrator_run_id: + raise ValueError( + "Unable to get or create run for request with missing " + "orchestrator run ID." + ) + + try: + return ( + self._replace_placeholder_run(pipeline_run=pipeline_run), + True, + ) + except KeyError: + # We were not able to find/replace a placeholder run. This could be + # due to one of the following three reasons: + # (1) There never was a placeholder run for the deployment. This is + # the case if the user ran the pipeline on a schedule. + # (2) There was a placeholder run, but a previous pipeline run + # already used it. This is the case if users rerun a pipeline + # run e.g. from the orchestrator UI, as they will use the same + # deployment_id with a new orchestrator_run_id. + # (3) A step of the same pipeline run already replaced the + # placeholder run. + pass + + try: + # We now try to create a new run. The following will happen in the + # three cases described above: + # (1) The behavior depends on whether we're the first step of the + # pipeline run that's trying to create the run. If yes, the + # `self.create_run(...)` will succeed. If no, a run with the + # same deployment_id and orchestrator_run_id already exists and + # the `self.create_run(...)` call will fail due to the unique + # constraint on those columns. + # (2) Same as (1). + # (3) A step of the same pipeline run replaced the placeholder + # run, which now contains the deployment_id and + # orchestrator_run_id of the run that we're trying to create. + # -> The `self.create_run(...) call will fail due to the unique + # constraint on those columns. + return self.create_run(pipeline_run), True + except (EntityExistsError, IntegrityError) as create_error: + # Creating the run failed with an + # - IntegrityError: This happens when we violated a unique + # constraint, which in turn means a run with the same + # deployment_id and orchestrator_run_id exists. We now fetch and + # return that run. + # - EntityExistsError: This happens when a run with the same name + # already exists. This could be either a different run (in which + # case we want to fail) or a run created by a step of the same + # pipeline run (in which case we want to return it). + # Note: The IntegrityError might also be raised when other unique + # constraints get violated. The only other such constraint is the + # primary key constraint on the run ID, which means we randomly + # generated an existing UUID. In this case the call below will fail, + # but the chance of that happening is so low we don't handle it. + try: + return ( + self._get_run_by_orchestrator_run_id( + orchestrator_run_id=pipeline_run.orchestrator_run_id, + deployment_id=pipeline_run.deployment, + ), + False, + ) + except KeyError: + if isinstance(create_error, EntityExistsError): + # There was a run with the same name which does not share + # the deployment_id and orchestrator_run_id -> We fail with + # the error that run names must be unique. + raise create_error from None + + # This should never happen as the run creation failed with an + # IntegrityError which means a run with the deployment_id and + # orchestrator_run_id exists. + raise RuntimeError( + f"Failed to get or create run: {create_error}" + ) + def list_runs( self, runs_filter_model: PipelineRunFilter, @@ -3292,34 +3458,6 @@ def delete_run(self, run_id: UUID) -> None: session.delete(existing_run) session.commit() - def get_or_create_run( - self, pipeline_run: PipelineRunRequest - ) -> Tuple[PipelineRunResponse, bool]: - """Gets or creates a pipeline run. - - If a run with the same ID or name already exists, it is returned. - Otherwise, a new run is created. - - Args: - pipeline_run: The pipeline run to get or create. - - Returns: - The pipeline run, and a boolean indicating whether the run was - created or not. - """ - # We want to have the 'create' statement in the try block since running - # it first will reduce concurrency issues. - try: - return self.create_run(pipeline_run), True - except (EntityExistsError, IntegrityError): - # Catch both `EntityExistsError`` and `IntegrityError`` exceptions - # since either one can be raised by the database when trying - # to create a new pipeline run with duplicate ID or name. - try: - return self.get_run(pipeline_run.id), False - except KeyError: - return self.get_run(pipeline_run.name), False - def count_runs(self, filter_model: Optional[PipelineRunFilter]) -> int: """Count all pipeline runs. diff --git a/tests/integration/examples/utils.py b/tests/integration/examples/utils.py index 4a92c4cd9d..558dd6ebb9 100644 --- a/tests/integration/examples/utils.py +++ b/tests/integration/examples/utils.py @@ -274,7 +274,7 @@ def wait_and_validate_pipeline_run( if older_than is not None: runs = [r for r in runs if r.created >= older_than] - runs = [r for r in runs if r.status != ExecutionStatus.RUNNING] + runs = [r for r in runs if r.status.is_finished] if len(runs) >= run_no: # We have at least `run_no` runs completed or failed diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index bface6dca0..cbc1613cee 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -2576,8 +2576,9 @@ def test_logs_are_recorded_properly(clean_client): client = Client() store = client.zen_store - with PipelineRunContext(2): - steps = store.list_run_steps(StepRunFilter()) + run_context = PipelineRunContext(1) + with run_context: + steps = run_context.steps step1_logs = steps[0].logs step2_logs = steps[1].logs artifact_store = _load_artifact_store( diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7fb0b08313..b84ec5d4b2 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -463,7 +463,6 @@ def sample_pipeline_deployment_request_model() -> PipelineDeploymentRequest: def sample_pipeline_run_request_model() -> PipelineRunRequest: """Return sample pipeline run view for testing purposes.""" return PipelineRunRequest( - id=uuid4(), name="sample_run_name", config=PipelineConfiguration(name="aria_pipeline"), num_steps=1, diff --git a/tests/unit/orchestrators/test_cache_utils.py b/tests/unit/orchestrators/test_cache_utils.py index fc0a28e661..7585fe865e 100644 --- a/tests/unit/orchestrators/test_cache_utils.py +++ b/tests/unit/orchestrators/test_cache_utils.py @@ -245,21 +245,18 @@ def test_fetching_cached_step_run_uses_latest_candidate( sample_pipeline_run_request_model.deployment = deployment_response.id sample_step_request_model.deployment = deployment_response.id - clean_client.zen_store.create_run(sample_pipeline_run_request_model) - sample_step_request_model.pipeline_run_id = ( - sample_pipeline_run_request_model.id - ) + run = clean_client.zen_store.create_run(sample_pipeline_run_request_model) + sample_step_request_model.pipeline_run_id = run.id response_1 = clean_client.zen_store.create_run_step( sample_step_request_model ) # Create another pipeline run and step run, with the same cache key - sample_pipeline_run_request_model.id = uuid4() sample_pipeline_run_request_model.name = "new_run_name" - clean_client.zen_store.create_run(sample_pipeline_run_request_model) - sample_step_request_model.pipeline_run_id = ( - sample_pipeline_run_request_model.id + new_run = clean_client.zen_store.create_run( + sample_pipeline_run_request_model ) + sample_step_request_model.pipeline_run_id = new_run.id response_2 = clean_client.zen_store.create_run_step( sample_step_request_model ) diff --git a/tests/unit/pipelines/test_base_pipeline.py b/tests/unit/pipelines/test_base_pipeline.py index 85ed5477b6..67904a55d1 100644 --- a/tests/unit/pipelines/test_base_pipeline.py +++ b/tests/unit/pipelines/test_base_pipeline.py @@ -23,6 +23,7 @@ from zenml.config.compiler import Compiler from zenml.config.pipeline_run_configuration import PipelineRunConfiguration from zenml.config.pipeline_spec import PipelineSpec +from zenml.enums import ExecutionStatus from zenml.exceptions import ( PipelineInterfaceError, StackValidationError, @@ -935,3 +936,157 @@ def test_building_a_pipeline_registers_it( clean_client.get_pipeline(name_id_or_prefix=pipeline_instance.name) is not None ) + + +def is_placeholder_request(run_request) -> bool: + """Checks whether a pipeline run request refers to a placeholder run.""" + return ( + run_request.status == ExecutionStatus.INITIALIZING + and run_request.orchestrator_environment == {} + and run_request.orchestrator_run_id is None + ) + + +def test_running_pipeline_creates_and_uses_placeholder_run( + mocker, + clean_client, + empty_pipeline, # noqa: F811 +): + """Tests that running a pipeline creates a placeholder run and later + replaces it with the actual run.""" + mock_create_run = mocker.patch.object( + type(clean_client.zen_store), + "create_run", + wraps=clean_client.zen_store.create_run, + ) + mock_get_or_create_run = mocker.patch.object( + type(clean_client.zen_store), + "get_or_create_run", + wraps=clean_client.zen_store.get_or_create_run, + ) + + pipeline_instance = empty_pipeline + assert clean_client.list_pipeline_runs().total == 0 + + pipeline_instance() + + mock_create_run.assert_called_once() + mock_get_or_create_run.assert_called_once() + + placeholder_run_request = mock_create_run.call_args[0][0] # First arg + assert is_placeholder_request(placeholder_run_request) + + replace_request = mock_get_or_create_run.call_args[0][0] # First arg + assert not is_placeholder_request(replace_request) + + runs = clean_client.list_pipeline_runs() + assert runs.total == 1 + + run = runs[0] + assert run.status == ExecutionStatus.COMPLETED + assert ( + run.orchestrator_environment + == replace_request.orchestrator_environment + ) + assert run.orchestrator_run_id == replace_request.orchestrator_run_id + # assert run.deployment_id + + +def test_rerunning_deloyment_does_not_fail( + mocker, + clean_client, + empty_pipeline, # noqa: F811 +): + """Tests that a deployment can be re-run without issues.""" + mock_create_run = mocker.patch.object( + type(clean_client.zen_store), + "create_run", + wraps=clean_client.zen_store.create_run, + ) + mock_get_or_create_run = mocker.patch.object( + type(clean_client.zen_store), + "get_or_create_run", + wraps=clean_client.zen_store.get_or_create_run, + ) + + pipeline_instance = empty_pipeline + pipeline_instance() + + deployments = clean_client.list_deployments() + assert deployments.total == 1 + deployment = deployments[0] + + stack = clean_client.active_stack + + # Simulate re-running the deployment + stack.deploy_pipeline(deployment) + + assert mock_create_run.call_count == 2 + assert mock_get_or_create_run.call_count == 2 + + placeholder_request = mock_create_run.call_args_list[0][0][0] + assert is_placeholder_request(placeholder_request) + + run_request = mock_create_run.call_args_list[1][0][0] + assert not is_placeholder_request(run_request) + + runs = clean_client.list_pipeline_runs(deployment_id=deployment.id) + assert runs.total == 2 + + +def test_failure_during_initialization_deletes_placeholder_run( + clean_client, + empty_pipeline, # noqa: F811 + mocker, +): + """Tests that when a pipeline run fails during initialization, the + placeholder run that was created for it is deleted.""" + mock_create_run = mocker.patch.object( + type(clean_client.zen_store), + "create_run", + wraps=clean_client.zen_store.create_run, + ) + mock_delete_run = mocker.patch.object( + type(clean_client.zen_store), + "delete_run", + wraps=clean_client.zen_store.delete_run, + ) + + pipeline_instance = empty_pipeline + assert clean_client.list_pipeline_runs().total == 0 + + mocker.patch( + "zenml.stack.stack.Stack.deploy_pipeline", side_effect=RuntimeError + ) + + with pytest.raises(RuntimeError): + pipeline_instance() + + mock_create_run.assert_called_once() + mock_delete_run.assert_called_once() + + assert clean_client.list_pipeline_runs().total == 0 + + +def test_running_scheduled_pipeline_does_not_create_placeholder_run( + mocker, + clean_client, + empty_pipeline, # noqa: F811 +): + """Tests that running a scheduled pipeline does not create a placeholder run + in the database.""" + mock_create_run = mocker.patch.object( + type(clean_client.zen_store), + "create_run", + wraps=clean_client.zen_store.create_run, + ) + pipeline_instance = empty_pipeline + + scheduled_pipeline_instance = pipeline_instance.with_options( + schedule=Schedule(cron_expression="*/5 * * * *") + ) + scheduled_pipeline_instance() + + mock_create_run.assert_called_once() + run_request = mock_create_run.call_args[0][0] # First arg + assert not is_placeholder_request(run_request)