From abee2f33d29a9a1fb28db0a39b4c830c18e1013d Mon Sep 17 00:00:00 2001 From: Leo Honkanen Date: Tue, 10 Dec 2024 15:38:18 +0200 Subject: [PATCH] Support requesting on-demand inputs when using valohai input utilities --- tests/conftest.py | 4 +- tests/test_download.py | 87 +++++++++++++++++++++++++++++++++ valohai/internals/download.py | 38 +++++++++++++- valohai/internals/input_info.py | 19 +++++-- 4 files changed, 140 insertions(+), 8 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 7cc3a2c..c179309 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,8 +16,8 @@ def vte(tmpdir): def use_test_config_dir(vte, monkeypatch): vte.build() monkeypatch.setenv("VH_CONFIG_DIR", str(vte.config_path)) - monkeypatch.setenv("VH_INPUT_DIR", str(vte.inputs_path)) - monkeypatch.setenv("VH_OUTPUT_DIR", str(vte.outputs_path)) + monkeypatch.setenv("VH_INPUTS_DIR", str(vte.inputs_path)) + monkeypatch.setenv("VH_OUTPUTS_DIR", str(vte.outputs_path)) # pytest carries global state between tests if we don't flush it global_state.flush_global_state() diff --git a/tests/test_download.py b/tests/test_download.py index 6614a68..fb16d97 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -1,5 +1,7 @@ import os import sys +import uuid +import json import pytest @@ -150,3 +152,88 @@ def test_datum_url_download(tmpdir, monkeypatch, requests_mock, datum_name): assert os.path.isfile(os.path.join(inputs_dir, "example", filename)) assert requests_mock.call_count == 4 + + +def test_download_by_input_id(vte, use_test_config_dir, requests_mock): + filename = "t10k-images-idx3-ubyte.gz" + input_id = str(uuid.uuid4()) + input_request_url = "http://example.com/input-request/" + download_url = f"https://valohai-mnist.s3.amazonaws.com/{filename}" + + # Setup: + # --- + + # Write a config file that contains an on-demand input + inputs_config = { + "on-demand": { + "input_id": input_id, + "files": [ + { + "path": f"{vte.inputs_path}/{filename}", + "name": filename, + "uri": f"s3://valohai-mnist.s3.amazonaws.com/{filename}", + "size": 0, + "input_id": input_id, + "storage_uri": f"s3://valohai-mnist.s3.amazonaws.com/{filename}", + "download_intent": "on-demand", + }, + ], + } + } + with open(os.path.join(vte.config_path, "inputs.json"), "w") as inputs_f: + json.dump(inputs_config, inputs_f) + + # Write a config file that contains an input_request API endpoint + api_config = { + "input_request": { + "url": input_request_url, + "method": "POST", + }, + } + with open(os.path.join(vte.config_path, "api.json"), "w") as api_f: + json.dump(api_config, api_f) + + # Set up requests_mock + requests_mock.post( + input_request_url, + json=[ + { + "name": "on-demand", + "files": [ + { + "input_id": input_id, + "url": download_url, + "original_uri": f"s3://valohai-mnist.s3.amazonaws.com/{filename}", + "filename": filename, + "download_intent": "on-demand", + }, + ], + }, + ], + ) + requests_mock.get(download_url, text="I was downloaded by valohai-utils") + + # Assumptions + # --- + # The file does not exist before it is accessed by valohai-utils + local_filename = os.path.join(vte.inputs_path, "on-demand", filename) + assert not os.path.isfile(local_filename) + + # Trigger the download + # --- + get_input_vfs("on-demand") + + # Assertions + # --- + + # We can tell it was downloaded the way we expected it to be + assert requests_mock.call_count == 2 + first_rq, second_rq = requests_mock.request_history + assert first_rq.url == f"{input_request_url}?inputs={input_id}" + assert second_rq.url == download_url + + # The file now exists and contains the downloaded data + assert os.path.isfile(local_filename) + with open(local_filename, "r") as local_file: + file_contents = local_file.read() + assert file_contents == "I was downloaded by valohai-utils" diff --git a/valohai/internals/download.py b/valohai/internals/download.py index 2e1ba51..916dc8c 100644 --- a/valohai/internals/download.py +++ b/valohai/internals/download.py @@ -1,10 +1,12 @@ import contextlib import os import tempfile +import shutil from typing import Any, Dict, Union from requests import Response from valohai.internals.utils import uri_to_filename, get_sha256_hash +from valohai.internals.api_calls import send_api_request def resolve_datum(datum_id: str) -> Dict[str, Any]: @@ -122,4 +124,38 @@ def _do_download(url: str, path: str) -> None: if prog: prog.update(len(chunk)) f.write(chunk) - os.replace(tmp_path, path) + try: + os.replace(tmp_path, path) + except OSError: + # different filesystems, for example a tmp filesystem, Docker volume, etc + if os.path.isfile(path): + os.remove(path) + shutil.copy(tmp_path, path) + + +def request_download_urls(input_id: str) -> Dict[str, str]: + """Request download URLs for the input from Valohai. + + Returns a dict of filename -> download URL for the given input. + """ + try: + import requests + except ImportError as ie: + raise RuntimeError("Can't download on demand without requests") from ie + + try: + response = send_api_request( + endpoint="input_request", params={"inputs": [input_id]} + ) + response.raise_for_status() + except requests.RequestException as e: + raise RuntimeError("Could not get new input download URLs") from e + + # While we should only get the single input we request in the response, this does handle the case + # that we also get unrelated inputs. + return dict( + (input_file["filename"], input_file["url"]) + for input_request in response.json() + for input_file in input_request["files"] + if input_file["input_id"] == input_id + ) diff --git a/valohai/internals/input_info.py b/valohai/internals/input_info.py index 511c2e9..3b16cc0 100644 --- a/valohai/internals/input_info.py +++ b/valohai/internals/input_info.py @@ -4,7 +4,7 @@ from valohai_yaml.utils import listify -from valohai.internals.download import download_url +from valohai.internals.download import download_url, request_download_urls from valohai.internals.download_type import DownloadType from valohai.internals.utils import uri_to_filename from valohai.paths import get_inputs_path @@ -24,6 +24,7 @@ def __init__( ) -> None: self.name = str(name) self.uri = str(uri) if uri else None + self.download_url = self.uri self.checksums = dict(checksums) if checksums else {} self.path = str(path) if path else None self.size = int(size) if size else None @@ -34,10 +35,10 @@ def is_downloaded(self) -> Optional[bool]: return bool(self.path and os.path.isfile(self.path)) def download(self, path: str, force_download: bool = False) -> None: - if not self.uri: + if not self.download_url: raise ValueError("Can not download file with no URI") self.path = download_url( - self.uri, os.path.join(path, self.name), force_download + self.download_url, os.path.join(path, self.name), force_download ) # TODO: Store size & checksums if they become useful @@ -55,8 +56,9 @@ def from_json_data(cls, json_data: Dict[str, Any]) -> "FileInfo": class InputInfo: - def __init__(self, files: Iterable[FileInfo]): + def __init__(self, files: Iterable[FileInfo], input_id: Optional[str] = None): self.files = list(files) + self.input_id = input_id def is_downloaded(self) -> bool: if not self.files: @@ -74,13 +76,20 @@ def download_if_necessary( ): path = get_inputs_path(name) os.makedirs(path, exist_ok=True) + if self.input_id: + # Resolve download URLs from Valohai before downloading + filenames_to_urls = request_download_urls(self.input_id) + for file in self.files: + if not file.is_downloaded(): + file.download_url = filenames_to_urls[file.name] for f in self.files: f.download(path, force_download=(download == DownloadType.ALWAYS)) @classmethod def from_json_data(cls, json_data: Dict[str, Any]) -> "InputInfo": return cls( - files=[FileInfo.from_json_data(d) for d in json_data.get("files", ())] + input_id=json_data.get("input_id"), + files=[FileInfo.from_json_data(d) for d in json_data.get("files", ())], ) @classmethod