Skip to content

Commit

Permalink
Merge pull request #164 from RobustBench/fix-download
Browse files Browse the repository at this point in the history
Fix download
  • Loading branch information
fra31 authored Jan 17, 2024
2 parents 94fa328 + cf2fab1 commit c04877b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
7 changes: 7 additions & 0 deletions robustbench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
BenchmarkDataset.imagenet: 1000,
}

CANNED_USER_AGENT="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36" # NOQA


def download_gdrive(gdrive_id, fname_save):
""" source: https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url """
Expand All @@ -55,6 +57,11 @@ def save_response_content(response, fname_save):
url_base = "https://docs.google.com/uc?export=download&confirm=t"
session = requests.Session()

# Fix from https://github.com/wkentaro/gdown/pull/294.
session.headers.update(
{"User-Agent": CANNED_USER_AGENT}
)

response = session.get(url_base, params={'id': gdrive_id}, stream=True)
token = get_confirm_token(response)

Expand Down
3 changes: 2 additions & 1 deletion robustbench/zenodo_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def zenodo_download(record_id: str, filenames_to_download: Set[str],
"The hash of the downloaded file does not match"
" the expected one.")
print("Download finished, extracting...")
format = file["type"] if "type" in file.keys() else file["key"].split('.')[-1]
shutil.unpack_archive(filename,
extract_dir=save_dir,
format=file["type"])
format=format)
print("Downloaded and extracted.")

0 comments on commit c04877b

Please sign in to comment.