Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task config #289

Merged
merged 10 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 66 additions & 56 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
import collections
import inspect
import random
from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, field
from multiprocessing import Pool
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple

from datasets import DatasetDict
from huggingface_hub import TextGenerationInputGrammarType
from pytablewriter import MarkdownTableWriter

Expand All @@ -54,7 +54,7 @@
RequestType,
SampleUid,
)
from lighteval.utils.utils import as_list, download_dataset_worker
from lighteval.utils.utils import ListLike, as_list, download_dataset_worker


if TYPE_CHECKING:
Expand Down Expand Up @@ -82,55 +82,58 @@ class LightevalTaskConfig:
original_num_docs (int): Number of documents in the task
effective_num_docs (int): Number of documents used in a specific evaluation
truncated_num_docs (bool): Whether less than the total number of documents were used
output_regex (str)
frozen (bool)
trust_dataset (bool): Whether to trust the dataset at execution or not
version (int): The version of the task. Defaults to 0. Can be increased if the underlying dataset or the prompt changes.
output_regex (str)
frozen (bool)
"""

name: str
prompt_function: Callable # [[dict, str], Doc]
prompt_function: Callable[[dict, str], Doc]
hf_repo: str
hf_subset: str
metric: Tuple[Union[Metric, Metrics]]
hf_avail_splits: Optional[Tuple[str]] = None
evaluation_splits: Optional[Tuple[str]] = None
metric: ListLike[Metric | Metrics]

# Additional hf dataset config
hf_revision: Optional[str] = None
hf_filter: Optional[Callable[[dict], bool]] = None
hf_avail_splits: Optional[ListLike[str]] = field(default_factory=lambda: ["train", "validation", "test"])
# We default to false, to reduce security issues
trust_dataset: bool = False

# Splits
evaluation_splits: ListLike[str] = field(default_factory=lambda: ["validation"])
few_shots_split: Optional[str] = None
few_shots_select: Optional[str] = None

# Generation args
generation_size: Optional[int] = None
generation_grammar: Optional[TextGenerationInputGrammarType] = None
stop_sequence: Optional[Tuple[str]] = None
stop_sequence: Optional[ListLike[str]] = None
output_regex: Optional[str] = None
num_samples: Optional[list[int]] = None

frozen: bool = False
suite: Optional[Tuple[str]] = None
suite: ListLike[str] = field(default_factory=lambda: ["custom"])

original_num_docs: int = -1
effective_num_docs: int = -1

trust_dataset: bool = None

must_remove_duplicate_docs: bool = None
must_remove_duplicate_docs: bool = False

version: int = 0

def __post_init__(self):
if self.suite is None:
self.suite = ["custom"]
if self.hf_avail_splits is None:
self.hf_avail_splits = ["train", "validation", "test"]
if self.evaluation_splits is None:
self.evaluation_splits = ["validation"]
# Currently unused
frozen: bool = False

def __post_init__(self):
# If we got a Metrics enums instead of a Metric, we convert
self.metric = [metric.value if isinstance(metric, Metrics) else metric for metric in self.metric]

# Convert list to tuple for hashing
self.metric = tuple(self.metric)
self.hf_avail_splits = tuple(self.hf_avail_splits) if self.hf_avail_splits is not None else None
self.evaluation_splits = tuple(self.evaluation_splits) if self.evaluation_splits is not None else None
self.suite = tuple(self.suite) if self.suite is not None else None
self.evaluation_splits = tuple(self.evaluation_splits)
self.suite = tuple(self.suite)
self.stop_sequence = tuple(self.stop_sequence) if self.stop_sequence is not None else None

def print(self):
Expand Down Expand Up @@ -175,31 +178,27 @@ def __init__( # noqa: C901
"""
self.name = name
self.version = cfg.version
self.is_main_process = False
self.cache_dir = cache_dir
self._cfg = cfg

# Dataset info
self.hf_repo = cfg.hf_repo
self.hf_subset = cfg.hf_subset
self.dataset_path = self.hf_repo
self.dataset_config_name = self.hf_subset
self.dataset = None # Delayed download
self.dataset_path = cfg.hf_repo
self.dataset_config_name = cfg.hf_subset
self.dataset_revision = cfg.hf_revision
self.dataset_filter = cfg.hf_filter
self.trust_dataset = cfg.trust_dataset
self.dataset: Optional[DatasetDict] = None # Delayed download
hlog(f"{self.dataset_path} {self.dataset_config_name}")
self._fewshot_docs = None
self._docs = None

