forked from stair-lab/melt
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix convention for src/melt/tools/data/loader.py
- Loading branch information
1 parent
c6f8769
commit 22519f4
Showing
1 changed file
with
123 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,90 +1,130 @@ | ||
"""Module for loading datasets from various sources.""" | ||
|
||
import os | ||
from pathlib import Path | ||
from datasets import load_dataset | ||
from transformers.utils.versions import require_version | ||
from ..utils.constants import FILEEXT2TYPE | ||
from typing import Tuple, Any | ||
|
||
# Third-party imports | ||
try: | ||
from transformers.utils.versions import require_version | ||
except ImportError: | ||
require_version = None | ||
|
||
def load_a_dataset(dataset_attr, args): | ||
dataset_training, _ = _load_single_dataset( | ||
dataset_attr, args, dataset_attr.train_split | ||
) | ||
dataset_testing, _ = _load_single_dataset( | ||
dataset_attr, args, dataset_attr.test_split | ||
try: | ||
from modelscope import MsDataset | ||
from modelscope.utils.config_ds import MS_DATASETS_CACHE | ||
except ImportError: | ||
MsDataset = None | ||
MS_DATASETS_CACHE = None | ||
|
||
try: | ||
from datasets import load_dataset | ||
except ImportError: | ||
load_dataset = None | ||
|
||
# First-party imports | ||
try: | ||
from melt.utils.constants import FILEEXT2TYPE | ||
except ImportError: | ||
FILEEXT2TYPE = {} | ||
|
||
def _load_single_dataset(dataset_attr, args, mode) -> Tuple[Any, Any]: | ||
""" | ||
Load a single dataset based on the given attributes and mode. | ||
Args: | ||
dataset_attr: Attributes of the dataset to load. | ||
args: Arguments containing configuration options. | ||
mode: The mode of the dataset (e.g., 'train', 'test'). | ||
Returns: | ||
A tuple containing the loaded dataset and its attributes. | ||
Raises: | ||
NotImplementedError: If the load type is unknown. | ||
ImportError: If required modules are not available. | ||
""" | ||
print(f"Loading {mode} dataset {dataset_attr}...") | ||
|
||
load_functions = { | ||
"hf_hub": _load_from_hf_hub, | ||
"ms_hub": _load_from_ms_hub, | ||
"file": _load_from_file | ||
} | ||
|
||
load_func = load_functions.get(dataset_attr.load_from) | ||
if not load_func: | ||
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.") | ||
|
||
return load_func(dataset_attr, args, mode) | ||
|
||
def _load_from_hf_hub(dataset_attr, args, mode): | ||
if load_dataset is None: | ||
raise ImportError("The 'datasets' library is not installed.") | ||
return load_dataset( | ||
path=dataset_attr.dataset_name, | ||
name=dataset_attr.subset, | ||
data_dir=dataset_attr.folder, | ||
split=mode, | ||
token=args.hf_hub_token, | ||
trust_remote_code=True, | ||
), dataset_attr | ||
|
||
def _load_from_ms_hub(dataset_attr, args, mode): | ||
if MsDataset is None or MS_DATASETS_CACHE is None: | ||
raise ImportError("ModelScope packages are not installed or not available.") | ||
|
||
if require_version is None: | ||
raise ImportError("The 'transformers' library is not installed.") | ||
|
||
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") | ||
|
||
dataset = MsDataset.load( | ||
dataset_name=dataset_attr.dataset_name, | ||
subset_name=dataset_attr.subset, | ||
data_dir=dataset_attr.folder, | ||
split=mode, | ||
cache_dir=MS_DATASETS_CACHE, | ||
token=args.ms_hub_token, | ||
) | ||
return dataset_training, dataset_testing | ||
|
||
|
||
def _load_single_dataset(dataset_attr, args, mode): | ||
print("Loading {} dataset {}...".format(mode, dataset_attr)) | ||
data_path, data_name, data_dir, data_files = None, None, None, None | ||
if dataset_attr.load_from in ["hf_hub", "ms_hub"]: | ||
data_path = dataset_attr.dataset_name | ||
data_name = dataset_attr.subset | ||
data_dir = dataset_attr.folder | ||
|
||
elif dataset_attr.load_from == "file": | ||
data_files = {} | ||
local_path = os.path.join(args.dataset_dir, dataset_attr.dataset_name) | ||
|
||
if os.path.isdir(local_path): # is directory | ||
for file_name in os.listdir(local_path): | ||
if Path(file_name).stem.split("_")[-1] == mode: | ||
data_files[mode] = os.path.join(local_path, file_name) | ||
if data_path is None: | ||
data_path = FILEEXT2TYPE.get( | ||
file_name.split(".")[-1], None | ||
) | ||
elif data_path != FILEEXT2TYPE.get( | ||
file_name.split(".")[-1], None | ||
): | ||
raise ValueError("File types should be identical.") | ||
|
||
if len(data_files) < 1: | ||
raise ValueError("File name is not approriate.") | ||
# elif os.path.isfile(local_path): # is file | ||
# data_files.append(local_path) | ||
# data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None) | ||
else: | ||
raise ValueError("File {} not found.".format(local_path)) | ||
|
||
if data_path is None: | ||
raise ValueError( | ||
"Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())) | ||
) | ||
else: | ||
raise NotImplementedError( | ||
"Unknown load type: {}.".format(dataset_attr.load_from) | ||
) | ||
|
||
if dataset_attr.load_from == "ms_hub": | ||
require_version( | ||
"modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0" | ||
) | ||
from modelscope import MsDataset | ||
from modelscope.utils.config_ds import MS_DATASETS_CACHE | ||
|
||
cache_dir = MS_DATASETS_CACHE | ||
dataset = MsDataset.load( | ||
dataset_name=data_path, | ||
subset_name=data_name, | ||
data_dir=data_dir, | ||
data_files=data_files, | ||
split=mode, | ||
cache_dir=cache_dir, | ||
token=args.ms_hub_token, | ||
) | ||
if isinstance(dataset, MsDataset): | ||
dataset = dataset.to_hf_dataset() | ||
else: | ||
dataset = load_dataset( | ||
path=data_path, | ||
name=data_name, | ||
data_dir=data_dir, | ||
data_files=data_files, | ||
split=mode, | ||
token=args.hf_hub_token, | ||
trust_remote_code=True, | ||
) | ||
|
||
if isinstance(dataset, MsDataset): | ||
dataset = dataset.to_hf_dataset() | ||
|
||
return dataset, dataset_attr | ||
|
||
def _load_from_file(dataset_attr, args, mode): | ||
local_path = os.path.join(args.dataset_dir, dataset_attr.dataset_name) | ||
if not os.path.isdir(local_path): | ||
raise ValueError(f"Directory {local_path} not found.") | ||
|
||
data_files = {} | ||
data_path = None | ||
|
||
for file_name in os.listdir(local_path): | ||
if Path(file_name).stem.split("_")[-1] == mode: | ||
data_files[mode] = os.path.join(local_path, file_name) | ||
file_ext = file_name.split(".")[-1] | ||
current_data_path = FILEEXT2TYPE.get(file_ext) | ||
|
||
if data_path is None: | ||
data_path = current_data_path | ||
elif data_path != current_data_path: | ||
raise ValueError("File types should be identical.") | ||
|
||
if not data_files: | ||
raise ValueError("No appropriate file found.") | ||
|
||
if data_path is None: | ||
raise ValueError(f"Allowed file types: {', '.join(FILEEXT2TYPE.keys())}.") | ||
|
||
if load_dataset is None: | ||
raise ImportError("The 'datasets' library is not installed.") | ||
|
||
return load_dataset( | ||
path=data_path, | ||
data_files=data_files, | ||
split=mode, | ||
token=args.hf_hub_token, | ||
trust_remote_code=True, | ||
), dataset_attr |