From fe96913af74cad4c4a81a6916b97dccb5d30ca01 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 4 Jun 2024 13:05:40 +0200 Subject: [PATCH 1/2] Always override the originally saved __version__ in the ST config --- sentence_transformers/SentenceTransformer.py | 11 +++++------ tests/test_sentence_transformer.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index f6a48ae29..dfcc297c8 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -1004,12 +1004,11 @@ def save( modules_config = [] # Save some model info - if "__version__" not in self._model_config: - self._model_config["__version__"] = { - "sentence_transformers": __version__, - "transformers": transformers.__version__, - "pytorch": torch.__version__, - } + self._model_config["__version__"] = { + "sentence_transformers": __version__, + "transformers": transformers.__version__, + "pytorch": torch.__version__, + } with open(os.path.join(path, "config_sentence_transformers.json"), "w") as fOut: config = self._model_config.copy() diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index 233be3f8c..2d789dd4a 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -564,3 +564,14 @@ def test_similarity_score_save(stsb_bert_tiny_model: SentenceTransformer) -> Non assert loaded_model.similarity_fn_name == "euclidean" dot_scores = model.similarity(embeddings, embeddings) assert np.not_equal(cosine_scores, dot_scores).all() + + +def test_override_config_versions(stsb_bert_tiny_model: SentenceTransformer) -> None: + model = stsb_bert_tiny_model + + assert model._model_config["__version__"]["sentence_transformers"] == "2.2.2" + with tempfile.TemporaryDirectory() as tmp_folder: + model.save(tmp_folder) + loaded_model = SentenceTransformer(tmp_folder) + # Verify that the version has now been updated when saving the model again + assert loaded_model._model_config["__version__"]["sentence_transformers"] != "2.2.2" From 9aef3c4899e12b8dd3ea1f3eb6f9bdde37edb505 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 4 Jun 2024 14:36:16 +0200 Subject: [PATCH 2/2] Also include HF datasets in the model card metadata --- sentence_transformers/model_card.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sentence_transformers/model_card.py b/sentence_transformers/model_card.py index ea67e4dee..149dae236 100644 --- a/sentence_transformers/model_card.py +++ b/sentence_transformers/model_card.py @@ -301,6 +301,7 @@ class SentenceTransformerModelCardData(CardData): citations: Dict[str, str] = field(default_factory=dict, init=False) best_model_step: Optional[int] = field(default=None, init=False) trainer: Optional["SentenceTransformerTrainer"] = field(default=None, init=False, repr=False) + datasets: List[str] = field(default_factory=list, init=False, repr=False) # Utility fields first_save: bool = field(default=True, init=False) @@ -357,6 +358,10 @@ def validate_datasets(self, dataset_list, infer_languages: bool = True) -> None: if language not in self.language: self.language.append(language) + # Track dataset IDs for the metadata + if info.id not in self.datasets: + self.datasets.append(info.id) + output_dataset_list.append(dataset) return output_dataset_list