Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend notebook source replacement code to other objects apart from ZenML steps #2919

Merged
merged 7 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions src/zenml/config/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,20 +234,16 @@ class NotebookSource(Source):
"""Source representing an object defined in a notebook.

Attributes:
code_path: Path where the notebook cell code for this source is
uploaded.
replacement_module: Name of the module from which this source should
be loaded in case the code is not running in a notebook.
artifact_store_id: ID of the artifact store in which the replacement
module code is stored.
"""

code_path: Optional[str] = None
replacement_module: Optional[str] = None
artifact_store_id: Optional[UUID] = None
type: SourceType = SourceType.NOTEBOOK

# Private attribute that is used to store the code but should not be
# serialized
_cell_code: Optional[str] = None

@field_validator("type")
@classmethod
def _validate_type(cls, value: SourceType) -> SourceType:
Expand Down
4 changes: 4 additions & 0 deletions src/zenml/materializers/base_materializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def __new__(
associated_type, cls
)

from zenml.utils import notebook_utils

notebook_utils.try_to_save_notebook_cell_code(cls)

return cls


Expand Down
77 changes: 56 additions & 21 deletions src/zenml/new/pipelines/run_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""Utility functions for running pipelines."""

import hashlib
import time
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
from uuid import UUID

from pydantic import BaseModel

from zenml import constants
from zenml.client import Client
from zenml.config.pipeline_run_configuration import PipelineRunConfiguration
from zenml.config.source import SourceType
from zenml.config.source import Source, SourceType
from zenml.config.step_configurations import StepConfigurationUpdate
from zenml.enums import ExecutionStatus
from zenml.logger import get_logger
Expand All @@ -23,7 +24,7 @@
)
from zenml.orchestrators.utils import get_run_name
from zenml.stack import Flavor, Stack
from zenml.utils import code_utils, notebook_utils
from zenml.utils import code_utils, notebook_utils, source_utils
from zenml.zen_stores.base_zen_store import BaseZenStore

if TYPE_CHECKING:
Expand Down Expand Up @@ -269,9 +270,8 @@ def upload_notebook_cell_code_if_necessary(
RuntimeError: If the code for one of the steps that will run out of
process cannot be extracted into a python file.
"""
code_archive = code_utils.CodeArchive(root=None)
should_upload = False
sources_that_require_upload = []
resolved_notebook_sources = source_utils.get_resolved_notebook_sources()

for step in deployment.step_configurations.values():
source = step.spec.source
Expand All @@ -282,7 +282,9 @@ def upload_notebook_cell_code_if_necessary(
or step.config.step_operator
):
should_upload = True
cell_code = getattr(step.spec.source, "_cell_code", None)
cell_code = resolved_notebook_sources.get(
source.import_path, None
)

# Code does not run in-process, which means we need to
# extract the step code into a python file
Expand All @@ -296,20 +298,53 @@ def upload_notebook_cell_code_if_necessary(
"of a notebook."
)

notebook_utils.warn_about_notebook_cell_magic_commands(
cell_code=cell_code
)
if should_upload:
logger.info("Uploading notebook code...")

code_hash = hashlib.sha1(cell_code.encode()).hexdigest() # nosec
module_name = f"extracted_notebook_code_{code_hash}"
file_name = f"{module_name}.py"
code_archive.add_file(source=cell_code, destination=file_name)
for _, cell_code in resolved_notebook_sources.items():
notebook_utils.warn_about_notebook_cell_magic_commands(
cell_code=cell_code
)
module_name = notebook_utils.compute_cell_replacement_module_name(
cell_code=cell_code
)
file_name = f"{module_name}.py"

code_utils.upload_notebook_code(
artifact_store=stack.artifact_store,
cell_code=cell_code,
file_name=file_name,
)

setattr(step.spec.source, "replacement_module", module_name)
sources_that_require_upload.append(source)
all_deployment_sources = get_all_sources_from_value(deployment)

if should_upload:
logger.info("Archiving notebook code...")
code_path = code_utils.upload_code_if_necessary(code_archive)
for source in sources_that_require_upload:
setattr(source, "code_path", code_path)
for source in all_deployment_sources:
if source.type == SourceType.NOTEBOOK:
setattr(source, "artifact_store_id", stack.artifact_store.id)

