Skip to content

Commit

Permalink
fix bugs for huggingface dataset loading; add sample config
Browse files Browse the repository at this point in the history
  • Loading branch information
cyruszhang committed Jan 27, 2025
1 parent acccc01 commit 1823cd6
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 16 deletions.
22 changes: 22 additions & 0 deletions configs/demo/process-huggingface.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Process config example for dataset

# global parameters
project_name: 'demo-process'
dataset:
configs:
- type: 'remote'
source: 'huggingface'
path: 'hugfaceguy0001/retarded_bar'
name: 'question'
split: 'train'

np: 4 # number of subprocess to process your dataset

export_path: './outputs/demo-process/demo-processed.jsonl'

# process schedule
# a list of several process operators with their arguments
process:
- language_id_score_filter:
lang: 'zh'
min_score: 0.8
11 changes: 7 additions & 4 deletions data_juicer/core/data/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,13 @@ def __init__(self, cfg: Namespace, executor_type: str = 'default'):
# initialize data loading strategy
data_type = ds_config.get('type', None)
data_source = ds_config.get('source', None)
self.load_strategies.append(
DataLoadStrategyRegistry.get_strategy_class(
self.executor_type, data_type, data_source)(ds_config,
cfg=self.cfg))
stra = DataLoadStrategyRegistry.get_strategy_class(
self.executor_type, data_type, data_source)(ds_config,
cfg=self.cfg)
if stra is None:
raise ValueError(f'No data load strategy found for'
f' {data_type} {data_source}')
self.load_strategies.append(stra)

# initialzie the sample numbers
self.max_sample_num = ds_configs.get('max_sample_num', None)
Expand Down
37 changes: 25 additions & 12 deletions data_juicer/core/data/load_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict, Optional, Type, Union

import datasets
from loguru import logger

from data_juicer.core.data import DJDataset, RayDataset
from data_juicer.core.data.config_validator import ConfigValidator
Expand Down Expand Up @@ -78,6 +79,11 @@ def get_strategy_class(
1. Exact match
2. Wildcard matches from most specific to most general
"""
logger.info(f'Getting strategy class for '
f'exec: {executor_type}, '
f'data_type: {data_type}, '
f'data_source: {data_source}')

# default to wildcard if not provided
executor_type = executor_type or '*'
data_type = data_type or '*'
Expand Down Expand Up @@ -113,9 +119,12 @@ def specificity_score(key: StrategyKey) -> int:
if part == '*')

matching_strategies.sort(key=lambda x: specificity_score(x[0]))
return matching_strategies[0][1]
found = matching_strategies[0][1]
logger.info(f'Found matching strategies: {found}')
return found

# No matching strategy found
logger.warning('No matching strategy found')
return None

@classmethod
Expand Down Expand Up @@ -247,7 +256,8 @@ class DefaultHuggingfaceDataLoadStrategy(DefaultDataLoadStrategy):

CONFIG_VALIDATION_RULES = {
'required_fields': ['path'],
'optional_fields': ['split', 'limit', 'name'],
'optional_fields':
['split', 'limit', 'name', 'data_files', 'data_dir'],
'field_types': {
'path': str
},
Expand All @@ -256,16 +266,19 @@ class DefaultHuggingfaceDataLoadStrategy(DefaultDataLoadStrategy):

def load_data(self, **kwargs):
num_proc = kwargs.pop('num_proc', 1)
ds = datasets.load_dataset(self.ds_config['path'],
split=self.ds_config.split,
name=self.ds_config.name,
limit=self.ds_config.limit,
num_proc=num_proc,
**kwargs)
ds = unify_format(ds,
text_keys=self.text_keys,
num_proc=num_proc,
global_cfg=self.cfg)
ds = datasets.load_dataset(
self.ds_config['path'],
split=self.ds_config.get('split', None),
data_files=self.ds_config.get('data_files', None),
data_dir=self.ds_config.get('data_dir', None),
name=self.ds_config.get('name', None),
limit=self.ds_config.get('limit', None),
num_proc=num_proc,
**kwargs)
return unify_format(ds,
text_keys=self.cfg.text_keys,
num_proc=num_proc,
global_cfg=self.cfg)


@DataLoadStrategyRegistry.register('default', 'remote', 'modelscope')
Expand Down

0 comments on commit 1823cd6

Please sign in to comment.