# Managing splits and few shot
self.all_available_splits = as_list(cfg.hf_avail_splits)
if cfg.evaluation_splits is None:
raise ValueError(f"The evaluation split for task {self.name} is None. Please select a valid split.")

self.evaluation_split = as_list(cfg.evaluation_splits)

self.fewshot_split: list[str] | None
if cfg.few_shots_split is not None:
self.fewshot_split = as_list(cfg.few_shots_split)
else:
self.fewshot_split = as_list(self.get_first_possible_fewshot_splits())
self.fewshot_split = self.get_first_possible_fewshot_splits(cfg.hf_avail_splits or [])
self.fewshot_selection = cfg.few_shots_select

# Metrics
Expand All @@ -223,30 +222,20 @@ def __init__( # noqa: C901
if "maj@" in metric_name:
self.num_samples.append(int(metric_name.replace("maj@", "").split("_")[0]))

if not isinstance(cfg.prompt_function, Callable):
raise TypeError(
f"Prompt formatting function ({str(cfg.prompt_function)}) should have been passed as a callable, was {type(cfg.prompt_function)} instead."
)
self.formatter = cfg.prompt_function

self.generation_size = cfg.generation_size
self.generation_grammar = cfg.generation_grammar
self.stop_sequence = cfg.stop_sequence
self.output_regex = cfg.output_regex
self.must_remove_duplicate_docs = cfg.must_remove_duplicate_docs
if self.must_remove_duplicate_docs is None:
self.must_remove_duplicate_docs = False

# Save options
self.save_queries: bool = False
self.logfile_name: Optional[Path] = None
self.is_main_process: bool = False

@property
def cfg(self):
return self._cfg

def get_first_possible_fewshot_splits(self, number_of_splits: int = 1) -> list[str]:
def get_first_possible_fewshot_splits(
self, available_splits: ListLike[str], number_of_splits: int = 1
) -> list[str] | None:
"""
Parses the possible fewshot split keys in order: train, then validation
keys and matches them with the available keys. Returns the first
Expand All @@ -260,7 +249,7 @@ def get_first_possible_fewshot_splits(self, number_of_splits: int = 1) -> list[s
list[str]: List of the first available fewshot splits.
"""
# Possible few shot splits are the available splits not used for evaluation
possible_fewshot_splits = [k for k in self.all_available_splits if k not in self.evaluation_split]
possible_fewshot_splits = [k for k in available_splits if k not in self.evaluation_split]
stored_splits = []

# We look at these keys in order (first the training sets, then the validation sets)
Expand Down Expand Up @@ -289,7 +278,13 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]:
list[Doc]: List of documents.
"""
if self.dataset is None:
self.dataset = download_dataset_worker((self.dataset_path, self.dataset_config_name, self.trust_dataset))
self.dataset = download_dataset_worker(
self.dataset_path,
self.dataset_config_name,
self.trust_dataset,
self.dataset_filter,
self.dataset_revision,
)
splits = as_list(splits)

docs = []
Expand Down Expand Up @@ -326,7 +321,7 @@ def fewshot_docs(self) -> list[Doc]:
self._fewshot_docs = []

# If we have no available few shot split, the few shot data is the eval data!
if self.fewshot_split in [None, [None]]:
if self.fewshot_split is None:
self._fewshot_docs = self._get_docs_from_split(self.evaluation_split, few_shots=True)
else: # Normal case
self._fewshot_docs = self._get_docs_from_split(self.fewshot_split, few_shots=True)
Expand Down Expand Up @@ -552,14 +547,29 @@ def load_datasets(tasks: list["LightevalTask"], dataset_loading_processes: int =

if dataset_loading_processes <= 1:
datasets = [
download_dataset_worker((task.dataset_path, task.dataset_config_name, task.trust_dataset))
download_dataset_worker(
task.dataset_path,
task.dataset_config_name,
task.trust_dataset,
task.dataset_filter,
task.dataset_revision,
)
for task in tasks
]
else:
with Pool(processes=dataset_loading_processes) as pool:
datasets = pool.map(
datasets = pool.starmap(
download_dataset_worker,
[(task.dataset_path, task.dataset_config_name, task.trust_dataset) for task in tasks],
[
(
task.dataset_path,
task.dataset_config_name,
task.trust_dataset,
task.dataset_filter,
task.dataset_revision,
)
for task in tasks
],
)

for task, dataset in zip(tasks, datasets):
Expand Down
32 changes: 26 additions & 6 deletions src/lighteval/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# limitations under the License.
import os
from dataclasses import asdict, dataclass, is_dataclass
from typing import Any, Union
from typing import Callable, TypeVar, Union

import numpy as np
from datasets import load_dataset
from datasets import DatasetDict, load_dataset
from pytablewriter import MarkdownTableWriter


Expand Down Expand Up @@ -109,7 +109,14 @@ def sanitize_numpy(example_dict: dict) -> dict:
return output_dict


def as_list(item: Union[list, tuple, Any]) -> list:
ListLikeTypeVar = TypeVar("ListLikeTypeVar")
ListLike = list[ListLikeTypeVar] | tuple[ListLikeTypeVar, ...]


ElementType = TypeVar("ElementType")


def as_list(item: ListLike[ElementType] | ElementType) -> list[ElementType]:
"""
Convert the given item into a list.

