diff --git a/monai/apps/utils.py b/monai/apps/utils.py index a36caf2e66..35322be2a8 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -12,6 +12,7 @@ from __future__ import annotations import hashlib +import json import logging import os import shutil @@ -24,7 +25,7 @@ from typing import TYPE_CHECKING, Any from urllib.error import ContentTooShortError, HTTPError, URLError from urllib.parse import urlparse -from urllib.request import urlretrieve +from urllib.request import urlopen, urlretrieve from monai.config.type_definitions import PathLike from monai.utils import look_up_option, min_version, optional_import @@ -203,6 +204,17 @@ def download_url( if not has_gdown: raise RuntimeError("To download files from Google Drive, please install the gdown dependency.") gdown.download(url, f"{tmp_name}", quiet=not progress, **gdown_kwargs) + elif urlparse(url).netloc == "cloud-api.yandex.net": + with urlopen(url) as response: + code = response.getcode() + if code == 200: + download_url = json.loads(response.read())["href"] + _download_with_progress(download_url, tmp_name, progress=progress) + else: + raise RuntimeError( + f"Download of file from {download_url}, received from {url} " + + f" to {filepath} failed due to network issue or denied permission." + ) else: _download_with_progress(url, tmp_name, progress=progress) if not tmp_name.exists(): diff --git a/tests/test_download_url_yandex.py b/tests/test_download_url_yandex.py new file mode 100644 index 0000000000..d0946f9f70 --- /dev/null +++ b/tests/test_download_url_yandex.py @@ -0,0 +1,43 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import tempfile +import unittest +from urllib.error import HTTPError + +from monai.apps.utils import download_url + +YANDEX_MODEL_URL = ( + "https://cloud-api.yandex.net/v1/disk/public/resources/download?" + "public_key=https%3A%2F%2Fdisk.yandex.ru%2Fd%2Fxs0gzlj2_irgWA" +) +YANDEX_MODEL_FLAWED_URL = ( + "https://cloud-api.yandex.net/v1/disk/public/resources/download?" + "public_key=https%3A%2F%2Fdisk.yandex.ru%2Fd%2Fxs0gzlj2_irgWA-url-with-error" +) + + +class TestDownloadUrlYandex(unittest.TestCase): + def test_verify(self): + with tempfile.TemporaryDirectory() as tempdir: + download_url(url=YANDEX_MODEL_URL, filepath=os.path.join(tempdir, "model.pt")) + + def test_verify_error(self): + with tempfile.TemporaryDirectory() as tempdir: + with self.assertRaises(HTTPError): + download_url(url=YANDEX_MODEL_FLAWED_URL, filepath=os.path.join(tempdir, "model.pt")) + + +if __name__ == "__main__": + unittest.main()