logger.info("Upload finished.")


def get_all_sources_from_value(value: Any) -> List[Source]:
"""Get all source objects from a value.

Args:
value: The value from which to get all the source objects.

Returns:
List of source objects for the given value.
"""
sources = []
if isinstance(value, Source):
sources.append(value)
elif isinstance(value, BaseModel):
for v in value.__dict__.values():
sources.extend(get_all_sources_from_value(v))
elif isinstance(value, Dict):
for v in value.values():
sources.extend(get_all_sources_from_value(v))
elif isinstance(value, (List, Set, tuple)):
for v in value:
sources.extend(get_all_sources_from_value(v))

return sources
60 changes: 60 additions & 0 deletions src/zenml/utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
if TYPE_CHECKING:
from git.repo.base import Repo

from zenml.artifact_stores import BaseArtifactStore


logger = get_logger(__name__)

Expand Down Expand Up @@ -242,3 +244,61 @@ def download_and_extract_code(code_path: str, extract_dir: str) -> None:

shutil.unpack_archive(filename=download_path, extract_dir=extract_dir)
os.remove(download_path)


def _get_notebook_upload_dir(artifact_store: "BaseArtifactStore") -> str:
"""Get the upload directory for code extracted from notebook cells.

Args:
artifact_store: The artifact store in which the directory should be.

Returns:
The upload directory for code extracted from notebook cells.
"""
return os.path.join(artifact_store.path, "notebook_code")


def upload_notebook_code(
artifact_store: "BaseArtifactStore", cell_code: str, file_name: str
) -> None:
"""Upload code extracted from a notebook cell.

Args:
artifact_store: The artifact store in which to upload the code.
cell_code: The notebook cell code.
file_name: The filename to use for storing the cell code.
"""
upload_dir = _get_notebook_upload_dir(artifact_store=artifact_store)
fileio.makedirs(upload_dir)
upload_path = os.path.join(upload_dir, file_name)

if not fileio.exists(upload_path):
with fileio.open(upload_path, "wb") as f:
f.write(cell_code.encode())

logger.info("Uploaded notebook cell code to %s.", upload_path)


def download_notebook_code(
artifact_store: "BaseArtifactStore", file_name: str, download_path: str
) -> None:
"""Download code extracted from a notebook cell.

Args:
artifact_store: The artifact store from which to download the code.
file_name: The name of the code file.
download_path: The local path where the file should be downloaded to.

Raises:
FileNotFoundError: If no file with the given filename exists in this
artifact store.
"""
code_dir = _get_notebook_upload_dir(artifact_store=artifact_store)
code_path = os.path.join(code_dir, file_name)

if not fileio.exists(code_path):
raise FileNotFoundError(
f"Notebook code at path {code_path} not found."
)

fileio.copy(code_path, download_path)
14 changes: 14 additions & 0 deletions src/zenml/utils/notebook_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# permissions and limitations under the License.
"""Notebook utilities."""

import hashlib
from typing import Any, Callable, Optional, TypeVar, Union

from zenml.environment import Environment
Expand Down Expand Up @@ -120,3 +121,16 @@ def warn_about_notebook_cell_magic_commands(cell_code: str) -> None:
"of these lines contain Jupyter notebook magic commands, "
"remove them and try again."
)


def compute_cell_replacement_module_name(cell_code: str) -> str:
"""Compute the replacement module name for a given cell code.

Args:
cell_code: The code of the notebook cell.

Returns:
The replacement module name.
"""
code_hash = hashlib.sha1(cell_code.encode()).hexdigest() # nosec
return f"extracted_notebook_code_{code_hash}"
88 changes: 67 additions & 21 deletions src/zenml/utils/source_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,8 @@
from distutils.sysconfig import get_python_lib
from pathlib import Path, PurePath
from types import BuiltinFunctionType, FunctionType, ModuleType
from typing import (
Any,
Callable,
Iterator,
Optional,
Type,
Union,
cast,
)
from typing import Any, Callable, Dict, Iterator, Optional, Type, Union, cast
from uuid import UUID

from zenml.config.source import (
CodeRepositorySource,
Expand Down Expand Up @@ -69,6 +62,8 @@
)

