Skip to content

Commit

Permalink
Task config (#289)
Browse files Browse the repository at this point in the history
* add new params to config class

* clean up task/config

* connect datatasert revision and filter

* add tests for filtering/revision

* nit

* nit+1

* remove redudant check

---------

Co-authored-by: Hynek Kydlicek <kydliceh.hynek@gmail.com>
  • Loading branch information
hynky1999 and Hynek Kydlicek authored Sep 13, 2024
1 parent 5034a96 commit 919be47
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 62 deletions.
122 changes: 66 additions & 56 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
import collections
import inspect
import random
from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, field
from multiprocessing import Pool
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple

from datasets import DatasetDict
from huggingface_hub import TextGenerationInputGrammarType
from pytablewriter import MarkdownTableWriter

Expand All @@ -54,7 +54,7 @@
RequestType,
SampleUid,
)
from lighteval.utils.utils import as_list, download_dataset_worker
from lighteval.utils.utils import ListLike, as_list, download_dataset_worker


if TYPE_CHECKING:
Expand Down Expand Up @@ -82,55 +82,58 @@ class LightevalTaskConfig:
original_num_docs (int): Number of documents in the task
effective_num_docs (int): Number of documents used in a specific evaluation
truncated_num_docs (bool): Whether less than the total number of documents were used
output_regex (str)
frozen (bool)
trust_dataset (bool): Whether to trust the dataset at execution or not
version (int): The version of the task. Defaults to 0. Can be increased if the underlying dataset or the prompt changes.
output_regex (str)
frozen (bool)
"""

name: str
prompt_function: Callable # [[dict, str], Doc]
prompt_function: Callable[[dict, str], Doc]
hf_repo: str
hf_subset: str
metric: Tuple[Union[Metric, Metrics]]
hf_avail_splits: Optional[Tuple[str]] = None
evaluation_splits: Optional[Tuple[str]] = None
metric: ListLike[Metric | Metrics]

# Additional hf dataset config
hf_revision: Optional[str] = None
hf_filter: Optional[Callable[[dict], bool]] = None
hf_avail_splits: Optional[ListLike[str]] = field(default_factory=lambda: ["train", "validation", "test"])
# We default to false, to reduce security issues
trust_dataset: bool = False

# Splits
evaluation_splits: ListLike[str] = field(default_factory=lambda: ["validation"])
few_shots_split: Optional[str] = None
few_shots_select: Optional[str] = None

# Generation args
generation_size: Optional[int] = None
generation_grammar: Optional[TextGenerationInputGrammarType] = None
stop_sequence: Optional[Tuple[str]] = None
stop_sequence: Optional[ListLike[str]] = None
output_regex: Optional[str] = None
num_samples: Optional[list[int]] = None

frozen: bool = False
suite: Optional[Tuple[str]] = None
suite: ListLike[str] = field(default_factory=lambda: ["custom"])

original_num_docs: int = -1
effective_num_docs: int = -1

trust_dataset: bool = None

must_remove_duplicate_docs: bool = None
must_remove_duplicate_docs: bool = False

version: int = 0

def __post_init__(self):
if self.suite is None:
self.suite = ["custom"]
if self.hf_avail_splits is None:
self.hf_avail_splits = ["train", "validation", "test"]
if self.evaluation_splits is None:
self.evaluation_splits = ["validation"]
# Currently unused
frozen: bool = False

def __post_init__(self):
# If we got a Metrics enums instead of a Metric, we convert
self.metric = [metric.value if isinstance(metric, Metrics) else metric for metric in self.metric]

# Convert list to tuple for hashing
self.metric = tuple(self.metric)
self.hf_avail_splits = tuple(self.hf_avail_splits) if self.hf_avail_splits is not None else None
self.evaluation_splits = tuple(self.evaluation_splits) if self.evaluation_splits is not None else None
self.suite = tuple(self.suite) if self.suite is not None else None
self.evaluation_splits = tuple(self.evaluation_splits)
self.suite = tuple(self.suite)
self.stop_sequence = tuple(self.stop_sequence) if self.stop_sequence is not None else None

def print(self):
Expand Down Expand Up @@ -175,31 +178,27 @@ def __init__( # noqa: C901
"""
self.name = name
self.version = cfg.version
self.is_main_process = False
self.cache_dir = cache_dir
self._cfg = cfg

