Skip to content

Commit

Permalink
MVP support for Kaggle Packages
Browse files Browse the repository at this point in the history
Support tracking accessed datasources and writing / reading them to a
requirements.yaml file.
Support a "Package Scope" using that file which applies those datasource
versions at runtime if the user requests a datasource without an
explicit version.
Support importing a Package.  This uses a handle equivalent to a
Notebook handle, and like Notebooks is currently limited to the latest
version, with version support coming soon.
Support Package Asset files whose path honors the current Package Scope.
  • Loading branch information
dster2 committed Jan 28, 2025
1 parent 7e4f45f commit 10782d2
Show file tree
Hide file tree
Showing 26 changed files with 911 additions and 99 deletions.
47 changes: 47 additions & 0 deletions integration_tests/test_package_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import sys
import unittest

from kagglehub import package_import

from .utils import create_test_cache, unauthenticated

UNVERSIONED_HANDLE = "dster/package-test"
VERSIONED_HANDLE = "dster/package-test/versions/1"


class TestPackageImport(unittest.TestCase):

def tearDown(self) -> None:
# Clear any imported packages from sys.modules.
for name in list(sys.modules.keys()):
if name.startswith("kagglehub_package"):
del sys.modules[name]

def test_package_versioned_succeeds(self) -> None:
with create_test_cache():
package = package_import(VERSIONED_HANDLE)

self.assertIn("foo", dir(package))
self.assertEqual("bar", package.foo())

def test_package_unversioned_succeeds(self) -> None:
with create_test_cache():
package = package_import(UNVERSIONED_HANDLE)

self.assertIn("foo", dir(package))
self.assertEqual("baz", package.foo())

def test_download_private_package_succeeds(self) -> None:
with create_test_cache():
package = package_import("integrationtester/kagglehub-test-private-package")

self.assertIn("foo", dir(package))
self.assertEqual("bar", package.foo())

def test_public_package_with_unauthenticated_succeeds(self) -> None:
with create_test_cache():
with unauthenticated():
package = package_import(UNVERSIONED_HANDLE)

self.assertIn("foo", dir(package))
self.assertEqual("baz", package.foo())
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"tqdm",
"packaging",
"model_signing",
"pyyaml",
]

[project.urls]
Expand Down
3 changes: 2 additions & 1 deletion src/kagglehub/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.3.6"
__version__ = "0.3.7"

import kagglehub.logger # configures the library logger.
from kagglehub import colab_cache_resolver, http_resolver, kaggle_cache_resolver, registry
Expand All @@ -7,6 +7,7 @@
from kagglehub.datasets import KaggleDatasetAdapter, dataset_download, dataset_upload, load_dataset
from kagglehub.models import model_download, model_upload
from kagglehub.notebooks import notebook_output_download
from kagglehub.packages import get_package_asset_path, package_import
from kagglehub.utility_scripts import utility_script_install

registry.model_resolver.add_implementation(http_resolver.ModelHttpResolver())
Expand Down
2 changes: 1 addition & 1 deletion src/kagglehub/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ class ColabClient:
MOUNT_PATH = "/kagglehub/models/mount"
MODEL_MOUNT_PATH = "/kagglehub/models/mount"
DATASET_MOUNT_PATH = "/kagglehub/datasets/mount"
# TBE_RUNTIME_ADDR serves requests made from `is_supported` and `__call__`
# TBE_RUNTIME_ADDR serves requests made from `is_supported` and `_resolve`
# of ModelColabCacheResolver.
TBE_RUNTIME_ADDR_ENV_VAR_NAME = "TBE_RUNTIME_ADDR"

Expand Down
61 changes: 47 additions & 14 deletions src/kagglehub/colab_cache_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from kagglehub.exceptions import BackendError, NotFoundError
from kagglehub.handle import DatasetHandle, ModelHandle
from kagglehub.logger import EXTRA_CONSOLE_BLOCK
from kagglehub.packages import PackageScope
from kagglehub.resolver import Resolver

