Skip to content

Commit

Permalink
Add tests for auto-(de)compression (#195)
Browse files Browse the repository at this point in the history
Notes:
- This is follow on work from #194
- Also added some integration tests
- Note that not all stubs have been moved to this redirect path. Only
the paths that could have compressed URLs in prod have been affected
(Comps and Datasets). We can move all download paths to use this if we'd
like.

http://b/379756505
  • Loading branch information
goeffthomas authored Dec 16, 2024
1 parent 1bc9814 commit 5fdb159
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 96 deletions.
12 changes: 10 additions & 2 deletions integration_tests/test_competition_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from kagglehub import competition_download

from .utils import assert_files, create_test_cache
from .utils import assert_columns, assert_files, create_test_cache

HANDLE = "titanic"
IEEE_FRAUD_DETECTION_HANDLE = "ieee-fraud-detection"


class TestCompetitionDownload(unittest.TestCase):
Expand All @@ -32,7 +33,7 @@ def test_competition_competition_rules_accepted_succeeds(self) -> None:
"train_transaction.csv",
]

actual_path = competition_download("ieee-fraud-detection")
actual_path = competition_download(IEEE_FRAUD_DETECTION_HANDLE)

assert_files(self, actual_path, expected_files)

Expand All @@ -53,6 +54,13 @@ def test_competition_multiple_files(self) -> None:
actual_path = competition_download(HANDLE, path=p)
assert_files(self, actual_path, [p])

def test_auto_decompress_file(self) -> None:
with create_test_cache():
# sample_submission.csv is an auto-compressed CSV with the following columns
expected_columns = ["TransactionID", "isFraud"]
actual_path = competition_download(IEEE_FRAUD_DETECTION_HANDLE, path="sample_submission.csv")
assert_columns(self, actual_path, expected_columns)

def test_competition_with_incorrect_file_path(self) -> None:
incorrect_path = "nonxisten/Test"
with self.assertRaises(HTTPError) as e:
Expand Down
9 changes: 8 additions & 1 deletion integration_tests/test_dataset_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from kagglehub import dataset_download

from .utils import assert_files, create_test_cache, unauthenticated
from .utils import assert_columns, assert_files, create_test_cache, unauthenticated

UNVERSIONED_HANDLE = "ryanholbrook/dl-course-data"
HANDLE = "ryanholbrook/dl-course-data/versions/5"
Expand Down Expand Up @@ -70,6 +70,13 @@ def test_download_multiple_files(self) -> None:
actual_path = dataset_download(HANDLE, path=p)
assert_files(self, actual_path, [p])

def test_auto_decompress_file(self) -> None:
with create_test_cache():
# diamonds.csv is an auto-compressed CSV with the following columns
expected_columns = ["carat", "cut", "color", "clarity", "depth", "table", "price", "x", "y", "z"]
actual_path = dataset_download(HANDLE, path="diamonds.csv")
assert_columns(self, actual_path, expected_columns)

def test_download_with_incorrect_file_path(self) -> None:
incorrect_path = "nonexistent/file/path"
with self.assertRaises(HTTPError):
Expand Down
15 changes: 14 additions & 1 deletion integration_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,14 @@ def list_files_recursively(path: str) -> list[str]:
return sorted(files)


def assert_files(test_case: unittest.TestCase, path: str, expected_files: list[str]) -> bool:
def list_columns(path: str) -> list[str]:
"""Assuming the path is a CSV, list all columns sorted lexicographically"""
with open(path) as file:
first_line = file.readline().strip()
return sorted(first_line.split(","))