# Dataset info
self.hf_repo = cfg.hf_repo
self.hf_subset = cfg.hf_subset
self.dataset_path = self.hf_repo
self.dataset_config_name = self.hf_subset
self.dataset = None # Delayed download
self.dataset_path = cfg.hf_repo
self.dataset_config_name = cfg.hf_subset
self.dataset_revision = cfg.hf_revision
self.dataset_filter = cfg.hf_filter
self.trust_dataset = cfg.trust_dataset
self.dataset: Optional[DatasetDict] = None # Delayed download
hlog(f"{self.dataset_path} {self.dataset_config_name}")
self._fewshot_docs = None
self._docs = None

# Managing splits and few shot
self.all_available_splits = as_list(cfg.hf_avail_splits)
if cfg.evaluation_splits is None:
raise ValueError(f"The evaluation split for task {self.name} is None. Please select a valid split.")

self.evaluation_split = as_list(cfg.evaluation_splits)

self.fewshot_split: list[str] | None
if cfg.few_shots_split is not None:
self.fewshot_split = as_list(cfg.few_shots_split)
else:
self.fewshot_split = as_list(self.get_first_possible_fewshot_splits())
self.fewshot_split = self.get_first_possible_fewshot_splits(cfg.hf_avail_splits or [])
self.fewshot_selection = cfg.few_shots_select

