Skip to content

Commit

Permalink
Fix convention for src/melt/tools/data/loader.py
Browse files Browse the repository at this point in the history
  • Loading branch information
minhtrung23 committed Sep 7, 2024
1 parent c6f8769 commit 22519f4
Showing 1 changed file with 123 additions and 83 deletions.
206 changes: 123 additions & 83 deletions src/melt/tools/data/loader.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,130 @@
"""Module for loading datasets from various sources."""

import os
from pathlib import Path
from datasets import load_dataset
from transformers.utils.versions import require_version
from ..utils.constants import FILEEXT2TYPE
from typing import Tuple, Any

# Third-party imports
try:
from transformers.utils.versions import require_version
except ImportError:
require_version = None

def load_a_dataset(dataset_attr, args):
dataset_training, _ = _load_single_dataset(
dataset_attr, args, dataset_attr.train_split
)
dataset_testing, _ = _load_single_dataset(
dataset_attr, args, dataset_attr.test_split
try:
from modelscope import MsDataset
from modelscope.utils.config_ds import MS_DATASETS_CACHE
except ImportError:
MsDataset = None
MS_DATASETS_CACHE = None

try:
from datasets import load_dataset
except ImportError:
load_dataset = None

# First-party imports
try:
from melt.utils.constants import FILEEXT2TYPE
except ImportError:
FILEEXT2TYPE = {}

def _load_single_dataset(dataset_attr, args, mode) -> Tuple[Any, Any]:
"""
Load a single dataset based on the given attributes and mode.
Args:
dataset_attr: Attributes of the dataset to load.
args: Arguments containing configuration options.
mode: The mode of the dataset (e.g., 'train', 'test').
Returns:
A tuple containing the loaded dataset and its attributes.
Raises:
NotImplementedError: If the load type is unknown.
ImportError: If required modules are not available.
"""
print(f"Loading {mode} dataset {dataset_attr}...")

load_functions = {
"hf_hub": _load_from_hf_hub,
"ms_hub": _load_from_ms_hub,
"file": _load_from_file
}

load_func = load_functions.get(dataset_attr.load_from)
if not load_func:
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")

return load_func(dataset_attr, args, mode)

def _load_from_hf_hub(dataset_attr, args, mode):
if load_dataset is None:
raise ImportError("The 'datasets' library is not installed.")
return load_dataset(
path=dataset_attr.dataset_name,
name=dataset_attr.subset,
data_dir=dataset_attr.folder,
split=mode,
token=args.hf_hub_token,
trust_remote_code=True,
), dataset_attr

def _load_from_ms_hub(dataset_attr, args, mode):
if MsDataset is None or MS_DATASETS_CACHE is None:
raise ImportError("ModelScope packages are not installed or not available.")

if require_version is None:
raise ImportError("The 'transformers' library is not installed.")

require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")

dataset = MsDataset.load(
dataset_name=dataset_attr.dataset_name,
subset_name=dataset_attr.subset,
data_dir=dataset_attr.folder,
split=mode,
cache_dir=MS_DATASETS_CACHE,
token=args.ms_hub_token,
)
return dataset_training, dataset_testing


def _load_single_dataset(dataset_attr, args, mode):
print("Loading {} dataset {}...".format(mode, dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset
data_dir = dataset_attr.folder

elif dataset_attr.load_from == "file":
data_files = {}
local_path = os.path.join(args.dataset_dir, dataset_attr.dataset_name)

if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path):
if Path(file_name).stem.split("_")[-1] == mode:
data_files[mode] = os.path.join(local_path, file_name)
if data_path is None:
data_path = FILEEXT2TYPE.get(
file_name.split(".")[-1], None
)
elif data_path != FILEEXT2TYPE.get(
file_name.split(".")[-1], None
):
raise ValueError("File types should be identical.")

if len(data_files) < 1:
raise ValueError("File name is not approriate.")
# elif os.path.isfile(local_path): # is file
# data_files.append(local_path)
# data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else:
raise ValueError("File {} not found.".format(local_path))

if data_path is None:
raise ValueError(
"Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys()))
)
else:
raise NotImplementedError(
"Unknown load type: {}.".format(dataset_attr.load_from)
)

if dataset_attr.load_from == "ms_hub":
require_version(
"modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0"
)
from modelscope import MsDataset
from modelscope.utils.config_ds import MS_DATASETS_CACHE

cache_dir = MS_DATASETS_CACHE
dataset = MsDataset.load(
dataset_name=data_path,
subset_name=data_name,
data_dir=data_dir,
data_files=data_files,
split=mode,
cache_dir=cache_dir,
token=args.ms_hub_token,
)
if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset()
else:
dataset = load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=mode,
token=args.hf_hub_token,
trust_remote_code=True,
)

if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset()

return dataset, dataset_attr

def _load_from_file(dataset_attr, args, mode):
local_path = os.path.join(args.dataset_dir, dataset_attr.dataset_name)
if not os.path.isdir(local_path):
raise ValueError(f"Directory {local_path} not found.")

data_files = {}
data_path = None

for file_name in os.listdir(local_path):
if Path(file_name).stem.split("_")[-1] == mode:
data_files[mode] = os.path.join(local_path, file_name)
file_ext = file_name.split(".")[-1]
current_data_path = FILEEXT2TYPE.get(file_ext)

if data_path is None:
data_path = current_data_path
elif data_path != current_data_path:
raise ValueError("File types should be identical.")

if not data_files:
raise ValueError("No appropriate file found.")

if data_path is None:
raise ValueError(f"Allowed file types: {', '.join(FILEEXT2TYPE.keys())}.")

if load_dataset is None:
raise ImportError("The 'datasets' library is not installed.")

return load_dataset(
path=data_path,
data_files=data_files,
split=mode,
token=args.hf_hub_token,
trust_remote_code=True,
), dataset_attr

0 comments on commit 22519f4

Please sign in to comment.