diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index 6d6c35ecc..c86b0c68d 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -15,7 +15,7 @@ from transformers.training_args import ParallelMode from sentence_transformers.data_collator import SentenceTransformerDataCollator -from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator +from sentence_transformers.evaluation import SentenceEvaluator, SequentialEvaluator from sentence_transformers.losses.CoSENTLoss import CoSENTLoss from sentence_transformers.model_card import ModelCardCallback from sentence_transformers.models.Transformer import Transformer @@ -79,9 +79,12 @@ class SentenceTransformerTrainer(Trainer): dataset names to functions that return a loss class instance given a model. In practice, the latter two are primarily used for hyper-parameter optimization. Will default to :class:`~sentence_transformers.losses.CoSENTLoss` if no ``loss`` is provided. - evaluator (:class:`~sentence_transformers.evaluation.SentenceEvaluator`, *optional*): - The evaluator class to use for evaluation alongside the evaluation dataset. An evaluator will display more - useful metrics than the loss function. + evaluator (Union[:class:`~sentence_transformers.evaluation.SentenceEvaluator`,\ + List[:class:`~sentence_transformers.evaluation.SentenceEvaluator`]], *optional*): + The evaluator instance for useful evaluation metrics during training. You can use an ``evaluator`` with + or without an ``eval_dataset``, and vice versa. Generally, the metrics that an ``evaluator`` returns + are more useful than the loss value returned from the ``eval_dataset``. A list of evaluators will be + wrapped in a :class:`~sentence_transformers.evaluation.SequentialEvaluator` to run them sequentially. callbacks (List of [:class:`transformers.TrainerCallback`], *optional*): A list of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in [here](callback). @@ -123,7 +126,7 @@ def __init__( Dict[str, Callable[["SentenceTransformer"], torch.nn.Module]], ] ] = None, - evaluator: Optional[SentenceEvaluator] = None, + evaluator: Optional[Union[SentenceEvaluator, List[SentenceEvaluator]]] = None, data_collator: Optional[DataCollator] = None, tokenizer: Optional[Union[PreTrainedTokenizerBase, Callable]] = None, model_init: Optional[Callable[[], "SentenceTransformer"]] = None, @@ -218,6 +221,9 @@ def __init__( ) else: self.loss = self.prepare_loss(loss, model) + # If evaluator is a list, we wrap it in a SequentialEvaluator + if not isinstance(evaluator, SentenceEvaluator): + evaluator = SequentialEvaluator(evaluator) self.evaluator = evaluator # Add a callback responsible for automatically tracking data required for the automatic model card generation