diff --git a/minari/namespace.py b/minari/namespace.py index 3de8d593..851b6972 100644 --- a/minari/namespace.py +++ b/minari/namespace.py @@ -3,6 +3,7 @@ import os import re import warnings +from pathlib import Path from typing import Any, Dict, Iterable, List, Optional from minari.storage import get_dataset_path @@ -147,8 +148,8 @@ def list_local_namespaces() -> List[str]: datasets_path = get_dataset_path() namespaces = [] - def recurse_directories(base_path, namespace): - parent_dir = os.path.join(base_path, namespace) + def recurse_directories(base_path: Path, namespace): + parent_dir = base_path.joinpath(namespace) for dir_name in list_non_hidden_dirs(parent_dir): dir_path = os.path.join(parent_dir, dir_name) namespaced_dir_name = os.path.join(namespace, dir_name) diff --git a/minari/storage/hosting.py b/minari/storage/hosting.py index 4d794efb..af04faec 100644 --- a/minari/storage/hosting.py +++ b/minari/storage/hosting.py @@ -5,7 +5,7 @@ import warnings from collections import defaultdict from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List +from typing import Dict, List, Optional from minari.dataset.minari_dataset import gen_dataset_id, parse_dataset_id from minari.dataset.minari_storage import MinariStorage @@ -180,6 +180,7 @@ def download_dataset(dataset_id: str, force_download: bool = False): def list_remote_datasets( latest_version: bool = False, compatible_minari_version: bool = False, + prefix: Optional[str] = None, ) -> Dict[str, Dict[str, str]]: """Get the names and metadata of all the Minari datasets in the remote Farama server. @@ -200,7 +201,7 @@ def download_metadata(dataset_id): if supported_dataset or not compatible_minari_version: return metadata - dataset_ids = cloud_storage.list_datasets() + dataset_ids = cloud_storage.list_datasets(prefix=prefix) with ThreadPoolExecutor(max_workers=10) as executor: remote_metadatas = executor.map(download_metadata, dataset_ids) diff --git a/minari/storage/local.py b/minari/storage/local.py index 5045db2e..e2fcc0b9 100644 --- a/minari/storage/local.py +++ b/minari/storage/local.py @@ -3,7 +3,7 @@ import pathlib import shutil import warnings -from typing import Dict, Iterable, Tuple, Union +from typing import Dict, Iterable, Optional, Tuple, Union from minari.dataset.minari_dataset import ( MinariDataset, @@ -19,9 +19,9 @@ __version__ = importlib.metadata.version("minari") -def list_non_hidden_dirs(path: str) -> Iterable[str]: +def list_non_hidden_dirs(path: pathlib.Path) -> Iterable[str]: """List all non-hidden subdirectories.""" - for d in os.scandir(path): + for d in path.iterdir(): if d.is_dir() and (not d.name.startswith(".")): yield d.name @@ -60,6 +60,7 @@ def load_dataset(dataset_id: str, download: bool = False): def list_local_datasets( latest_version: bool = False, compatible_minari_version: bool = False, + prefix: Optional[str] = None, ) -> Dict[str, Dict[str, Union[str, int, bool]]]: """Get the ids and metadata of all the Minari datasets in the local database. @@ -75,8 +76,11 @@ def list_local_datasets( datasets_path = get_dataset_path() dataset_ids = [] - def recurse_directories(base_path, namespace): - parent_dir = os.path.join(base_path, namespace) + def recurse_directories(base_path: pathlib.Path, namespace): + parent_dir = base_path.joinpath(namespace) + if not parent_dir.exists(): + return + for dir_name in list_non_hidden_dirs(parent_dir): dir_path = os.path.join(parent_dir, dir_name) namespaced_dir_name = os.path.join(namespace, dir_name) @@ -86,7 +90,7 @@ def recurse_directories(base_path, namespace): else: recurse_directories(base_path, namespaced_dir_name) - recurse_directories(datasets_path, "") + recurse_directories(datasets_path, prefix or "") dataset_ids = sorted(dataset_ids, key=dataset_id_sort_key) diff --git a/minari/storage/remotes/huggingface.py b/minari/storage/remotes/huggingface.py index 946592a4..dc7abd35 100644 --- a/minari/storage/remotes/huggingface.py +++ b/minari/storage/remotes/huggingface.py @@ -105,31 +105,38 @@ def upload_namespace(self, namespace: str) -> None: ) def list_datasets(self, prefix: Optional[str] = None) -> Iterable[str]: - if prefix is not None: # TODO: support prefix - raise NotImplementedError("prefix is not supported yet") - - for hf_dataset in self._api.list_datasets(author=self.name): + if prefix is not None: + group_name, _ = self._decompose_path(prefix) + else: + prefix = "" + group_name = None + + hf_datasets = self._api.list_datasets(author=self.name, dataset_name=group_name) + for group_info in hf_datasets: try: repo_metadata = self._api.hf_hub_download( - repo_id=hf_dataset.id, + repo_id=group_info.id, filename=_NAMESPACE_METADATA_FILENAME, repo_type="dataset", ) except EntryNotFoundError: try: self._api.hf_hub_download( - repo_id=hf_dataset.id, + repo_id=group_info.id, filename=f"data/{METADATA_FILE_NAME}", repo_type="dataset", ) - yield hf_dataset.id + if group_info.id.startswith(prefix): + yield group_info.id except Exception: - warnings.warn(f"Skipping {hf_dataset.id} as it is malformed.") + warnings.warn(f"Skipping {group_info.id} as it is malformed.") else: with open(repo_metadata) as f: namespace_metadata = json.load(f) - yield from namespace_metadata.get("datasets", []) + group_datasets = namespace_metadata.get("datasets", []) + group_datasets = filter(lambda x: x.startswith(prefix), group_datasets) + yield from group_datasets def download_dataset(self, dataset_id: Any, path: Path) -> None: repo_id, path_in_repo = self._decompose_path(dataset_id)