diff --git a/amlb/datasets/file.py b/amlb/datasets/file.py index bcb941ab2..8e696f135 100644 --- a/amlb/datasets/file.py +++ b/amlb/datasets/file.py @@ -15,7 +15,7 @@ from ..resources import config as rconfig from ..utils import Namespace as ns, as_list, lazy_property, list_all_files, memoize, path_from_split, profile, split_path -from .fileutils import download_file, is_archive, is_valid_url, unarchive_file, url_exists +from .fileutils import is_archive, is_valid_url, unarchive_file, get_file_handler log = logging.getLogger(__name__) @@ -118,8 +118,9 @@ def _extract_train_test_paths(self, dataset, fold=None): elif is_valid_url(dataset): cached_file = os.path.join(self._cache_dir, os.path.basename(dataset)) if not os.path.exists(cached_file): # don't download if previously done - assert url_exists(dataset), f"Invalid path/url: {dataset}" - download_file(dataset, cached_file) + handler = get_file_handler(dataset) + assert handler.exists(dataset), f"Invalid path/url: {dataset}" + handler.download(dataset, dest_path=cached_file) return self._extract_train_test_paths(cached_file) else: raise ValueError(f"Invalid dataset description: {dataset}") diff --git a/amlb/datasets/fileutils.py b/amlb/datasets/fileutils.py index bce8cbc4f..9be9a84e3 100644 --- a/amlb/datasets/fileutils.py +++ b/amlb/datasets/fileutils.py @@ -2,39 +2,85 @@ import os import shutil import tarfile +import boto3 +from botocore.errorfactory import ClientError from urllib.error import URLError from urllib.parse import urlparse -from urllib.request import Request, urlopen, urlretrieve +from urllib.request import Request, urlopen import zipfile from ..utils import touch log = logging.getLogger(__name__) -SUPPORTED_SCHEMES = ("http", "https") +class FileHandler: + def exists(self, url): pass + def download(self, url, dest_path): pass + + +class HttpHandler(FileHandler): + def exists(self, url): + head_req = Request(url, method='HEAD') + try: + with urlopen(head_req) as test: + return test.status == 200 + except URLError as e: + log.error(f"Cannot access url %s: %s", url, e) + return False + + def download(self, url, dest_path): + touch(dest_path) + with urlopen(url) as resp, open(dest_path, 'wb') as dest: + shutil.copyfileobj(resp, dest) + + +class S3Handler(FileHandler): + def exists(self, url): + s3 = boto3.client('s3') + bucket, key = self._s3_path_to_bucket_prefix(url) + try: + s3.head_object(Bucket=bucket, Key=key) + return True + except ClientError as e: + log.error(f"Cannot access url %s: %s", url, e) + return False + + def download(self, url, dest_path): + touch(dest_path) + s3 = boto3.resource('s3') + bucket, key = self._s3_path_to_bucket_prefix(url) + try: + s3.Bucket(bucket).download_file(key, dest_path) + except ClientError as e: + if e.response['Error']['Code'] == "404": + log.error("The object does not exist.") + else: + raise + + def _s3_path_to_bucket_prefix(self, s3_path): + s3_path_cleaned = s3_path.split('://', 1)[1] + bucket, prefix = s3_path_cleaned.split('/', 1) + return bucket, prefix -def is_valid_url(url): - return urlparse(url).scheme in SUPPORTED_SCHEMES +scheme_handlers = dict( + http=HttpHandler(), + https=HttpHandler(), + s3=S3Handler(), + s3a=S3Handler(), + s3n=S3Handler(), +) -def url_exists(url): - if not is_valid_url(url): - return False - head_req = Request(url, method='HEAD') - try: - with urlopen(head_req) as test: - return test.status == 200 - except URLError as e: - log.error(f"Cannot access url %s: %s", url, e) - return False +SUPPORTED_SCHEMES = list(scheme_handlers.keys()) -def download_file(url, dest_path): - touch(dest_path) - # urlretrieve(url, filename=dest_path) - with urlopen(url) as resp, open(dest_path, 'wb') as dest: - shutil.copyfileobj(resp, dest) +def get_file_handler(url): + return scheme_handlers[urlparse(url).scheme] + + +def is_valid_url(url): + return urlparse(url).scheme in SUPPORTED_SCHEMES def is_archive(path): @@ -52,4 +98,3 @@ def unarchive_file(path, dest_folder=None): with tarfile.open(path) as tf: tf.extractall(path=dest_folder) return dest -