Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support requesting on-demand inputs when using valohai input utilities #137

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def vte(tmpdir):
def use_test_config_dir(vte, monkeypatch):
vte.build()
monkeypatch.setenv("VH_CONFIG_DIR", str(vte.config_path))
monkeypatch.setenv("VH_INPUT_DIR", str(vte.inputs_path))
monkeypatch.setenv("VH_OUTPUT_DIR", str(vte.outputs_path))
monkeypatch.setenv("VH_INPUTS_DIR", str(vte.inputs_path))
monkeypatch.setenv("VH_OUTPUTS_DIR", str(vte.outputs_path))
Comment on lines +19 to +20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 😃


# pytest carries global state between tests if we don't flush it
global_state.flush_global_state()
Expand Down
87 changes: 87 additions & 0 deletions tests/test_download.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import sys
import uuid
import json

import pytest

Expand Down Expand Up @@ -150,3 +152,88 @@ def test_datum_url_download(tmpdir, monkeypatch, requests_mock, datum_name):

assert os.path.isfile(os.path.join(inputs_dir, "example", filename))
assert requests_mock.call_count == 4


def test_download_by_input_id(vte, use_test_config_dir, requests_mock):
filename = "t10k-images-idx3-ubyte.gz"
input_id = str(uuid.uuid4())
input_request_url = "http://example.com/input-request/"
download_url = f"https://valohai-mnist.s3.amazonaws.com/{filename}"

# Setup:
# ---

# Write a config file that contains an on-demand input
inputs_config = {
"on-demand": {
"input_id": input_id,
"files": [
{
"path": f"{vte.inputs_path}/{filename}",
"name": filename,
"uri": f"s3://valohai-mnist.s3.amazonaws.com/{filename}",
"size": 0,
"input_id": input_id,
"storage_uri": f"s3://valohai-mnist.s3.amazonaws.com/{filename}",
"download_intent": "on-demand",
},
],
}
}
with open(os.path.join(vte.config_path, "inputs.json"), "w") as inputs_f:
json.dump(inputs_config, inputs_f)

# Write a config file that contains an input_request API endpoint
api_config = {
"input_request": {
"url": input_request_url,
"method": "POST",
},
}
with open(os.path.join(vte.config_path, "api.json"), "w") as api_f:
json.dump(api_config, api_f)

# Set up requests_mock
requests_mock.post(
input_request_url,
json=[
{
"name": "on-demand",
"files": [
{
"input_id": input_id,
"url": download_url,
"original_uri": f"s3://valohai-mnist.s3.amazonaws.com/{filename}",
"filename": filename,
"download_intent": "on-demand",
},
],
},
],
)
requests_mock.get(download_url, text="I was downloaded by valohai-utils")

# Assumptions
# ---
# The file does not exist before it is accessed by valohai-utils
local_filename = os.path.join(vte.inputs_path, "on-demand", filename)
assert not os.path.isfile(local_filename)

# Trigger the download
# ---
get_input_vfs("on-demand")

# Assertions
# ---

# We can tell it was downloaded the way we expected it to be
assert requests_mock.call_count == 2
first_rq, second_rq = requests_mock.request_history
assert first_rq.url == f"{input_request_url}?inputs={input_id}"
assert second_rq.url == download_url

# The file now exists and contains the downloaded data
assert os.path.isfile(local_filename)
with open(local_filename, "r") as local_file:
file_contents = local_file.read()
assert file_contents == "I was downloaded by valohai-utils"
38 changes: 37 additions & 1 deletion valohai/internals/download.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import contextlib
import os
import tempfile
import shutil
from typing import Any, Dict, Union

from requests import Response
from valohai.internals.utils import uri_to_filename, get_sha256_hash
from valohai.internals.api_calls import send_api_request


def resolve_datum(datum_id: str) -> Dict[str, Any]:
Expand Down Expand Up @@ -122,4 +124,38 @@ def _do_download(url: str, path: str) -> None:
if prog:
prog.update(len(chunk))
f.write(chunk)
os.replace(tmp_path, path)
try:
os.replace(tmp_path, path)
except OSError:
# different filesystems, for example a tmp filesystem, Docker volume, etc
if os.path.isfile(path):
os.remove(path)
shutil.copy(tmp_path, path)


def request_download_urls(input_id: str) -> Dict[str, str]:
"""Request download URLs for the input from Valohai.

Returns a dict of filename -> download URL for the given input.
"""
try:
import requests
except ImportError as ie:
raise RuntimeError("Can't download on demand without requests") from ie

try:
response = send_api_request(
endpoint="input_request", params={"inputs": [input_id]}
)
response.raise_for_status()
except requests.RequestException as e:
raise RuntimeError("Could not get new input download URLs") from e

# While we should only get the single input we request in the response, this does handle the case
# that we also get unrelated inputs.
return dict(
(input_file["filename"], input_file["url"])
for input_request in response.json()
for input_file in input_request["files"]
if input_file["input_id"] == input_id
)
19 changes: 14 additions & 5 deletions valohai/internals/input_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from valohai_yaml.utils import listify

from valohai.internals.download import download_url
from valohai.internals.download import download_url, request_download_urls
from valohai.internals.download_type import DownloadType
from valohai.internals.utils import uri_to_filename
from valohai.paths import get_inputs_path
Expand All @@ -24,6 +24,7 @@ def __init__(
) -> None:
self.name = str(name)
self.uri = str(uri) if uri else None
self.download_url = self.uri
self.checksums = dict(checksums) if checksums else {}
self.path = str(path) if path else None
self.size = int(size) if size else None
Expand All @@ -34,10 +35,10 @@ def is_downloaded(self) -> Optional[bool]:
return bool(self.path and os.path.isfile(self.path))

def download(self, path: str, force_download: bool = False) -> None:
if not self.uri:
if not self.download_url:
raise ValueError("Can not download file with no URI")
self.path = download_url(
self.uri, os.path.join(path, self.name), force_download
self.download_url, os.path.join(path, self.name), force_download
)
# TODO: Store size & checksums if they become useful

Expand All @@ -55,8 +56,9 @@ def from_json_data(cls, json_data: Dict[str, Any]) -> "FileInfo":


class InputInfo:
def __init__(self, files: Iterable[FileInfo]):
def __init__(self, files: Iterable[FileInfo], input_id: Optional[str] = None):
self.files = list(files)
self.input_id = input_id

def is_downloaded(self) -> bool:
if not self.files:
Expand All @@ -74,13 +76,20 @@ def download_if_necessary(
):
path = get_inputs_path(name)
os.makedirs(path, exist_ok=True)
if self.input_id:
# Resolve download URLs from Valohai before downloading
filenames_to_urls = request_download_urls(self.input_id)
for file in self.files:
if not file.is_downloaded():
file.download_url = filenames_to_urls[file.name]
for f in self.files:
f.download(path, force_download=(download == DownloadType.ALWAYS))

@classmethod
def from_json_data(cls, json_data: Dict[str, Any]) -> "InputInfo":
return cls(
files=[FileInfo.from_json_data(d) for d in json_data.get("files", ())]
input_id=json_data.get("input_id"),
files=[FileInfo.from_json_data(d) for d in json_data.get("files", ())],
)

@classmethod
Expand Down
Loading