COLAB_CACHE_MOUNT_FOLDER_ENV_VAR_NAME = "COLAB_CACHE_MOUNT_FOLDER"
Expand All @@ -29,17 +30,20 @@ def is_supported(self, handle: ModelHandle, *_, **__) -> bool: # noqa: ANN002,
"variation": handle.variation,
}

if handle.is_versioned():
version = _get_model_version(handle)
if version:
# Colab treats version as int in the request
data["version"] = handle.version # type: ignore
data["version"] = version # type: ignore

try:
api_client.post(data, ColabClient.IS_MODEL_SUPPORTED_PATH, handle)
except NotFoundError:
return False
return True

def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str:
def _resolve(
self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
) -> tuple[str, Optional[int]]:
if force_download:
logger.info(
"Ignoring `force_download` argument when running inside the Colab notebook environment.",
Expand All @@ -53,9 +57,11 @@ def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download
"framework": h.framework,
"variation": h.variation,
}
if h.is_versioned():

version = _get_model_version(h)
if version:
# Colab treats version as int in the request
data["version"] = h.version # type: ignore
data["version"] = version # type: ignore

response = api_client.post(data, ColabClient.MODEL_MOUNT_PATH, h)

Expand Down Expand Up @@ -85,8 +91,8 @@ def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download
f"You can access the other files of the attached model at '{cached_path}'"
)
raise ValueError(msg)
return cached_filepath
return cached_path
return cached_filepath, version
return cached_path, version


class DatasetColabCacheResolver(Resolver[DatasetHandle]):
Expand All @@ -100,17 +106,20 @@ def is_supported(self, handle: DatasetHandle, *_, **__) -> bool: # noqa: ANN002
"dataset": handle.dataset,
}

if handle.is_versioned():
version = _get_dataset_version(handle)
if version:
# Colab treats version as int in the request
data["version"] = handle.version # type: ignore
data["version"] = version # type: ignore

try:
api_client.post(data, ColabClient.IS_DATASET_SUPPORTED_PATH, handle)
except NotFoundError:
return False
return True

def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str:
def _resolve(
self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
) -> tuple[str, Optional[int]]:
if force_download:
logger.info(
"Ignoring `force_download` argument when running inside the Colab notebook environment.",
Expand All @@ -122,9 +131,11 @@ def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_downlo
"owner": h.owner,
"dataset": h.dataset,
}
if h.is_versioned():

version = _get_dataset_version(h)
if version:
# Colab treats version as int in the request
data["version"] = h.version # type: ignore
data["version"] = version # type: ignore

response = api_client.post(data, ColabClient.DATASET_MOUNT_PATH, h)

Expand Down Expand Up @@ -154,5 +165,27 @@ def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_downlo
f"You can access the other files of the attached dataset at '{cached_path}'"
)
raise ValueError(msg)
return cached_filepath
return cached_path
return cached_filepath, version
return cached_path, version


def _get_model_version(h: ModelHandle) -> Optional[int]:
if h.is_versioned():
return h.version

version_from_package_scope = PackageScope.get_version(h)
if version_from_package_scope is not None:
return version_from_package_scope

return None


def _get_dataset_version(h: DatasetHandle) -> Optional[int]:
if h.is_versioned():
return h.version

version_from_package_scope = PackageScope.get_version(h)
if version_from_package_scope is not None:
return version_from_package_scope

