Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept datasets stored in s3 #420

Merged
merged 5 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions amlb/datasets/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..resources import config as rconfig
from ..utils import Namespace as ns, as_list, lazy_property, list_all_files, memoize, path_from_split, profile, split_path

from .fileutils import download_file, is_archive, is_valid_url, unarchive_file, url_exists
from .fileutils import is_archive, is_valid_url, unarchive_file, get_file_handler


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -118,8 +118,9 @@ def _extract_train_test_paths(self, dataset, fold=None):
elif is_valid_url(dataset):
cached_file = os.path.join(self._cache_dir, os.path.basename(dataset))
if not os.path.exists(cached_file): # don't download if previously done
assert url_exists(dataset), f"Invalid path/url: {dataset}"
download_file(dataset, cached_file)
handler = get_file_handler(dataset)
assert handler.exists(dataset), f"Invalid path/url: {dataset}"
handler.download(dataset, dest_path=cached_file)
return self._extract_train_test_paths(cached_file)
else:
raise ValueError(f"Invalid dataset description: {dataset}")
Expand Down
86 changes: 66 additions & 20 deletions amlb/datasets/fileutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,84 @@
import tarfile
from urllib.error import URLError
from urllib.parse import urlparse
from urllib.request import Request, urlopen, urlretrieve
from urllib.request import Request, urlopen
import zipfile

from ..utils import touch

log = logging.getLogger(__name__)

SUPPORTED_SCHEMES = ("http", "https")

class FileHandler:
def exists(self, url): pass
def download(self, url, dest_path): pass


class HttpHandler(FileHandler):
def exists(self, url):
head_req = Request(url, method='HEAD')
try:
with urlopen(head_req) as test:
return test.status == 200
except URLError as e:
log.error(f"Cannot access url %s: %s", url, e)
return False

def download(self, url, dest_path):
touch(dest_path)
with urlopen(url) as resp, open(dest_path, 'wb') as dest:
shutil.copyfileobj(resp, dest)


class S3Handler(FileHandler):
def exists(self, url):
import boto3
from botocore.errorfactory import ClientError
yinweisu marked this conversation as resolved.
Show resolved Hide resolved
s3 = boto3.client('s3')
bucket, key = self._s3_path_to_bucket_prefix(url)
try:
s3.head_object(Bucket=bucket, Key=key)
return True
except ClientError as e:
log.error(f"Cannot access url %s: %s", url, e)
return False

def download(self, url, dest_path):
import boto3
from botocore.errorfactory import ClientError
touch(dest_path)
s3 = boto3.resource('s3')
bucket, key = self._s3_path_to_bucket_prefix(url)
try:
s3.Bucket(bucket).download_file(key, dest_path)
except ClientError as e:
if e.response['Error']['Code'] == "404":
log.error("The object does not exist.")
else:
raise

def _s3_path_to_bucket_prefix(self, s3_path):
s3_path_cleaned = s3_path.split('://', 1)[1]
bucket, prefix = s3_path_cleaned.split('/', 1)
return bucket, prefix

def is_valid_url(url):
return urlparse(url).scheme in SUPPORTED_SCHEMES

scheme_handlers = dict(
http=HttpHandler(),
https=HttpHandler(),
s3=S3Handler(),
s3a=S3Handler
yinweisu marked this conversation as resolved.
Show resolved Hide resolved
)

def url_exists(url):
if not is_valid_url(url):
return False
head_req = Request(url, method='HEAD')
try:
with urlopen(head_req) as test:
return test.status == 200
except URLError as e:
log.error(f"Cannot access url %s: %s", url, e)
return False
SUPPORTED_SCHEMES = list(scheme_handlers.keys())


def download_file(url, dest_path):
touch(dest_path)
# urlretrieve(url, filename=dest_path)
with urlopen(url) as resp, open(dest_path, 'wb') as dest:
shutil.copyfileobj(resp, dest)
def get_file_handler(url):
return scheme_handlers[urlparse(url).scheme]


def is_valid_url(url):
return urlparse(url).scheme in SUPPORTED_SCHEMES


def is_archive(path):
Expand All @@ -52,4 +99,3 @@ def unarchive_file(path, dest_folder=None):
with tarfile.open(path) as tf:
tf.extractall(path=dest_folder)
return dest