Skip to content

Commit

Permalink
Add filter option for templatable runs (#3000)
Browse files Browse the repository at this point in the history
* Add filter option for templatable runs

* Apply some feedback

* Fix mypy

* Fix unit test

* Formatting
  • Loading branch information
schustmi authored Sep 16, 2024
1 parent 3ffc2d5 commit e9c1d2c
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3748,6 +3748,7 @@ def list_pipeline_runs(
end_time: Optional[Union[datetime, str]] = None,
num_steps: Optional[Union[int, str]] = None,
unlisted: Optional[bool] = None,
templatable: Optional[bool] = None,
tag: Optional[str] = None,
hydrate: bool = False,
) -> Page[PipelineRunResponse]:
Expand Down Expand Up @@ -3778,6 +3779,7 @@ def list_pipeline_runs(
end_time: The end_time for the pipeline run
num_steps: The number of steps for the pipeline run
unlisted: If the runs should be unlisted or not.
templatable: If the runs should be templatable or not.
tag: Tag to filter by.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Expand Down Expand Up @@ -3811,6 +3813,7 @@ def list_pipeline_runs(
num_steps=num_steps,
tag=tag,
unlisted=unlisted,
templatable=templatable,
)
runs_filter_model.set_scope_workspace(self.active_workspace.id)
return self.zen_store.list_runs(
Expand Down
52 changes: 51 additions & 1 deletion src/zenml/models/v2/core/pipeline_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ class PipelineRunResponseMetadata(WorkspaceScopedResponseMetadata):
default=None,
description="Template used for the pipeline run.",
)
is_templatable: bool = Field(
description="Whether a template can be created from this run.",
)


class PipelineRunResponseResources(WorkspaceScopedResponseResources):
Expand Down Expand Up @@ -477,6 +480,15 @@ def template_id(self) -> Optional[UUID]:
"""
return self.get_metadata().template_id

@property
def is_templatable(self) -> bool:
"""The `is_templatable` property.
Returns:
the value of the property.
"""
return self.get_metadata().is_templatable

@property
def model_version(self) -> Optional[ModelVersionResponse]:
"""The `model_version` property.
Expand Down Expand Up @@ -511,6 +523,7 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
"stack_id",
"template_id",
"pipeline_name",
"templatable",
]
name: Optional[str] = Field(
default=None,
Expand Down Expand Up @@ -584,6 +597,7 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
union_mode="left_to_right",
)
unlisted: Optional[bool] = None
templatable: Optional[bool] = None

def get_custom_filters(
self,
Expand All @@ -595,7 +609,7 @@ def get_custom_filters(
"""
custom_filters = super().get_custom_filters()

from sqlmodel import and_
from sqlmodel import and_, col, or_

from zenml.zen_stores.schemas import (
CodeReferenceSchema,
Expand Down Expand Up @@ -668,4 +682,40 @@ def get_custom_filters(
)
custom_filters.append(run_template_filter)

if self.templatable is not None:
if self.templatable is True:
templatable_filter = and_(
# The following condition is not perfect as it does not
# consider stacks with custom flavor components or local
# components, but the best we can do currently with our
# table columns.
PipelineRunSchema.deployment_id
== PipelineDeploymentSchema.id,
PipelineDeploymentSchema.build_id
== PipelineBuildSchema.id,
col(PipelineBuildSchema.is_local).is_(False),
col(PipelineBuildSchema.stack_id).is_not(None),
)
else:
templatable_filter = or_(
col(PipelineRunSchema.deployment_id).is_(None),
and_(
PipelineRunSchema.deployment_id
== PipelineDeploymentSchema.id,
col(PipelineDeploymentSchema.build_id).is_(None),
),
and_(
PipelineRunSchema.deployment_id
== PipelineDeploymentSchema.id,
PipelineDeploymentSchema.build_id
== PipelineBuildSchema.id,
or_(
col(PipelineBuildSchema.is_local).is_(True),
col(PipelineBuildSchema.stack_id).is_(None),
),
),
)

custom_filters.append(templatable_filter)

return custom_filters
10 changes: 10 additions & 0 deletions src/zenml/zen_stores/schemas/pipeline_run_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,15 @@ def to_model(
)
metadata = None
if include_metadata:
is_templatable = False
if (
self.deployment
and self.deployment.build
and not self.deployment.build.is_local
and self.deployment.build.stack
):
is_templatable = True

steps = {step.name: step.to_model() for step in self.step_runs}

metadata = PipelineRunResponseMetadata(
Expand All @@ -346,6 +355,7 @@ def to_model(
template_id=self.deployment.template_id
if self.deployment
else None,
is_templatable=is_templatable,
)

resources = None
Expand Down
1 change: 1 addition & 0 deletions src/zenml/zen_stores/schemas/run_template_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def to_model(
if (
self.source_deployment
and self.source_deployment.build
and not self.source_deployment.build.is_local
and self.source_deployment.build.stack
):
runnable = True
Expand Down
1 change: 1 addition & 0 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,7 @@ def filter_and_paginate(
RuntimeError: if the schema does not have a `to_model` method.
"""
query = filter_model.apply_filter(query=query, table=table)
query = query.distinct()

# Get the total amount of items in the database for a given query
custom_fetch_result: Optional[Sequence[Any]] = None
Expand Down
1 change: 1 addition & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def sample_pipeline_run(
metadata=PipelineRunResponseMetadata(
workspace=sample_workspace_model,
config=PipelineConfiguration(name="aria_pipeline"),
is_templatable=False,
),
)

Expand Down

0 comments on commit e9c1d2c

Please sign in to comment.