Skip to content

Commit

Permalink
Added support for dataset files stored on s3 (#420)
Browse files Browse the repository at this point in the history
* s3 functionality

* Update amlb/datasets/fileutils.py

Co-authored-by: Pieter Gijsbers <p.gijsbers@tue.nl>

* OOD

* add s3n

* move boto3 import

Co-authored-by: Weisu Yin <weisuyin96@gmail.com>
Co-authored-by: Pieter Gijsbers <p.gijsbers@tue.nl>
  • Loading branch information
3 people authored Dec 7, 2021
1 parent aee0891 commit beef024
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 23 deletions.
7 changes: 4 additions & 3 deletions amlb/datasets/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..resources import config as rconfig
from ..utils import Namespace as ns, as_list, lazy_property, list_all_files, memoize, path_from_split, profile, split_path

from .fileutils import download_file, is_archive, is_valid_url, unarchive_file, url_exists
from .fileutils import is_archive, is_valid_url, unarchive_file, get_file_handler


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -118,8 +118,9 @@ def _extract_train_test_paths(self, dataset, fold=None):
elif is_valid_url(dataset):
cached_file = os.path.join(self._cache_dir, os.path.basename(dataset))
if not os.path.exists(cached_file): # don't download if previously done
assert url_exists(dataset), f"Invalid path/url: {dataset}"
download_file(dataset, cached_file)
handler = get_file_handler(dataset)
assert handler.exists(dataset), f"Invalid path/url: {dataset}"
handler.download(dataset, dest_path=cached_file)
return self._extract_train_test_paths(cached_file)
else:
raise ValueError(f"Invalid dataset description: {dataset}")
Expand Down
85 changes: 65 additions & 20 deletions amlb/datasets/fileutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,85 @@
import os
import shutil
import tarfile
import boto3
from botocore.errorfactory import ClientError
from urllib.error import URLError
from urllib.parse import urlparse
from urllib.request import Request, urlopen, urlretrieve
from urllib.request import Request, urlopen
import zipfile

from ..utils import touch

log = logging.getLogger(__name__)

SUPPORTED_SCHEMES = ("http", "https")

class FileHandler:
def exists(self, url): pass
def download(self, url, dest_path): pass


class HttpHandler(FileHandler):
def exists(self, url):
head_req = Request(url, method='HEAD')
try:
with urlopen(head_req) as test:
return test.status == 200
except URLError as e:
log.error(f"Cannot access url %s: %s", url, e)
return False

def download(self, url, dest_path):
touch(dest_path)
with urlopen(url) as resp, open(dest_path, 'wb') as dest:
shutil.copyfileobj(resp, dest)


class S3Handler(FileHandler):
def exists(self, url):
s3 = boto3.client('s3')
bucket, key = self._s3_path_to_bucket_prefix(url)
try:
s3.head_object(Bucket=bucket, Key=key)
return True
except ClientError as e:
log.error(f"Cannot access url %s: %s", url, e)
return False

def download(self, url, dest_path):
touch(dest_path)
s3 = boto3.resource('s3')
bucket, key = self._s3_path_to_bucket_prefix(url)
try:
s3.Bucket(bucket).download_file(key, dest_path)
except ClientError as e:
if e.response['Error']['Code'] == "404":
log.error("The object does not exist.")
else:
raise

def _s3_path_to_bucket_prefix(self, s3_path):
s3_path_cleaned = s3_path.split('://', 1)[1]
bucket, prefix = s3_path_cleaned.split('/', 1)
return bucket, prefix

def is_valid_url(url):
return urlparse(url).scheme in SUPPORTED_SCHEMES

scheme_handlers = dict(
http=HttpHandler(),
https=HttpHandler(),
s3=S3Handler(),
s3a=S3Handler(),
s3n=S3Handler(),
)

def url_exists(url):
if not is_valid_url(url):
return False
head_req = Request(url, method='HEAD')
try:
with urlopen(head_req) as test:
return test.status == 200
except URLError as e:
log.error(f"Cannot access url %s: %s", url, e)
return False
SUPPORTED_SCHEMES = list(scheme_handlers.keys())


def download_file(url, dest_path):
touch(dest_path)
# urlretrieve(url, filename=dest_path)
with urlopen(url) as resp, open(dest_path, 'wb') as dest:
shutil.copyfileobj(resp, dest)
def get_file_handler(url):
return scheme_handlers[urlparse(url).scheme]


def is_valid_url(url):
return urlparse(url).scheme in SUPPORTED_SCHEMES


def is_archive(path):
Expand All @@ -52,4 +98,3 @@ def unarchive_file(path, dest_folder=None):
with tarfile.open(path) as tf:
tf.extractall(path=dest_folder)
return dest

0 comments on commit beef024

Please sign in to comment.