Skip to content

Commit

Permalink
MVP support for Kaggle Packages
Browse files Browse the repository at this point in the history
Support tracking accessed datasources and writing / reading them to a
requirements.yaml file.
Support a "Package Scope" using that file which applies those datasource
versions at runtime if the user requests a datasource without an
explicit version.
Support importing a Package.  This uses a handle equivalent to a
Notebook handle, and like Notebooks is currently limited to the latest
version, with version support coming soon.
Support Package Asset files whose path honors the current Package Scope.
  • Loading branch information
dster2 committed Dec 18, 2024
1 parent 5fdb159 commit 31b4c18
Show file tree
Hide file tree
Showing 7 changed files with 288 additions and 9 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"requests",
"tqdm",
"packaging",
"pyyaml",
]

[project.urls]
Expand Down
2 changes: 2 additions & 0 deletions src/kagglehub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from kagglehub.datasets import dataset_download, dataset_upload
from kagglehub.models import model_download, model_upload
from kagglehub.notebooks import notebook_output_download
from kagglehub.packages import PackageScope, get_package_asset_path, package_import
from kagglehub.requirements import read_requirements, write_requirements

registry.model_resolver.add_implementation(http_resolver.ModelHttpResolver())
registry.model_resolver.add_implementation(kaggle_cache_resolver.ModelKaggleCacheResolver())
Expand Down
32 changes: 25 additions & 7 deletions src/kagglehub/handle.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Functions to parse resource handles."""

import abc
from dataclasses import dataclass
from typing import Optional
from dataclasses import asdict, dataclass
from typing import Optional, Self

from kagglehub.config import get_kaggle_api_endpoint

Expand All @@ -15,15 +15,15 @@
NUM_UNVERSIONED_NOTEBOOK_PARTS = 2 # e.g.: <owner>/<notebook>


@dataclass
@dataclass(frozen=True)
class ResourceHandle:
@abc.abstractmethod
def to_url(self) -> str:
"""Returns URL to the resource detail page."""
pass


@dataclass
@dataclass(frozen=True)
class ModelHandle(ResourceHandle):
owner: str
model: str
Expand All @@ -34,6 +34,11 @@ class ModelHandle(ResourceHandle):
def is_versioned(self) -> bool:
return self.version is not None and self.version > 0

def with_version(self, version: int) -> Self:
return ModelHandle(
owner=self.owner, model=self.model, framework=self.framework, variation=self.variation, version=version
)

def __str__(self) -> str:
handle_str = f"{self.owner}/{self.model}/{self.framework}/{self.variation}"
if self.is_versioned():
Expand All @@ -48,7 +53,7 @@ def to_url(self) -> str:
return f"{endpoint}/models/{self.owner}/{self.model}/{self.framework}/{self.variation}"


@dataclass
@dataclass(frozen=True)
class DatasetHandle(ResourceHandle):
owner: str
dataset: str
Expand All @@ -57,6 +62,9 @@ class DatasetHandle(ResourceHandle):
def is_versioned(self) -> bool:
return self.version is not None and self.version > 0

def with_version(self, version: int) -> Self:
return DatasetHandle(owner=self.owner, dataset=self.dataset, version=version)

def __str__(self) -> str:
handle_str = f"{self.owner}/{self.dataset}"
if self.is_versioned():
Expand All @@ -71,7 +79,7 @@ def to_url(self) -> str:
return base_url


@dataclass
@dataclass(frozen=True)
class CompetitionHandle(ResourceHandle):
competition: str

Expand All @@ -85,10 +93,11 @@ def to_url(self) -> str:
return base_url


@dataclass
@dataclass(frozen=True)
class NotebookHandle(ResourceHandle):
owner: str
notebook: str
version: Optional[int] = None

def __str__(self) -> str:
handle_str = f"{self.owner}/{self.notebook}"
Expand All @@ -100,6 +109,10 @@ def to_url(self) -> str:
return base_url


class PackageHandle(NotebookHandle):
pass


def parse_dataset_handle(handle: str) -> DatasetHandle:
parts = handle.split("/")

Expand Down Expand Up @@ -177,3 +190,8 @@ def parse_notebook_handle(handle: str) -> NotebookHandle:
msg = f"Invalid notebook handle: {handle}"
raise ValueError(msg)
return NotebookHandle(owner=parts[0], notebook=parts[1])


def parse_package_handle(handle: str) -> PackageHandle:
notebook_handle = parse_notebook_handle(handle)
return PackageHandle(**asdict(notebook_handle))
25 changes: 23 additions & 2 deletions src/kagglehub/http_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from kagglehub.clients import KaggleApiV1Client
from kagglehub.exceptions import UnauthenticatedError
from kagglehub.handle import CompetitionHandle, DatasetHandle, ModelHandle, NotebookHandle, ResourceHandle
from kagglehub.packages import get_package_datasource_version_number
from kagglehub.requirements import register_accessed_datasource
from kagglehub.resolver import Resolver

DATASET_CURRENT_VERSION_FIELD = "currentVersionNumber"
Expand Down Expand Up @@ -45,6 +47,7 @@ def __call__(

if not api_client.has_credentials():
if cached_path:
register_accessed_datasource(h, None)
return cached_path
raise UnauthenticatedError()

Expand All @@ -60,10 +63,12 @@ def __call__(
)
except requests.exceptions.ConnectionError:
if cached_path:
register_accessed_datasource(h, None)
return cached_path
raise

if not download_needed and cached_path:
register_accessed_datasource(h, None)
return cached_path
else:
# Download, extract, then delete the archive.
Expand All @@ -77,19 +82,22 @@ def __call__(
if cached_path:
if os.path.exists(archive_path):
os.remove(archive_path)
register_accessed_datasource(h, None)
return cached_path
raise

if not download_needed and cached_path:
if os.path.exists(archive_path):
os.remove(archive_path)
register_accessed_datasource(h, None)
return cached_path

os.makedirs(out_path, exist_ok=True)
_extract_archive(archive_path, out_path)
os.remove(archive_path)

mark_as_complete(h, path)
register_accessed_datasource(h, None)
return out_path


Expand All @@ -101,11 +109,13 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str:
api_client = KaggleApiV1Client()

original_h = h
if not h.is_versioned():
h.version = _get_current_version(api_client, h)
h = h.with_version(_get_current_version(api_client, h))

dataset_path = load_from_cache(h, path)
if dataset_path and not force_download:
register_accessed_datasource(original_h, h.version)
return dataset_path # Already cached
elif dataset_path and force_download:
delete_from_cache(h, path)
Expand Down Expand Up @@ -136,6 +146,7 @@ def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_downlo
os.remove(archive_path)

mark_as_complete(h, path)
register_accessed_datasource(original_h, h.version)
return out_path


Expand All @@ -147,11 +158,13 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str:
api_client = KaggleApiV1Client()

original_h = h
if not h.is_versioned():
h.version = _get_current_version(api_client, h)
h = h.with_version(_get_current_version(api_client, h))

model_path = load_from_cache(h, path)
if model_path and not force_download:
register_accessed_datasource(original_h, h.version)
return model_path # Already cached
elif model_path and force_download:
delete_from_cache(h, path)
Expand Down Expand Up @@ -199,6 +212,7 @@ def _inner_download_file(file: str) -> None:
)

mark_as_complete(h, path)
register_accessed_datasource(original_h, h.version)
return out_path


Expand All @@ -212,6 +226,7 @@ def __call__(self, h: NotebookHandle, path: Optional[str] = None, *, force_downl

cached_response = load_from_cache(h, path)
if cached_response and not force_download:
register_accessed_datasource(h, None)
return cached_response # Already cached
elif cached_response and force_download:
delete_from_cache(h, path)
Expand Down Expand Up @@ -249,6 +264,7 @@ def _inner_download_file(filepath: str) -> None:
mark_as_complete(h, path)

# TODO(b/377510971): when notebook is a Kaggle utility script, update sys.path
register_accessed_datasource(h, None)
return str(output_root)

def _list_files(self, api_client: KaggleApiV1Client, h: NotebookHandle) -> tuple[list[str], bool]:
Expand Down Expand Up @@ -278,6 +294,11 @@ def _extract_archive(archive_path: str, out_path: str) -> None:


def _get_current_version(api_client: KaggleApiV1Client, h: ResourceHandle) -> int:
# Check if there's a Package in scope which has stored a version number used when it was created.
version_from_package = get_package_datasource_version_number(h)
if version_from_package is not None:
return version_from_package

if isinstance(h, ModelHandle):
json_response = api_client.get(_build_get_instance_url_path(h), h)
if MODEL_INSTANCE_VERSION_FIELD not in json_response:
Expand Down
9 changes: 9 additions & 0 deletions src/kagglehub/kaggle_cache_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from kagglehub.exceptions import BackendError
from kagglehub.handle import CompetitionHandle, DatasetHandle, ModelHandle
from kagglehub.logger import EXTRA_CONSOLE_BLOCK
from kagglehub.requirements import register_accessed_datasource
from kagglehub.resolver import Resolver

KAGGLE_CACHE_MOUNT_FOLDER_ENV_VAR_NAME = "KAGGLE_CACHE_MOUNT_FOLDER"
Expand Down Expand Up @@ -83,7 +84,9 @@ def __call__(
f"You can acces the other files othe attached competition at '{cached_path}'"
)
raise ValueError(msg)
register_accessed_datasource(h, None)
return cached_filepath
register_accessed_datasource(h, None)
return cached_path


Expand Down Expand Up @@ -124,6 +127,7 @@ def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_downlo

base_mount_path = os.getenv(KAGGLE_CACHE_MOUNT_FOLDER_ENV_VAR_NAME, DEFAULT_KAGGLE_CACHE_MOUNT_FOLDER)
cached_path = f"{base_mount_path}/{result['mountSlug']}"
version_number = result.get("versionNumber") # None if missing

if not os.path.exists(cached_path):
# Only print this if the dataset is not already mounted.
Expand All @@ -150,7 +154,9 @@ def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_downlo
f"You can acces the other files othe attached dataset at '{cached_path}'"
)
raise ValueError(msg)
register_accessed_datasource(h, version_number)
return cached_filepath
register_accessed_datasource(h, version_number)
return cached_path


Expand Down Expand Up @@ -193,6 +199,7 @@ def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download

base_mount_path = os.getenv(KAGGLE_CACHE_MOUNT_FOLDER_ENV_VAR_NAME, DEFAULT_KAGGLE_CACHE_MOUNT_FOLDER)
cached_path = f"{base_mount_path}/{result['mountSlug']}"
version_number = result.get("versionNumber") # None if missing

if not os.path.exists(cached_path):
# Only print this if the model is not already mounted.
Expand All @@ -219,5 +226,7 @@ def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download
f"You can access the other files of the attached model at '{cached_path}'"
)
raise ValueError(msg)
register_accessed_datasource(h, version_number)
return cached_filepath
register_accessed_datasource(h, version_number)
return cached_path
Loading

0 comments on commit 31b4c18

Please sign in to comment.