Skip to content

Commit

Permalink
Support requesting on-demand inputs when using valohai input utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
hylje committed Dec 12, 2024
1 parent 25872d3 commit 544ebc3
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 6 deletions.
91 changes: 91 additions & 0 deletions tests/test_download.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import sys
import uuid
import json

import pytest

Expand Down Expand Up @@ -150,3 +152,92 @@ def test_datum_url_download(tmpdir, monkeypatch, requests_mock, datum_name):

assert os.path.isfile(os.path.join(inputs_dir, "example", filename))
assert requests_mock.call_count == 4


def test_download_by_input_id(tmpdir, monkeypatch, requests_mock):
inputs_dir = str(tmpdir.mkdir("inputs"))
config_dir = str(tmpdir.mkdir("config"))
monkeypatch.setenv("VH_INPUTS_DIR", inputs_dir)
monkeypatch.setenv("VH_CONFIG_DIR", config_dir)
filename = "t10k-images-idx3-ubyte.gz"
input_id = str(uuid.uuid4())
input_request_url = "http://example.com/input-request/"
download_url = f"https://valohai-mnist.s3.amazonaws.com/{filename}"

# Setup:
# ---

# Write a config file that contains an on-demand input
inputs_config = {
"on-demand": {
"input_id": input_id,
"files": [
{
"path": f"{inputs_dir}/{filename}",
"name": filename,
"uri": f"s3://valohai-mnist.s3.amazonaws.com/{filename}",
"size": 0,
"input_id": input_id,
"storage_uri": f"s3://valohai-mnist.s3.amazonaws.com/{filename}",
"download_intent": "on-demand",
},
],
}
}
with open(os.path.join(config_dir, "inputs.json"), "w") as inputs_f:
json.dump(inputs_config, inputs_f)

# Write a config file that contains an input_request API endpoint
api_config = {
"input_request": {
"url": input_request_url,
"method": "POST",
},
}
with open(os.path.join(config_dir, "api.json"), "w") as api_f:
json.dump(api_config, api_f)

# Set up requests_mock
requests_mock.post(
input_request_url,
json=[
{
"name": "on-demand",
"files": [
{
"input_id": input_id,
"url": download_url,
"original_uri": f"s3://valohai-mnist.s3.amazonaws.com/{filename}",
"filename": filename,
"download_intent": "on-demand",
},
],
},
],
)
requests_mock.get(download_url, text="I was downloaded by valohai-utils")

# Assumptions
# ---
# The file does not exist before it is accessed by valohai-utils
local_filename = os.path.join(inputs_dir, "on-demand", filename)
assert not os.path.isfile(local_filename)

# Trigger the download
# ---
get_input_vfs("on-demand")

# Assertions
# ---

# We can tell it was downloaded the way we expected it to be
assert requests_mock.call_count == 2
first_rq, second_rq = requests_mock.request_history
assert first_rq.url == f"{input_request_url}?inputs={input_id}"
assert second_rq.url == download_url

# The file now exists and contains the downloaded data
assert os.path.isfile(local_filename)
with open(local_filename, "r") as local_file:
file_contents = local_file.read()
assert file_contents == "I was downloaded by valohai-utils"
38 changes: 37 additions & 1 deletion valohai/internals/download.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import contextlib
import os
import tempfile
import shutil
from typing import Any, Dict, Union

from requests import Response
from valohai.internals.utils import uri_to_filename, get_sha256_hash
from valohai.internals.api_calls import send_api_request


def resolve_datum(datum_id: str) -> Dict[str, Any]:
Expand Down Expand Up @@ -122,4 +124,38 @@ def _do_download(url: str, path: str) -> None:
if prog:
prog.update(len(chunk))
f.write(chunk)
os.replace(tmp_path, path)
try:
os.replace(tmp_path, path)
except OSError:
# different filesystems, for example a tmp filesystem, Docker volume, etc
if os.path.isfile(path):
os.remove(path)
shutil.copy(tmp_path, path)


def request_download_urls(input_id: str) -> Dict[str, str]:
"""Request download URLs for the input from Valohai.
Returns a dict of filename -> download URL for the given input.
"""
try:
import requests
except ImportError as ie:
raise RuntimeError("Can't download on demand without requests") from ie

try:
response = send_api_request(
endpoint="input_request", params={"inputs": [input_id]}
)
response.raise_for_status()
except requests.RequestException as e:
raise RuntimeError("Could not get new input download URLs") from e

# While we should only get the single input we request in the response, this does handle the case
# that we also get unrelated inputs.
return dict(
(input_file["filename"], input_file["url"])
for input_request in response.json()
for input_file in input_request["files"]
if input_file["input_id"] == input_id
)
19 changes: 14 additions & 5 deletions valohai/internals/input_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from valohai_yaml.utils import listify

from valohai.internals.download import download_url
from valohai.internals.download import download_url, request_download_urls
from valohai.internals.download_type import DownloadType
from valohai.internals.utils import uri_to_filename
from valohai.paths import get_inputs_path
Expand All @@ -24,6 +24,7 @@ def __init__(
) -> None:
self.name = str(name)
self.uri = str(uri) if uri else None
self.download_url = self.uri
self.checksums = dict(checksums) if checksums else {}
self.path = str(path) if path else None
self.size = int(size) if size else None
Expand All @@ -34,10 +35,10 @@ def is_downloaded(self) -> Optional[bool]:
return bool(self.path and os.path.isfile(self.path))

def download(self, path: str, force_download: bool = False) -> None:
if not self.uri:
if not self.download_url:
raise ValueError("Can not download file with no URI")
self.path = download_url(
self.uri, os.path.join(path, self.name), force_download
self.download_url, os.path.join(path, self.name), force_download
)
# TODO: Store size & checksums if they become useful

Expand All @@ -55,8 +56,9 @@ def from_json_data(cls, json_data: Dict[str, Any]) -> "FileInfo":


class InputInfo:
def __init__(self, files: Iterable[FileInfo]):
def __init__(self, files: Iterable[FileInfo], input_id: Optional[str] = None):
self.files = list(files)
self.input_id = input_id

def is_downloaded(self) -> bool:
if not self.files:
Expand All @@ -74,13 +76,20 @@ def download_if_necessary(
):
path = get_inputs_path(name)
os.makedirs(path, exist_ok=True)
if self.input_id:
# Resolve download URLs from Valohai before downloading
filenames_to_urls = request_download_urls(self.input_id)
for file in self.files:
if not file.is_downloaded():
file.download_url = filenames_to_urls[file.name]
for f in self.files:
f.download(path, force_download=(download == DownloadType.ALWAYS))

@classmethod
def from_json_data(cls, json_data: Dict[str, Any]) -> "InputInfo":
return cls(
files=[FileInfo.from_json_data(d) for d in json_data.get("files", ())]
input_id=json_data.get("input_id"),
files=[FileInfo.from_json_data(d) for d in json_data.get("files", ())],
)

@classmethod
Expand Down

0 comments on commit 544ebc3

Please sign in to comment.