Skip to content

Commit

Permalink
Improve logging when erroring from calling urlretrieve (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt authored Nov 15, 2024
1 parent e12b53c commit 4024603
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 16 deletions.
62 changes: 46 additions & 16 deletions src/pystow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import shutil
import tarfile
import tempfile
import urllib.error
import zipfile
from collections import namedtuple
from functools import partial
Expand Down Expand Up @@ -342,6 +343,7 @@ def download(
:raises KeyboardInterrupt: If a keyboard interrupt is thrown during download
:raises UnexpectedDirectory: If a directory is given for the ``path`` argument
:raises ValueError: If an invalid backend is chosen
:raises DownloadError: If an error occurs during download
"""
path = Path(path).resolve()

Expand Down Expand Up @@ -373,24 +375,34 @@ def download(
if backend == "urllib":
logger.info("downloading with urllib from %s to %s", url, path)
with TqdmReportHook(**_tqdm_kwargs) as t:
urlretrieve(url, path, reporthook=t.update_to, **kwargs) # noqa:S310
try:
urlretrieve(url, path, reporthook=t.update_to, **kwargs) # noqa:S310
except urllib.error.URLError as e:
raise DownloadError(backend, url, path) from e
elif backend == "requests":
kwargs.setdefault("stream", True)
# see https://requests.readthedocs.io/en/master/user/quickstart/#raw-response-content
# pattern from https://stackoverflow.com/a/39217788/5775947
with requests.get(url, **kwargs) as response, path.open("wb") as file: # noqa:S113
logger.info(
"downloading (stream=%s) with requests from %s to %s",
kwargs["stream"],
url,
path,
)
# Solution for progress bar from https://stackoverflow.com/a/63831344/5775947
total_size = int(response.headers.get("Content-Length", 0))
# Decompress if needed
response.raw.read = partial(response.raw.read, decode_content=True) # type:ignore
with tqdm.wrapattr(response.raw, "read", total=total_size, **_tqdm_kwargs) as fsrc:
shutil.copyfileobj(fsrc, file)
try:
# see https://requests.readthedocs.io/en/master/user/quickstart/#raw-response-content
# pattern from https://stackoverflow.com/a/39217788/5775947
with requests.get(url, **kwargs) as response, path.open("wb") as file: # noqa:S113
logger.info(
"downloading (stream=%s) with requests from %s to %s",
kwargs["stream"],
url,
path,
)
# Solution for progress bar from https://stackoverflow.com/a/63831344/5775947
total_size = int(response.headers.get("Content-Length", 0))
# Decompress if needed
response.raw.read = partial( # type:ignore[method-assign]
response.raw.read, decode_content=True
)
with tqdm.wrapattr(
response.raw, "read", total=total_size, **_tqdm_kwargs
) as fsrc:
shutil.copyfileobj(fsrc, file)
except requests.exceptions.ConnectionError as e:
raise DownloadError(backend, url, path) from e
else:
raise ValueError(f'Invalid backend: {backend}. Use "requests" or "urllib".')
except (Exception, KeyboardInterrupt):
Expand All @@ -406,6 +418,24 @@ def download(
)


class DownloadError(OSError):
"""An error that wraps information from a requests or urllib download failure."""

def __init__(self, backend: str, url: str, path: Path) -> None:
"""Initialize the error.
:param backend: The backend used
:param url: The url that failed to download
:param path: The path that was supposed to be downloaded to
"""
self.backend = backend
self.url = url
self.path = path

def __str__(self) -> str:
return f"Failed with {self.backend} to download {self.url} to {self.path}"


def name_from_url(url: str) -> str:
"""Get the filename from the end of the URL.
Expand Down
48 changes: 48 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from requests_file import FileAdapter

from pystow.utils import (
DownloadError,
HexDigestError,
download,
get_hexdigests_remote,
Expand Down Expand Up @@ -194,6 +195,53 @@ def test_numpy_io(self):
self.assertTrue(np.array_equal(arr, reloaded_arr))


class TestDownload(unittest.TestCase):
"""Tests for downloading."""

def setUp(self) -> None:
"""Set up a test."""
self.directory_obj = tempfile.TemporaryDirectory()
self.directory = Path(self.directory_obj.name)
self.bad_url = "https://nope.nope/nope.tsv"
self.path_for_bad_url = self.directory.joinpath("nope.tsv")

def tearDown(self) -> None:
"""Tear down a test."""
self.directory_obj.cleanup()

def test_bad_file_error(self):
"""Test that urllib errors are handled properly."""
with self.assertRaises(DownloadError):
download(
url=self.bad_url,
path=self.path_for_bad_url,
backend="urllib",
)
self.assertFalse(self.path_for_bad_url.is_file())

def test_requests_error_stream(self):
"""Test that requests errors are handled properly."""
with self.assertRaises(DownloadError):
download(
url=self.bad_url,
path=self.path_for_bad_url,
backend="requests",
stream=True,
)
self.assertFalse(self.path_for_bad_url.is_file())

def test_requests_error_sync(self):
"""Test that requests errors are handled properly."""
with self.assertRaises(DownloadError):
download(
url=self.bad_url,
path=self.path_for_bad_url,
backend="requests",
stream=False,
)
self.assertFalse(self.path_for_bad_url.is_file())


class TestHashing(unittest.TestCase):
"""Tests for hexdigest checking."""

Expand Down

0 comments on commit 4024603

Please sign in to comment.