_SHARED_TEMPDIR: Optional[str] = None
_resolved_notebook_sources: Dict[str, str] = {}
_notebook_modules: Dict[str, UUID] = {}
schustmi marked this conversation as resolved.
Show resolved Hide resolved


def load(source: Union[Source, str]) -> Any:
Expand Down Expand Up @@ -237,13 +232,23 @@ def resolve(
source_type = SourceType.UNKNOWN
elif source_type == SourceType.NOTEBOOK:
source = NotebookSource(
module=module_name,
module="__main__",
attribute=attribute_name,
type=source_type,
)
# Private attributes are ignored by pydantic if passed in the __init__
# method, so we set this afterwards
source._cell_code = notebook_utils.load_notebook_cell_code(obj)

if module_name in _notebook_modules:
source.replacement_module = module_name
source.artifact_store_id = _notebook_modules[module_name]
elif cell_code := notebook_utils.load_notebook_cell_code(obj):
replacement_module = (
notebook_utils.compute_cell_replacement_module_name(
cell_code=cell_code
)
)
source.replacement_module = replacement_module
_resolved_notebook_sources[source.import_path] = cell_code

return source

return Source(
Expand Down Expand Up @@ -387,6 +392,9 @@ def get_source_type(module: ModuleType) -> SourceType:
Returns:
The source type.
"""
if module.__name__ in _notebook_modules:
return SourceType.NOTEBOOK

try:
file_path = inspect.getfile(module)
except (TypeError, OSError):
Expand Down Expand Up @@ -582,33 +590,61 @@ def _try_to_load_notebook_source(source: NotebookSource) -> Any:

Raises:
RuntimeError: If the source can't be loaded.
FileNotFoundError: If the file containing the notebook cell code can't
be found.

Returns:
The loaded object.
"""
if not source.code_path or not source.replacement_module:
if not source.replacement_module:
raise RuntimeError(
f"Failed to load {source.import_path}. This object was defined in "
"a notebook and you're trying to load it outside of a notebook. "
"This is currently only enabled for ZenML steps."
"This is currently only enabled for ZenML steps and materializers. "
"To enable this for your custom classes or functions, use the "
"`zenml.utils.notebook_utils.enable_notebook_code_extraction` "
"decorator."
)

extract_dir = _get_shared_temp_dir()
file_path = os.path.join(extract_dir, f"{source.replacement_module}.py")
file_name = f"{source.replacement_module}.py"
file_path = os.path.join(extract_dir, file_name)

if not os.path.exists(file_path):
from zenml.client import Client
from zenml.utils import code_utils

artifact_store = Client().active_stack.artifact_store

if (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the notebook cell content is stored in a different artifact store than the active one we are failing here. Similar to what we did with the pipeline artifacts, could we not try to initialize this other artifact store and use it to load it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, seems like the old bug we had

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use active only in case source.artifact_store_id is None, otherwise will work with that one. Is this what you meant @bcdurak or something else?

source.artifact_store_id
and source.artifact_store_id != artifact_store.id
):
raise RuntimeError(
"Notebook cell code not stored in active artifact store."
)

logger.info(
"Downloading notebook cell content from `%s` to load `%s`.",
source.code_path,
"Downloading notebook cell content to load `%s`.",
source.import_path,
)

code_utils.download_and_extract_code(
code_path=source.code_path, extract_dir=extract_dir
)
try:
code_utils.download_notebook_code(
artifact_store=artifact_store,
file_name=file_name,
download_path=file_path,
)
except FileNotFoundError:
if not source.artifact_store_id:
raise FileNotFoundError(
"Unable to find notebook code file. This might be because "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the error message above, this one is now a bit misleading as it will fail if the file is stored in a different artifact store.

"the file is stored in a different artifact store."
)

raise
else:
_notebook_modules[source.replacement_module] = artifact_store.id
try:
module = _load_module(
module_name=source.replacement_module, import_root=extract_dir
Expand Down Expand Up @@ -734,3 +770,13 @@ def validate_source_class(
return True
else:
return False


def get_resolved_notebook_sources() -> Dict[str, str]:
"""Get all notebook sources that were resolved in this process.

Returns:
Dictionary mapping the import path of notebook sources to the code
of their notebook cell.
"""
return _resolved_notebook_sources.copy()
Loading