Expand All @@ -126,8 +133,10 @@ def as_list(item: Union[list, tuple, Any]) -> list:
"""
if isinstance(item, list):
return item

elif isinstance(item, tuple):
return list(item)

return [item]


Expand Down Expand Up @@ -205,21 +214,32 @@ def boolstring_to_bool(x: Union[str, bool, int]) -> Union[bool, None]:
raise ValueError(f"You tried to convert {x} to a boolean but it's not possible.")


def download_dataset_worker(args):
def download_dataset_worker(
dataset_path: str,
dataset_config_name: str,
trust_dataset: bool,
dataset_filter: Callable[[dict], bool] | None = None,
revision: str | None = None,
) -> DatasetDict:
"""
Worker function to download a dataset from the HuggingFace Hub.
Used for parallel dataset loading.
"""
dataset_path, dataset_config_name, trust_dataset = args
dataset = load_dataset(
path=dataset_path,
name=dataset_config_name,
data_dir=None,
cache_dir=None,
download_mode=None,
trust_remote_code=trust_dataset,
revision=revision,
)
return dataset

if dataset_filter is not None:
dataset = dataset.filter(dataset_filter)

# It returns DatasetDict because we don't specify a split
return dataset # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we ignore the type here ?

Copy link
Collaborator Author

@hynky1999 hynky1999 Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it doesn't have correct type
load_dataset returns Dataset | DatasetDict | IterableDataset | IterableDatasetDict, based on the input args and afaik there is unspecified contract that if the we don't provide streaming and split arg we get DatasetDict. However there is no way to achieve this on typings level, so I just ignore this error.

If the question was why I put there the type: ignore it's because even tho we don't have a typechecker in the quality checks, I do have it on in my vscode (pyright) and it shows red when there is a typing problem.



def safe_divide(numerator: np.ndarray, denominator: float, default_value: float = 0.0) -> np.ndarray:
Expand Down
61 changes: 61 additions & 0 deletions tests/tasks/test_lighteval_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig


def dummy_prompt_function(item, task_name):
return item["text"]


def test_revision_check():
# Test with a different revision
cfg_with_revision = LightevalTaskConfig(
name="test_task_revision",
prompt_function=dummy_prompt_function,
hf_repo="lighteval-tests-datasets/dataset-test-1",
hf_subset="default",
evaluation_splits=["train"],
metric=[],
hf_revision="25175defadfde48b131b7cd7573ad6f59f868306",
)
task_with_revision = LightevalTask("test_task_revision", cfg_with_revision)
assert task_with_revision.eval_docs() == ["hi", "how are you?"]


def test_dataset_filter():
# Setup

cfg = LightevalTaskConfig(
name="test_task",
prompt_function=dummy_prompt_function,
hf_repo="lighteval-tests-datasets/dataset-test-1",
hf_subset="default",
hf_filter=lambda x: x["text"] == "hi",
metric=[],
evaluation_splits=["train"],
)
task = LightevalTask("test_task", cfg)

filtered_docs = task.eval_docs()
assert len(filtered_docs) == 1
assert filtered_docs[0] == "hi"
Loading