return None
3 changes: 2 additions & 1 deletion src/kagglehub/competition.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ def competition_download(handle: str, path: Optional[str] = None, *, force_downl

h = parse_competition_handle(handle)
logger.info(f"Downloading competition: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK})
return registry.competition_resolver(h, path, force_download=force_download)
path, _ = registry.competition_resolver(h, path, force_download=force_download)
return path
3 changes: 2 additions & 1 deletion src/kagglehub/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def dataset_download(handle: str, path: Optional[str] = None, *, force_download:

h = parse_dataset_handle(handle)
logger.info(f"Downloading Dataset: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK})
return registry.dataset_resolver(h, path, force_download=force_download)
path, _ = registry.dataset_resolver(h, path, force_download=force_download)
return path


def dataset_upload(
Expand Down
31 changes: 26 additions & 5 deletions src/kagglehub/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@
NUM_VERSIONED_MODEL_PARTS = 5 # e.g.: <owner>/<model>/<framework>/<variation>/<version>
NUM_UNVERSIONED_MODEL_PARTS = 4 # e.g.: <owner>/<model>/<framework>/<variation>

NUM_VERSIONED_NOTEBOOK_PARTS = 4 # e.g.: <owner>/<notebook>/versions/<version>
NUM_UNVERSIONED_NOTEBOOK_PARTS = 2 # e.g.: <owner>/<notebook>
NUM_VERSIONED_NOTEBOOK_PARTS = 4 # e.g.: <owner>/<notebook>/versions/<version>


@dataclass
@dataclass(frozen=True)
class ResourceHandle:
@abc.abstractmethod
def to_url(self) -> str:
"""Returns URL to the resource detail page."""
pass


@dataclass
@dataclass(frozen=True)
class ModelHandle(ResourceHandle):
owner: str
model: str
Expand All @@ -35,6 +36,11 @@ class ModelHandle(ResourceHandle):
def is_versioned(self) -> bool:
return self.version is not None and self.version > 0

def with_version(self, version: int) -> "ModelHandle":
return ModelHandle(
owner=self.owner, model=self.model, framework=self.framework, variation=self.variation, version=version
)

def __str__(self) -> str:
handle_str = f"{self.owner}/{self.model}/{self.framework}/{self.variation}"
if self.is_versioned():
Expand All @@ -49,7 +55,7 @@ def to_url(self) -> str:
return f"{endpoint}/models/{self.owner}/{self.model}/{self.framework}/{self.variation}"


@dataclass
@dataclass(frozen=True)
class DatasetHandle(ResourceHandle):
owner: str
dataset: str
Expand All @@ -58,6 +64,9 @@ class DatasetHandle(ResourceHandle):
def is_versioned(self) -> bool:
return self.version is not None and self.version > 0

def with_version(self, version: int) -> "DatasetHandle":
return DatasetHandle(owner=self.owner, dataset=self.dataset, version=version)

def __str__(self) -> str:
handle_str = f"{self.owner}/{self.dataset}"
if self.is_versioned():
Expand All @@ -72,7 +81,7 @@ def to_url(self) -> str:
return base_url


@dataclass
@dataclass(frozen=True)
class CompetitionHandle(ResourceHandle):
competition: str

Expand All @@ -86,7 +95,7 @@ def to_url(self) -> str:
return base_url


@dataclass
@dataclass(frozen=True)
class NotebookHandle(ResourceHandle):
owner: str
notebook: str
Expand All @@ -95,6 +104,9 @@ class NotebookHandle(ResourceHandle):
def is_versioned(self) -> bool:
return self.version is not None and self.version > 0

def with_version(self, version: int) -> "NotebookHandle":
return NotebookHandle(owner=self.owner, notebook=self.notebook, version=version)

def __str__(self) -> str:
handle_str = f"{self.owner}/{self.notebook}"
if self.is_versioned():
Expand All @@ -113,6 +125,10 @@ class UtilityScriptHandle(NotebookHandle):
pass


class PackageHandle(NotebookHandle):
pass


def parse_dataset_handle(handle: str) -> DatasetHandle:
parts = handle.split("/")

Expand Down Expand Up @@ -217,3 +233,8 @@ def parse_notebook_handle(handle: str) -> NotebookHandle:
def parse_utility_script_handle(handle: str) -> UtilityScriptHandle:
notebook_handle = parse_notebook_handle(handle)
return UtilityScriptHandle(**asdict(notebook_handle))


def parse_package_handle(handle: str) -> PackageHandle:
notebook_handle = parse_notebook_handle(handle)
return PackageHandle(**asdict(notebook_handle))
Loading

0 comments on commit 10782d2

Please sign in to comment.