-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2537 from jkamelin/jk/md_api
tools/downloader: separate download engine
- Loading branch information
Showing
11 changed files
with
623 additions
and
519 deletions.
There are no files selected for viewing
310 changes: 56 additions & 254 deletions
310
tools/downloader/src/open_model_zoo/model_tools/_configuration.py
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
100 changes: 100 additions & 0 deletions
100
tools/downloader/src/open_model_zoo/model_tools/download_engine/cache.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# Copyright (c) 2021 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import hashlib | ||
import shutil | ||
import sys | ||
import tempfile | ||
|
||
from pathlib import Path | ||
|
||
CHUNK_SIZE = 1 << 15 if sys.stdout.isatty() else 1 << 20 | ||
|
||
|
||
class NullCache: | ||
def has(self, hash): return False | ||
def get(self, model_file, path, reporter): return False | ||
def put(self, hash, path): pass | ||
|
||
|
||
class DirCache: | ||
_FORMAT = 1 # increment if backwards-incompatible changes to the format are made | ||
_HASH_LEN = hashlib.sha256().digest_size | ||
|
||
def __init__(self, cache_dir): | ||
self._cache_dir = cache_dir / str(self._FORMAT) | ||
self._cache_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
self._staging_dir = self._cache_dir / 'staging' | ||
self._staging_dir.mkdir(exist_ok=True) | ||
|
||
def _hash_path(self, hash): | ||
assert len(hash) == self._HASH_LEN | ||
hash_str = hash.hex().lower() | ||
return self._cache_dir / hash_str[:2] / hash_str[2:] | ||
|
||
def has(self, hash): | ||
return self._hash_path(hash).exists() | ||
|
||
def get(self, model_file, path, reporter): | ||
cache_path = self._hash_path(model_file.sha256) | ||
cache_sha256 = hashlib.sha256() | ||
cache_size = 0 | ||
|
||
with open(cache_path, 'rb') as cache_file, open(path, 'wb') as destination_file: | ||
while True: | ||
data = cache_file.read(CHUNK_SIZE) | ||
if not data: | ||
break | ||
cache_size += len(data) | ||
if cache_size > model_file.size: | ||
reporter.log_error("Cached file is longer than expected ({} B), copying aborted", model_file.size) | ||
return False | ||
cache_sha256.update(data) | ||
destination_file.write(data) | ||
if cache_size < model_file.size: | ||
reporter.log_error("Cached file is shorter ({} B) than expected ({} B)", cache_size, model_file.size) | ||
return False | ||
return verify_hash(reporter, cache_sha256.digest(), model_file.sha256, path) | ||
|
||
def put(self, hash, path): | ||
staging_path = None | ||
|
||
try: | ||
# A file in the cache must have the hash implied by its name. So when we upload a file, | ||
# we first copy it to a temporary file and then atomically move it to the desired name. | ||
# This prevents interrupted runs from corrupting the cache. | ||
with path.open('rb') as src_file: | ||
with tempfile.NamedTemporaryFile(dir=str(self._staging_dir), delete=False) as staging_file: | ||
staging_path = Path(staging_file.name) | ||
shutil.copyfileobj(src_file, staging_file) | ||
|
||
hash_path = self._hash_path(hash) | ||
hash_path.parent.mkdir(parents=True, exist_ok=True) | ||
staging_path.replace(self._hash_path(hash)) | ||
staging_path = None | ||
finally: | ||
# If we failed to complete our temporary file or to move it into place, | ||
# get rid of it. | ||
if staging_path: | ||
staging_path.unlink() | ||
|
||
|
||
def verify_hash(reporter, actual_hash, expected_hash, path): | ||
if actual_hash != expected_hash: | ||
reporter.log_error('Hash mismatch for "{}"', path) | ||
reporter.log_details('Expected: {}', expected_hash.hex()) | ||
reporter.log_details('Actual: {}', actual_hash.hex()) | ||
return False | ||
return True |
201 changes: 201 additions & 0 deletions
201
tools/downloader/src/open_model_zoo/model_tools/download_engine/downloader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
# Copyright (c) 2021 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import functools | ||
import hashlib | ||
import requests | ||
import ssl | ||
import time | ||
import types | ||
|
||
from open_model_zoo.model_tools.download_engine import cache | ||
|
||
DOWNLOAD_TIMEOUT = 5 * 60 | ||
|
||
class Downloader: | ||
def __init__(self, output_dir=None, cache_dir=None, num_attempts=1, timeout=DOWNLOAD_TIMEOUT): | ||
self.output_dir = output_dir | ||
self.cache = cache.NullCache() if cache_dir is None else cache.DirCache(cache_dir) | ||
self.num_attempts = num_attempts | ||
self.timeout = timeout | ||
|
||
def _process_download(self, reporter, chunk_iterable, size, progress, file): | ||
start_time = time.monotonic() | ||
start_size = progress.size | ||
|
||
try: | ||
for chunk in chunk_iterable: | ||
reporter.job_context.check_interrupted() | ||
|
||
if chunk: | ||
duration = time.monotonic() - start_time | ||
progress.size += len(chunk) | ||
progress.hasher.update(chunk) | ||
|
||
if duration != 0: | ||
speed = int((progress.size - start_size) / (1024 * duration)) | ||
else: | ||
speed = '?' | ||
|
||
percent = progress.size * 100 // size | ||
|
||
reporter.print_progress('... {}%, {} KB, {} KB/s, {} seconds passed', | ||
percent, progress.size // 1024, speed, int(duration)) | ||
reporter.emit_event('model_file_download_progress', size=progress.size) | ||
|
||
file.write(chunk) | ||
|
||
# don't attempt to finish a file if it's bigger than expected | ||
if progress.size > size: | ||
break | ||
finally: | ||
reporter.end_progress() | ||
|
||
def _try_download(self, reporter, file, start_download, size): | ||
progress = types.SimpleNamespace(size=0) | ||
|
||
for attempt in range(self.num_attempts): | ||
if attempt != 0: | ||
retry_delay = 10 | ||
reporter.print("Will retry in {} seconds...", retry_delay, flush=True) | ||
time.sleep(retry_delay) | ||
|
||
try: | ||
reporter.job_context.check_interrupted() | ||
chunk_iterable, continue_offset = start_download(offset=progress.size, timeout=self.timeout) | ||
|
||
if continue_offset not in {0, progress.size}: | ||
# Somehow we neither restarted nor continued from where we left off. | ||
# Try to restart. | ||
chunk_iterable, continue_offset = start_download(offset=0, timeout=self.timeout) | ||
if continue_offset != 0: | ||
reporter.log_error("Remote server refuses to send whole file, aborting") | ||
return None | ||
|
||
if continue_offset == 0: | ||
file.seek(0) | ||
file.truncate() | ||
progress.size = 0 | ||
progress.hasher = hashlib.sha256() | ||
|
||
self._process_download(reporter, chunk_iterable, size, progress, file) | ||
|
||
if progress.size > size: | ||
reporter.log_error("Remote file is longer than expected ({} B), download aborted", size) | ||
# no sense in retrying - if the file is longer, there's no way it'll fix itself | ||
return None | ||
elif progress.size < size: | ||
reporter.log_error("Downloaded file is shorter ({} B) than expected ({} B)", | ||
progress.size, size) | ||
# it's possible that we got disconnected before receiving the full file, | ||
# so try again | ||
else: | ||
return progress.hasher.digest() | ||
except (requests.exceptions.RequestException, ssl.SSLError): | ||
reporter.log_error("Download failed", exc_info=True) | ||
|
||
return None | ||
|
||
def _try_retrieve_from_cache(self, reporter, model_file, destination): | ||
try: | ||
if self.cache.has(model_file.sha256): | ||
reporter.job_context.check_interrupted() | ||
|
||
reporter.print_section_heading('Retrieving {} from the cache', destination) | ||
if not self.cache.get(model_file, destination, reporter): | ||
reporter.print('Will retry from the original source.') | ||
reporter.print() | ||
return False | ||
reporter.print() | ||
return True | ||
except Exception: | ||
reporter.log_warning('Cache retrieval failed; falling back to downloading', exc_info=True) | ||
reporter.print() | ||
|
||
return False | ||
|
||
@staticmethod | ||
def _try_update_cache(reporter, cache, hash, source): | ||
try: | ||
cache.put(hash, source) | ||
except Exception: | ||
reporter.log_warning('Failed to update the cache', exc_info=True) | ||
|
||
def _try_retrieve(self, reporter, destination, model_file, start_download): | ||
destination.parent.mkdir(parents=True, exist_ok=True) | ||
|
||
if self._try_retrieve_from_cache(reporter, model_file, destination): | ||
return True | ||
|
||
reporter.print_section_heading('Downloading {}', destination) | ||
|
||
success = False | ||
|
||
with destination.open('w+b') as f: | ||
actual_hash = self._try_download(reporter, f, start_download, model_file.size) | ||
|
||
if actual_hash and cache.verify_hash(reporter, actual_hash, model_file.sha256, destination): | ||
self._try_update_cache(reporter, self.cache, model_file.sha256, destination) | ||
success = True | ||
|
||
reporter.print() | ||
return success | ||
|
||
def download_model(self, reporter, session_factory, requested_precisions, model, known_precisions): | ||
session = session_factory() | ||
|
||
reporter.print_group_heading('Downloading {}', model.name) | ||
|
||
reporter.emit_event('model_download_begin', model=model.name, num_files=len(model.files)) | ||
|
||
output = self.output_dir / model.subdirectory | ||
output.mkdir(parents=True, exist_ok=True) | ||
|
||
for model_file in model.files: | ||
if len(model_file.name.parts) == 2: | ||
p = model_file.name.parts[0] | ||
if p in known_precisions and p not in requested_precisions: | ||
continue | ||
|
||
model_file_reporter = reporter.with_event_context(model=model.name, model_file=model_file.name.as_posix()) | ||
model_file_reporter.emit_event('model_file_download_begin', size=model_file.size) | ||
|
||
destination = output / model_file.name | ||
|
||
if not self._try_retrieve(model_file_reporter, destination, model_file, | ||
functools.partial(model_file.source.start_download, session, cache.CHUNK_SIZE)): | ||
try: | ||
destination.unlink() | ||
except FileNotFoundError: | ||
pass | ||
|
||
model_file_reporter.emit_event('model_file_download_end', successful=False) | ||
reporter.emit_event('model_download_end', model=model.name, successful=False) | ||
return False | ||
|
||
model_file_reporter.emit_event('model_file_download_end', successful=True) | ||
|
||
reporter.emit_event('model_download_end', model=model.name, successful=True) | ||
|
||
if model.postprocessing: | ||
reporter.emit_event('model_postprocessing_begin', model=model.name) | ||
|
||
for postproc in model.postprocessing: | ||
postproc.apply(reporter, output) | ||
|
||
reporter.emit_event('model_postprocessing_end', model=model.name) | ||
|
||
reporter.print() | ||
|
||
return True |
Oops, something went wrong.