Skip to content

Commit

Permalink
Fix build reuse after stack updates (#3251)
Browse files Browse the repository at this point in the history
* Fix build reuse after stack updates

* Tests

* Linting

* Add missing param to client method
  • Loading branch information
schustmi authored Dec 12, 2024
1 parent b73b567 commit df8d0a8
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 3 deletions.
7 changes: 7 additions & 0 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2663,11 +2663,13 @@ def list_builds(
user_id: Optional[Union[str, UUID]] = None,
pipeline_id: Optional[Union[str, UUID]] = None,
stack_id: Optional[Union[str, UUID]] = None,
container_registry_id: Optional[Union[UUID, str]] = None,
is_local: Optional[bool] = None,
contains_code: Optional[bool] = None,
zenml_version: Optional[str] = None,
python_version: Optional[str] = None,
checksum: Optional[str] = None,
stack_checksum: Optional[str] = None,
hydrate: bool = False,
) -> Page[PipelineBuildResponse]:
"""List all builds.
Expand All @@ -2684,11 +2686,14 @@ def list_builds(
user_id: The id of the user to filter by.
pipeline_id: The id of the pipeline to filter by.
stack_id: The id of the stack to filter by.
container_registry_id: The id of the container registry to
filter by.
is_local: Use to filter local builds.
contains_code: Use to filter builds that contain code.
zenml_version: The version of ZenML to filter by.
python_version: The Python version to filter by.
checksum: The build checksum to filter by.
stack_checksum: The stack checksum to filter by.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Expand All @@ -2707,11 +2712,13 @@ def list_builds(
user_id=user_id,
pipeline_id=pipeline_id,
stack_id=stack_id,
container_registry_id=container_registry_id,
is_local=is_local,
contains_code=contains_code,
zenml_version=zenml_version,
python_version=python_version,
checksum=checksum,
stack_checksum=stack_checksum,
)
build_filter_model.set_scope_workspace(self.active_workspace.id)
return self.zen_store.list_builds(
Expand Down
52 changes: 50 additions & 2 deletions src/zenml/models/v2/core/pipeline_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Models representing pipeline builds."""

import json
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union
from uuid import UUID

from pydantic import Field
Expand All @@ -31,6 +31,8 @@
from zenml.models.v2.misc.build_item import BuildItem

if TYPE_CHECKING:
from sqlalchemy.sql.elements import ColumnElement

from zenml.models.v2.core.pipeline import PipelineResponse
from zenml.models.v2.core.stack import StackResponse

Expand Down Expand Up @@ -446,6 +448,11 @@ def contains_code(self) -> bool:
class PipelineBuildFilter(WorkspaceScopedFilter):
"""Model to enable advanced filtering of all pipeline builds."""

FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
*WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
"container_registry_id",
]

workspace_id: Optional[Union[UUID, str]] = Field(
description="Workspace for this pipeline build.",
default=None,
Expand All @@ -462,7 +469,12 @@ class PipelineBuildFilter(WorkspaceScopedFilter):
union_mode="left_to_right",
)
stack_id: Optional[Union[UUID, str]] = Field(
description="Stack used for the Pipeline Run",
description="Stack associated with the pipeline build.",
default=None,
union_mode="left_to_right",
)
container_registry_id: Optional[Union[UUID, str]] = Field(
description="Container registry associated with the pipeline build.",
default=None,
union_mode="left_to_right",
)
Expand All @@ -484,3 +496,39 @@ class PipelineBuildFilter(WorkspaceScopedFilter):
checksum: Optional[str] = Field(
description="The build checksum.", default=None
)
stack_checksum: Optional[str] = Field(
description="The stack checksum.", default=None
)

def get_custom_filters(
self,
) -> List["ColumnElement[bool]"]:
"""Get custom filters.
Returns:
A list of custom filters.
"""
custom_filters = super().get_custom_filters()

from sqlmodel import and_

from zenml.enums import StackComponentType
from zenml.zen_stores.schemas import (
PipelineBuildSchema,
StackComponentSchema,
StackCompositionSchema,
StackSchema,
)

if self.container_registry_id:
container_registry_filter = and_(
PipelineBuildSchema.stack_id == StackSchema.id,
StackSchema.id == StackCompositionSchema.stack_id,
StackCompositionSchema.component_id == StackComponentSchema.id,
StackComponentSchema.type
== StackComponentType.CONTAINER_REGISTRY.value,
StackComponentSchema.id == self.container_registry_id,
)
custom_filters.append(container_registry_filter)

return custom_filters
12 changes: 12 additions & 0 deletions src/zenml/pipelines/build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,11 @@ def find_existing_build(
client = Client()
stack = client.active_stack

if not stack.container_registry:
# There can be no non-local builds that we can reuse if there is no
# container registry in the stack.
return None

python_version_prefix = ".".join(platform.python_version_tuple()[:2])
required_builds = stack.get_docker_builds(deployment=deployment)

Expand All @@ -263,6 +268,13 @@ def find_existing_build(
sort_by="desc:created",
size=1,
stack_id=stack.id,
# Until we implement stack versioning, users can still update their
# stack to update/remove the container registry. In that case, we might
# try to pull an image from a container registry that we don't have
# access to. This is why we add an additional check for the container
# registry ID here. (This is still not perfect as users can update the
# container registry URI or config, but the best we can do)
container_registry_id=stack.container_registry.id,
# The build is local and it's not clear whether the images
# exist on the current machine or if they've been overwritten.
# TODO: Should we support this by storing the unique Docker ID for
Expand Down
20 changes: 19 additions & 1 deletion tests/unit/pipelines/test_build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,9 @@ def test_local_repo_verification(
assert isinstance(code_repo, StubCodeRepository)


def test_finding_existing_build(mocker, sample_deployment_response_model):
def test_finding_existing_build(
mocker, sample_deployment_response_model, remote_container_registry
):
"""Tests finding an existing build."""
mock_list_builds = mocker.patch(
"zenml.client.Client.list_builds",
Expand Down Expand Up @@ -551,14 +553,30 @@ def test_finding_existing_build(mocker, sample_deployment_response_model):
],
)

build_utils.find_existing_build(
deployment=sample_deployment_response_model,
code_repository=StubCodeRepository(),
)
# No container registry -> no non-local build to pull
mock_list_builds.assert_not_called()

mocker.patch.object(
Stack,
"container_registry",
new_callable=mocker.PropertyMock,
return_value=remote_container_registry,
)

build = build_utils.find_existing_build(
deployment=sample_deployment_response_model,
code_repository=StubCodeRepository(),
)

mock_list_builds.assert_called_once_with(
sort_by="desc:created",
size=1,
stack_id=Client().active_stack.id,
container_registry_id=remote_container_registry.id,
is_local=False,
contains_code=False,
zenml_version=zenml.__version__,
Expand Down

0 comments on commit df8d0a8

Please sign in to comment.