Skip to content

Commit

Permalink
Fix Tokenization + misc fixes (#354)
Browse files Browse the repository at this point in the history
* fix loggers with multimetric, fix pariwise tokenization + fallback for non-pairwise, add pairwise to vllm, use multiprocess fordataset loading

* fix test + implementation

* finally do the tokenization  correctly

* pair_wise -> pairwise

---------

Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com>
  • Loading branch information
hynky1999 and NathanHB authored Oct 10, 2024
1 parent 78cda93 commit 1dfd77d
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 24 deletions.
2 changes: 1 addition & 1 deletion src/lighteval/config/lighteval_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class LightEvalTasksArgs:

dataset_loading_processes: int = 8
multichoice_continuations_start_space: Optional[bool] = None
pair_wise_tokenization: bool = False
pairwise_tokenization: bool = False


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/logging/info_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ class MetricsLogger:
default_factory=lambda: collections.defaultdict(lambda: collections.defaultdict(list))
)
metric_aggregated: dict[str, dict[str, float]] = field(
default_factory=lambda: collections.defaultdict(lambda: collections.defaultdict(dict))
default_factory=lambda: collections.defaultdict(lambda: collections.defaultdict(float))
)

def log(self, task_name: str, metrics: dict) -> None:
Expand Down
29 changes: 17 additions & 12 deletions src/lighteval/models/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,23 +183,28 @@ def tok_encode_pair(self, context, continuation, pairwise: bool = False):
context = context[:-n_spaces]

if pairwise:
context_enc, continuation_enc = self.tok_encode(context), self.tok_encode(continuation)
if self.add_special_tokens:
tokenized_with_special_tokens = self.tokenizer.build_inputs_with_special_tokens(
context_enc + continuation_enc
)
# If this fails something went wrong as the function above should only add special tokens
first_non_prefix_token_idx = tokenized_with_special_tokens.index(context_enc[0])
last_context_token_idx = first_non_prefix_token_idx + len(context_enc)
context_enc, continuation_enc = (
tokenized_with_special_tokens[:last_context_token_idx],
tokenized_with_special_tokens[last_context_token_idx:],
)
# We don't add special tokens to the continuation as if bos is added
# models tend to to completely ignore a context
context_enc, continuation_enc = (
self.tok_encode(context, add_special_tokens=self.add_special_tokens),
self.tok_encode(continuation, add_special_tokens=False),
)

# In theory the context_enc can be ended with eos token, this would again
# cause the model to ignore the context. We thus strip the eos token from context_enc
if len(context_enc) > 0 and context_enc[-1] == self.tokenizer.eos_token_id:
context_enc = context_enc[:-1]

return context_enc, continuation_enc

whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
# In case continuation tokens merge with context tokens we use the merged token as continuation
if len(context_enc) == len(whole_enc):
context_enc_len = len(context_enc) - 1
context_enc = whole_enc[:context_enc_len]

continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc

Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
model_size=model_size,
)

self.pair_wise_tokenization = config.pair_wise_tokenization
self.pairwise_tokenization = config.pairwise_tokenization

@property
def tokenizer(self):
Expand Down Expand Up @@ -697,7 +697,7 @@ def loglikelihood(
else:
# The following line is mandatory for compatibility with the harness
request.tokenized_context, request.tokenized_continuation = self.tok_encode_pair(
request.context, request.choice, pairwise=self.pair_wise_tokenization
request.context, request.choice, pairwise=self.pairwise_tokenization
)

return self._loglikelihood_tokens(requests, override_bs=override_bs)
Expand Down
6 changes: 4 additions & 2 deletions src/lighteval/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class BaseModelConfig:
For example, context: "What is the capital of France?" and choices: "Paris", "London".
Will be tokenized as: "What is the capital of France? Paris" and "What is the capital of France? London".
True adds a space, False strips a space, None does nothing
pair_wise_tokenization (bool): Whether to tokenize the context and continuation as separately or together.
pairwise_tokenization (bool): Whether to tokenize the context and continuation as separately or together.
subfolder (Optional[str]): The subfolder within the model repository.
revision (str): The revision of the model.
batch_size (int): The batch size for model training.
Expand Down Expand Up @@ -100,7 +100,7 @@ class BaseModelConfig:
accelerator: "Accelerator" = None
tokenizer: Optional[str] = None
multichoice_continuations_start_space: Optional[bool] = None
pair_wise_tokenization: bool = False
pairwise_tokenization: bool = False
subfolder: Optional[str] = None
revision: str = "main"
batch_size: int = -1
Expand Down Expand Up @@ -226,6 +226,8 @@ class VLLMModelConfig:
multichoice_continuations_start_space: bool = (
True # whether to add a space at the start of each continuation in multichoice generation
)
pairwise_tokenization: bool = False # whether to tokenize the context and continuation separately or together.

subfolder: Optional[str] = None


Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/models/nanotron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def __init__(
self.input_pp_rank, self.output_pp_rank = get_min_max_rank(module=self.model)

self.multichoice_continuations_start_space = multichoice_continuations_start_space
self.pair_wise_tokenization = nanotron_config.lighteval_config.tasks.pair_wise_tokenization
self.pairwise_tokenization = nanotron_config.lighteval_config.tasks.pairwise_tokenization

self.model_info = ModelInfo(
model_name=f"{nanotron_config.nanotron_config.general.run}/{nanotron_config.nanotron_config.general.step}"
Expand Down Expand Up @@ -447,7 +447,7 @@ def loglikelihood(self, requests: List[LoglikelihoodRequest], override_bs=None)
else:
# The following line is mandatory for compatibility with the harness
request.tokenized_context, request.tokenized_continuation = self.tok_encode_pair(
request.context, request.choice, self.pair_wise_tokenization
request.context, request.choice, self.pairwise_tokenization
)

return self._loglikelihood_tokens(
Expand Down
3 changes: 2 additions & 1 deletion src/lighteval/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
self.precision = _get_dtype(config.dtype, config=self._config)

self.model_info = ModelInfo(model_name=self.model_name, model_sha=self.model_sha)
self.pairwise_tokenization = config.pairwise_tokenization

@property
def tokenizer(self):
Expand Down Expand Up @@ -352,7 +353,7 @@ def loglikelihood(
else:
# The following line is mandatory for compatibility with the harness
request.tokenized_context, request.tokenized_continuation = self.tok_encode_pair(
request.context, request.choice
request.context, request.choice, pairwise=self.pairwise_tokenization
)
return self._loglikelihood_tokens(requests, override_bs=override_bs)

Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
import inspect
import random
from dataclasses import asdict, dataclass, field
from multiprocessing import Pool
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple

from datasets import DatasetDict
from huggingface_hub import TextGenerationInputGrammarType
from multiprocess import Pool
from pytablewriter import MarkdownTableWriter

from lighteval.logging.hierarchical_logger import hlog, hlog_warn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_tok_encode_pair():
continuation = "1"
non_pairwise_tokens = model.tok_encode_pair(context, continuation, pairwise=False)
pairwise_tokens = model.tok_encode_pair(context, continuation, pairwise=True)
# Problematic case where the completion tokens are empty despite the chars are non-empty
assert non_pairwise_tokens == ([6, 47873, 13], [])
# Non-pairwise merged ":1" to one token
assert non_pairwise_tokens == ([6, 47873], [34871])
# Pairwise separated ":" and "1"
assert pairwise_tokens == ([6, 47873, 13], [82])

0 comments on commit 1dfd77d

Please sign in to comment.