def assert_files(test_case: unittest.TestCase, path: str, expected_files: list[str]) -> None:
"""Assert that all expected files exist and are non-empty."""
files = list_files_recursively(path)
expected_files_sorted = sorted(expected_files)
Expand All @@ -53,6 +60,12 @@ def assert_files(test_case: unittest.TestCase, path: str, expected_files: list[s
test_case.assertGreater(os.path.getsize(file_path), 0, f"File {file} is empty")


def assert_columns(test_case: unittest.TestCase, path: str, expected_columns: list[str]) -> None:
"""Assert that the given path to a CSV has the expected columns."""
columns = list_columns(path)
test_case.assertEqual(columns, sorted(expected_columns))


@contextmanager
def unauthenticated() -> Generator[None, None, None]:
with mock.patch.dict(
Expand Down
34 changes: 6 additions & 28 deletions tests/server_stubs/competition_download_stub.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import hashlib
import os
from collections.abc import Generator
from typing import Any

from flask import Flask, Response, jsonify
from flask.typing import ResponseReturnValue

from kagglehub.integrity import to_b64_digest
from tests.utils import get_test_file_path
from tests.utils import AUTO_COMPRESSED_FILE_NAME, add_mock_gcs_route, get_gcs_redirect_response, get_test_file_path

app = Flask(__name__)
add_mock_gcs_route(app)

TARGZ_ARCHIVE_HANDLE = "competition-targz"

Expand Down Expand Up @@ -51,31 +50,10 @@ def competition_download(competition_slug: str) -> ResponseReturnValue:
@app.route("/api/v1/competitions/data/download/<competition_slug>/<file_name>", methods=["GET"])
def competition_download_file(competition_slug: str, file_name: str) -> ResponseReturnValue:
_ = f"{competition_slug}"
test_file_path = get_test_file_path(file_name)

def generate_file_content() -> Generator[bytes, Any, None]:
with open(test_file_path, "rb") as f:
while True:
chunk = f.read(4096) # Read file in chunks
if not chunk:
break
yield chunk

with open(test_file_path, "rb") as f:
content = f.read()
file_hash = hashlib.md5()
file_hash.update(content)
return (
Response(
generate_file_content(),
headers={
GCS_HASH_HEADER: f"md5={to_b64_digest(file_hash)}",
"Content-Length": str(len(content)),
LAST_MODIFIED: LAST_MODIFIED_DATE,
},
),
200,
)
# This mimics behavior for our file downloads, where users request a file, but
# receive a zipped version of the file from GCS.
test_file = f"{file_name}.zip" if file_name is AUTO_COMPRESSED_FILE_NAME else file_name
return get_gcs_redirect_response(test_file)


@app.errorhandler(404)
Expand Down
61 changes: 4 additions & 57 deletions tests/server_stubs/dataset_download_stub.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,13 @@
import hashlib
import mimetypes
import os
from collections.abc import Generator
from typing import Any

from flask import Flask, Response, jsonify, request
from flask import Flask, jsonify, request
from flask.typing import ResponseReturnValue

from kagglehub.http_resolver import DATASET_CURRENT_VERSION_FIELD
from kagglehub.integrity import to_b64_digest
from tests.utils import MOCK_GCS_BUCKET_BASE_PATH, get_mocked_gcs_signed_url, get_test_file_path
from tests.utils import AUTO_COMPRESSED_FILE_NAME, add_mock_gcs_route, get_gcs_redirect_response

app = Flask(__name__)
add_mock_gcs_route(app)

TARGZ_ARCHIVE_HANDLE = "testuser/zip-dataset/versions/1"
AUTO_COMPRESSED_FILE_NAME = "dummy.csv"

# See https://cloud.google.com/storage/docs/xml-api/reference-headers#xgooghash
GCS_HASH_HEADER = "x-goog-hash"
LOCATION_HEADER = "Location"
CONTENT_LENGTH_HEADER = "Content-Length"


@app.route("/", methods=["HEAD"])
Expand Down Expand Up @@ -58,48 +46,7 @@ def dataset_download(owner_slug: str, dataset_slug: str) -> ResponseReturnValue:
else:
test_file_name = "foo.txt.zip"

# All downloads, regardless of archive or file, happen via GCS signed URLs. We mock the 302 and handle
# the redirect not only to be thorough--without this, the response.url in download_file (clients.py)
# will not pick up on followed redirect URL being different from the originally requested URL.
return (
Response(
headers={
LOCATION_HEADER: get_mocked_gcs_signed_url(os.path.basename(test_file_name)),
CONTENT_LENGTH_HEADER: "0",
}
),
302,
)


# Route to handle the mocked GCS redirects
@app.route(f"{MOCK_GCS_BUCKET_BASE_PATH}/<file_name>", methods=["GET"])
def handle_mock_gcs_redirect(file_name: str) -> ResponseReturnValue:
test_file_path = get_test_file_path(file_name)

def generate_file_content() -> Generator[bytes, Any, None]:
with open(test_file_path, "rb") as f:
while True:
chunk = f.read(4096) # Read file in chunks
if not chunk:
break
yield chunk

with open(test_file_path, "rb") as f:
content = f.read()
file_hash = hashlib.md5()
file_hash.update(content)
return (
Response(
generate_file_content(),
headers={
GCS_HASH_HEADER: f"md5={to_b64_digest(file_hash)}",
"Content-Length": str(os.path.getsize(test_file_path)),
"Content-Type": mimetypes.guess_type(test_file_path)[0] or "application/octet-stream",
},
),
200,
)
return get_gcs_redirect_response(test_file_name)


@app.errorhandler(404)
Expand Down
26 changes: 22 additions & 4 deletions tests/test_http_competition_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,21 @@

from .server_stubs import competition_download_stub as stub
from .server_stubs import serv
from .utils import create_test_cache
from .utils import AUTO_COMPRESSED_FILE_NAME, create_test_cache

INVALID_ARCHIVE_COMPETITION_HANDLE = "invalid/invalid"
COMPETITION_HANDLE = "titanic"
TEST_FILEPATH = "foo.txt"
TEST_CONTENTS = "foo"
AUTO_COMPRESSED_CONTENTS = """column_1,column2
1,a
2,b
3,c"""

EXPECTED_COMPETITION_SUBDIR = os.path.join(COMPETITIONS_CACHE_SUBFOLDER, "titanic")
EXPECTED_COMPETITION_SUBPATH = os.path.join(
COMPETITIONS_CACHE_SUBFOLDER,
"titanic",
TEST_FILEPATH,
)


Expand Down Expand Up @@ -65,10 +68,21 @@ def _download_test_file_and_assert_downloaded(
) -> None:
competition_path = kagglehub.competition_download(competition_handle, path=TEST_FILEPATH, **kwargs)

self.assertEqual(os.path.join(d, EXPECTED_COMPETITION_SUBPATH), competition_path)
self.assertEqual(os.path.join(d, EXPECTED_COMPETITION_SUBPATH, TEST_FILEPATH), competition_path)
with open(competition_path) as competition_file:
self.assertEqual(TEST_CONTENTS, competition_file.readline())

def _download_test_file_and_assert_downloaded_auto_compressed(
self,
d: str,
competition_handle: str,
**kwargs, # noqa: ANN003
) -> None:
competition_path = kagglehub.competition_download(competition_handle, path=AUTO_COMPRESSED_FILE_NAME, **kwargs)
self.assertEqual(os.path.join(d, EXPECTED_COMPETITION_SUBPATH, AUTO_COMPRESSED_FILE_NAME), competition_path)
with open(competition_path) as competition_file:
self.assertEqual(AUTO_COMPRESSED_CONTENTS, competition_file.read())

def test_competition_download(self) -> None:
with create_test_cache() as d:
self._download_competition_and_assert_downloaded(d, COMPETITION_HANDLE, EXPECTED_COMPETITION_SUBDIR)
Expand Down Expand Up @@ -120,6 +134,10 @@ def test_competition_download_with_path(self) -> None:
with create_test_cache() as d:
self._download_test_file_and_assert_downloaded(d, COMPETITION_HANDLE)

def test_competition_download_with_path_auto_compressed(self) -> None:
with create_test_cache() as d:
self._download_test_file_and_assert_downloaded_auto_compressed(d, COMPETITION_HANDLE)


class TestHttpNoInternet(BaseTestCase):
@classmethod
Expand All @@ -145,7 +163,7 @@ def test_competition_download_path_already_cached_with_no_internet(self) -> None

path = kagglehub.competition_download(COMPETITION_HANDLE, path=TEST_FILEPATH)

self.assertEqual(os.path.join(d, EXPECTED_COMPETITION_SUBPATH), path)
self.assertEqual(os.path.join(d, EXPECTED_COMPETITION_SUBPATH, TEST_FILEPATH), path)

def test_competition_download_already_cached_with_force_download_no_internet(self) -> None:
with create_test_cache():
Expand Down
6 changes: 3 additions & 3 deletions tests/test_http_dataset_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .server_stubs import dataset_download_stub as stub
from .server_stubs import serv
from .utils import create_test_cache
from .utils import AUTO_COMPRESSED_FILE_NAME, create_test_cache

INVALID_ARCHIVE_DATASET_HANDLE = "invalid/invalid/invalid/invalid/invalid"
VERSIONED_DATASET_HANDLE = "sarahjeffreson/featured-spotify-artiststracks-with-metadata/versions/2"
Expand Down Expand Up @@ -74,8 +74,8 @@ def _download_test_file_and_assert_downloaded_auto_compressed(
dataset_handle: str,
**kwargs, # noqa: ANN003
) -> None:
dataset_path = kagglehub.dataset_download(dataset_handle, path=stub.AUTO_COMPRESSED_FILE_NAME, **kwargs)
self.assertEqual(os.path.join(d, EXPECTED_DATASET_SUBPATH, stub.AUTO_COMPRESSED_FILE_NAME), dataset_path)
dataset_path = kagglehub.dataset_download(dataset_handle, path=AUTO_COMPRESSED_FILE_NAME, **kwargs)
self.assertEqual(os.path.join(d, EXPECTED_DATASET_SUBPATH, AUTO_COMPRESSED_FILE_NAME), dataset_path)
with open(dataset_path) as dataset_file:
self.assertEqual(AUTO_COMPRESSED_CONTENTS, dataset_file.read())

Expand Down
64 changes: 64 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
import hashlib
import mimetypes
import os
from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any
from unittest import mock

from flask import Flask, Response
from flask.typing import ResponseReturnValue

from kagglehub.config import CACHE_FOLDER_ENV_VAR_NAME, get_kaggle_api_endpoint
from kagglehub.handle import ResourceHandle
from kagglehub.integrity import GCS_HASH_HEADER, to_b64_digest

MOCK_GCS_BUCKET_BASE_PATH = "/mock-gcs-bucket/file-path"
AUTO_COMPRESSED_FILE_NAME = "dummy.csv"
LOCATION_HEADER = "Location"
CONTENT_LENGTH_HEADER = "Content-Length"


def get_test_file_path(relative_path: str) -> str:
Expand All @@ -25,6 +35,21 @@ def get_mocked_gcs_signed_url(file_name: str) -> str:
return f"{get_kaggle_api_endpoint()}{MOCK_GCS_BUCKET_BASE_PATH}/{file_name}?X-Goog-Headers=all-kinds-of-stuff"


# All downloads, regardless of archive or file, happen via GCS signed URLs. We mock the 302 and handle
# the redirect not only to be thorough--without this, the response.url in download_file (clients.py)
# will not pick up on followed redirect URL being different from the originally requested URL.
def get_gcs_redirect_response(file_name: str) -> ResponseReturnValue:
return (
Response(
headers={
LOCATION_HEADER: get_mocked_gcs_signed_url(file_name),
CONTENT_LENGTH_HEADER: "0",
}
),
302,
)


@contextmanager
def create_test_cache() -> Generator[str, None, None]:
with TemporaryDirectory() as d:
Expand All @@ -38,3 +63,42 @@ def __init__(self):

def to_url(self) -> str:
return "invalid"


def add_mock_gcs_route(app: Flask) -> None:
"""Adds the mock GCS route for handling signed URL redirects"""

app.add_url_rule(
f"{MOCK_GCS_BUCKET_BASE_PATH}/<file_name>",
endpoint="handle_mock_gcs_redirect",
view_func=handle_mock_gcs_redirect,
methods=["get"],
)


def handle_mock_gcs_redirect(file_name: str) -> ResponseReturnValue:
test_file_path = get_test_file_path(file_name)

def generate_file_content() -> Generator[bytes, Any, None]:
with open(test_file_path, "rb") as f:
while True:
chunk = f.read(4096) # Read file in chunks
if not chunk:
break
yield chunk

with open(test_file_path, "rb") as f:
content = f.read()
file_hash = hashlib.md5()
file_hash.update(content)
return (
Response(
generate_file_content(),
headers={
GCS_HASH_HEADER: f"md5={to_b64_digest(file_hash)}",
"Content-Length": str(os.path.getsize(test_file_path)),
"Content-Type": mimetypes.guess_type(test_file_path)[0] or "application/octet-stream",
},
),
200,
)

0 comments on commit 5fdb159

Please sign in to comment.