# Metrics
Expand All @@ -223,30 +222,20 @@ def __init__( # noqa: C901
if "maj@" in metric_name:
self.num_samples.append(int(metric_name.replace("maj@", "").split("_")[0]))

if not isinstance(cfg.prompt_function, Callable):
raise TypeError(
f"Prompt formatting function ({str(cfg.prompt_function)}) should have been passed as a callable, was {type(cfg.prompt_function)} instead."
)
self.formatter = cfg.prompt_function

self.generation_size = cfg.generation_size
self.generation_grammar = cfg.generation_grammar
self.stop_sequence = cfg.stop_sequence
self.output_regex = cfg.output_regex
self.must_remove_duplicate_docs = cfg.must_remove_duplicate_docs
if self.must_remove_duplicate_docs is None:
self.must_remove_duplicate_docs = False

# Save options
self.save_queries: bool = False
self.logfile_name: Optional[Path] = None
self.is_main_process: bool = False

@property
def cfg(self):
return self._cfg

def get_first_possible_fewshot_splits(self, number_of_splits: int = 1) -> list[str]:
def get_first_possible_fewshot_splits(
self, available_splits: ListLike[str], number_of_splits: int = 1
) -> list[str] | None:
"""
Parses the possible fewshot split keys in order: train, then validation
keys and matches them with the available keys. Returns the first
Expand All @@ -260,7 +249,7 @@ def get_first_possible_fewshot_splits(self, number_of_splits: int = 1) -> list[s
list[str]: List of the first available fewshot splits.
"""
# Possible few shot splits are the available splits not used for evaluation
possible_fewshot_splits = [k for k in self.all_available_splits if k not in self.evaluation_split]
possible_fewshot_splits = [k for k in available_splits if k not in self.evaluation_split]
stored_splits = []

# We look at these keys in order (first the training sets, then the validation sets)
Expand Down Expand Up @@ -289,7 +278,13 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]:
list[Doc]: List of documents.
"""
if self.dataset is None:
self.dataset = download_dataset_worker((self.dataset_path, self.dataset_config_name, self.trust_dataset))
self.dataset = download_dataset_worker(
self.dataset_path,
self.dataset_config_name,
self.trust_dataset,
self.dataset_filter,
self.dataset_revision,
)
splits = as_list(splits)

docs = []
Expand Down Expand Up @@ -326,7 +321,7 @@ def fewshot_docs(self) -> list[Doc]:
self._fewshot_docs = []

# If we have no available few shot split, the few shot data is the eval data!
if self.fewshot_split in [None, [None]]:
if self.fewshot_split is None:
self._fewshot_docs = self._get_docs_from_split(self.evaluation_split, few_shots=True)
else: # Normal case
self._fewshot_docs = self._get_docs_from_split(self.fewshot_split, few_shots=True)
Expand Down Expand Up @@ -552,14 +547,29 @@ def load_datasets(tasks: list["LightevalTask"], dataset_loading_processes: int =

if dataset_loading_processes <= 1:
datasets = [
download_dataset_worker((task.dataset_path, task.dataset_config_name, task.trust_dataset))
download_dataset_worker(
task.dataset_path,
task.dataset_config_name,
task.trust_dataset,
task.dataset_filter,
task.dataset_revision,
)
for task in tasks
]
else:
with Pool(processes=dataset_loading_processes) as pool:
datasets = pool.map(
datasets = pool.starmap(
download_dataset_worker,
[(task.dataset_path, task.dataset_config_name, task.trust_dataset) for task in tasks],
[
(
task.dataset_path,
task.dataset_config_name,
task.trust_dataset,
task.dataset_filter,
task.dataset_revision,
)
for task in tasks
],
)

for task, dataset in zip(tasks, datasets):
Expand Down
32 changes: 26 additions & 6 deletions src/lighteval/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# limitations under the License.
import os
from dataclasses import asdict, dataclass, is_dataclass
from typing import Any, Union
from typing import Callable, TypeVar, Union

import numpy as np
from datasets import load_dataset
from datasets import DatasetDict, load_dataset
from pytablewriter import MarkdownTableWriter


Expand Down Expand Up @@ -109,7 +109,14 @@ def sanitize_numpy(example_dict: dict) -> dict:
return output_dict


def as_list(item: Union[list, tuple, Any]) -> list:
ListLikeTypeVar = TypeVar("ListLikeTypeVar")
ListLike = list[ListLikeTypeVar] | tuple[ListLikeTypeVar, ...]


ElementType = TypeVar("ElementType")


def as_list(item: ListLike[ElementType] | ElementType) -> list[ElementType]:
"""
Convert the given item into a list.
Expand All @@ -126,8 +133,10 @@ def as_list(item: Union[list, tuple, Any]) -> list:
"""
if isinstance(item, list):
return item

elif isinstance(item, tuple):
return list(item)

return [item]


Expand Down Expand Up @@ -205,21 +214,32 @@ def boolstring_to_bool(x: Union[str, bool, int]) -> Union[bool, None]:
raise ValueError(f"You tried to convert {x} to a boolean but it's not possible.")


def download_dataset_worker(args):
def download_dataset_worker(
dataset_path: str,
dataset_config_name: str,
trust_dataset: bool,
dataset_filter: Callable[[dict], bool] | None = None,
revision: str | None = None,
) -> DatasetDict:
"""
Worker function to download a dataset from the HuggingFace Hub.
Used for parallel dataset loading.
"""
dataset_path, dataset_config_name, trust_dataset = args
dataset = load_dataset(
path=dataset_path,
name=dataset_config_name,
data_dir=None,
cache_dir=None,
download_mode=None,
trust_remote_code=trust_dataset,
revision=revision,
)
return dataset

if dataset_filter is not None:
dataset = dataset.filter(dataset_filter)

# It returns DatasetDict because we don't specify a split
return dataset # type: ignore


def safe_divide(numerator: np.ndarray, denominator: float, default_value: float = 0.0) -> np.ndarray:
Expand Down
61 changes: 61 additions & 0 deletions tests/tasks/test_lighteval_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig


def dummy_prompt_function(item, task_name):
return item["text"]


def test_revision_check():
# Test with a different revision
cfg_with_revision = LightevalTaskConfig(
name="test_task_revision",
prompt_function=dummy_prompt_function,
hf_repo="lighteval-tests-datasets/dataset-test-1",
hf_subset="default",
evaluation_splits=["train"],
metric=[],
hf_revision="25175defadfde48b131b7cd7573ad6f59f868306",
)
task_with_revision = LightevalTask("test_task_revision", cfg_with_revision)
assert task_with_revision.eval_docs() == ["hi", "how are you?"]


def test_dataset_filter():
# Setup

cfg = LightevalTaskConfig(
name="test_task",
prompt_function=dummy_prompt_function,
hf_repo="lighteval-tests-datasets/dataset-test-1",
hf_subset="default",
hf_filter=lambda x: x["text"] == "hi",
metric=[],
evaluation_splits=["train"],
)
task = LightevalTask("test_task", cfg)

filtered_docs = task.eval_docs()
assert len(filtered_docs) == 1
assert filtered_docs[0] == "hi"

0 comments on commit 919be47

Please sign in to comment.