Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Tokenization + misc fixes #354

Merged
merged 5 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lighteval/config/lighteval_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class LightEvalTasksArgs:

dataset_loading_processes: int = 8
multichoice_continuations_start_space: Optional[bool] = None
pair_wise_tokenization: bool = False
pairwise_tokenization: bool = False


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/logging/info_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ class MetricsLogger:
default_factory=lambda: collections.defaultdict(lambda: collections.defaultdict(list))
)
metric_aggregated: dict[str, dict[str, float]] = field(
default_factory=lambda: collections.defaultdict(lambda: collections.defaultdict(dict))
default_factory=lambda: collections.defaultdict(lambda: collections.defaultdict(float))
)

def log(self, task_name: str, metrics: dict) -> None:
Expand Down
29 changes: 17 additions & 12 deletions src/lighteval/models/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,23 +183,28 @@ def tok_encode_pair(self, context, continuation, pairwise: bool = False):
context = context[:-n_spaces]

if pairwise:
context_enc, continuation_enc = self.tok_encode(context), self.tok_encode(continuation)
if self.add_special_tokens:
tokenized_with_special_tokens = self.tokenizer.build_inputs_with_special_tokens(
context_enc + continuation_enc
)
# If this fails something went wrong as the function above should only add special tokens
first_non_prefix_token_idx = tokenized_with_special_tokens.index(context_enc[0])
last_context_token_idx = first_non_prefix_token_idx + len(context_enc)
context_enc, continuation_enc = (
tokenized_with_special_tokens[:last_context_token_idx],
tokenized_with_special_tokens[last_context_token_idx:],
)
# We don't add special tokens to the continuation as if bos is added
# models tend to to completely ignore a context
context_enc, continuation_enc = (
self.tok_encode(context, add_special_tokens=self.add_special_tokens),
self.tok_encode(continuation, add_special_tokens=False),
)

# In theory the context_enc can be ended with eos token, this would again
# cause the model to ignore the context. We thus strip the eos token from context_enc
if len(context_enc) > 0 and context_enc[-1] == self.tokenizer.eos_token_id:
context_enc = context_enc[:-1]

return context_enc, continuation_enc

whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
# In case continuation tokens merge with context tokens we use the merged token as continuation
if len(context_enc) == len(whole_enc):
context_enc_len = len(context_enc) - 1
context_enc = whole_enc[:context_enc_len]

continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc

Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
model_size=model_size,
)

self.pair_wise_tokenization = config.pair_wise_tokenization
self.pairwise_tokenization = config.pairwise_tokenization

@property
def tokenizer(self):
Expand Down Expand Up @@ -697,7 +697,7 @@ def loglikelihood(
else:
# The following line is mandatory for compatibility with the harness
request.tokenized_context, request.tokenized_continuation = self.tok_encode_pair(
request.context, request.choice, pairwise=self.pair_wise_tokenization
request.context, request.choice, pairwise=self.pairwise_tokenization
)

return self._loglikelihood_tokens(requests, override_bs=override_bs)
Expand Down
6 changes: 4 additions & 2 deletions src/lighteval/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class BaseModelConfig:
For example, context: "What is the capital of France?" and choices: "Paris", "London".
Will be tokenized as: "What is the capital of France? Paris" and "What is the capital of France? London".
True adds a space, False strips a space, None does nothing
pair_wise_tokenization (bool): Whether to tokenize the context and continuation as separately or together.
pairwise_tokenization (bool): Whether to tokenize the context and continuation as separately or together.
subfolder (Optional[str]): The subfolder within the model repository.
revision (str): The revision of the model.
batch_size (int): The batch size for model training.
Expand Down Expand Up @@ -100,7 +100,7 @@ class BaseModelConfig:
accelerator: "Accelerator" = None
tokenizer: Optional[str] = None
multichoice_continuations_start_space: Optional[bool] = None
pair_wise_tokenization: bool = False
pairwise_tokenization: bool = False
subfolder: Optional[str] = None
revision: str = "main"
batch_size: int = -1
Expand Down Expand Up @@ -226,6 +226,8 @@ class VLLMModelConfig:
multichoice_continuations_start_space: bool = (
True # whether to add a space at the start of each continuation in multichoice generation
)
pairwise_tokenization: bool = False # whether to tokenize the context and continuation separately or together.

subfolder: Optional[str] = None


Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/models/nanotron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def __init__(
self.input_pp_rank, self.output_pp_rank = get_min_max_rank(module=self.model)

self.multichoice_continuations_start_space = multichoice_continuations_start_space
self.pair_wise_tokenization = nanotron_config.lighteval_config.tasks.pair_wise_tokenization
self.pairwise_tokenization = nanotron_config.lighteval_config.tasks.pairwise_tokenization

self.model_info = ModelInfo(
model_name=f"{nanotron_config.nanotron_config.general.run}/{nanotron_config.nanotron_config.general.step}"
Expand Down Expand Up @@ -447,7 +447,7 @@ def loglikelihood(self, requests: List[LoglikelihoodRequest], override_bs=None)
else:
# The following line is mandatory for compatibility with the harness
request.tokenized_context, request.tokenized_continuation = self.tok_encode_pair(
request.context, request.choice, self.pair_wise_tokenization
request.context, request.choice, self.pairwise_tokenization
)

return self._loglikelihood_tokens(
Expand Down
3 changes: 2 additions & 1 deletion src/lighteval/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
self.precision = _get_dtype(config.dtype, config=self._config)

self.model_info = ModelInfo(model_name=self.model_name, model_sha=self.model_sha)
self.pairwise_tokenization = config.pairwise_tokenization

@property
def tokenizer(self):
Expand Down Expand Up @@ -352,7 +353,7 @@ def loglikelihood(
else:
# The following line is mandatory for compatibility with the harness
request.tokenized_context, request.tokenized_continuation = self.tok_encode_pair(
request.context, request.choice
request.context, request.choice, pairwise=self.pairwise_tokenization
)
return self._loglikelihood_tokens(requests, override_bs=override_bs)

Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
import inspect
import random
from dataclasses import asdict, dataclass, field
from multiprocessing import Pool
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple

from datasets import DatasetDict
from huggingface_hub import TextGenerationInputGrammarType
from multiprocess import Pool
from pytablewriter import MarkdownTableWriter

from lighteval.logging.hierarchical_logger import hlog, hlog_warn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_tok_encode_pair():
continuation = "1"
non_pairwise_tokens = model.tok_encode_pair(context, continuation, pairwise=False)
pairwise_tokens = model.tok_encode_pair(context, continuation, pairwise=True)
# Problematic case where the completion tokens are empty despite the chars are non-empty
assert non_pairwise_tokens == ([6, 47873, 13], [])
# Non-pairwise merged ":1" to one token
assert non_pairwise_tokens == ([6, 47873], [34871])
# Pairwise separated ":" and "1"
assert pairwise_tokens == ([6, 47873, 13], [82])
Loading