From 0f84a7bcc090214e361c21ddfb00387b151379ea Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 4 Jun 2024 22:04:31 +0200 Subject: [PATCH 1/3] Replace 'sentence_transformers_model_id' from reused model if possible --- sentence_transformers/SentenceTransformer.py | 5 +++++ tests/test_sentence_transformer.py | 21 ++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index f6a48ae29..d2bd644b4 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -1095,6 +1095,11 @@ def _create_model_card( # we don't generate a new model card, but reuse the old one instead. if self._model_card_text and self.model_card_data.trainer is None: model_card = self._model_card_text + # If the original model card was saved without a model_id, we replace the model_id with the new model_id + model_card = model_card.replace( + 'model = SentenceTransformer("sentence_transformers_model_id"', + f'model = SentenceTransformer("{self.model_card_data.model_id}"', + ) else: try: model_card = generate_model_card(self) diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index 233be3f8c..7d13480c2 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -564,3 +564,24 @@ def test_similarity_score_save(stsb_bert_tiny_model: SentenceTransformer) -> Non assert loaded_model.similarity_fn_name == "euclidean" dot_scores = model.similarity(embeddings, embeddings) assert np.not_equal(cosine_scores, dot_scores).all() + +def test_model_card_save_update_model_id(stsb_bert_tiny_model: SentenceTransformer) -> None: + model = stsb_bert_tiny_model + # Removing the saved model card will cause a fresh one to be generated when we save + model._model_card_text = "" + with tempfile.TemporaryDirectory() as tmp_folder: + model.save(tmp_folder) + with open(Path(tmp_folder) / "README.md", "r", encoding="utf8") as f: + model_card_text = f.read() + assert 'model = SentenceTransformer("sentence_transformers_model_id"' in model_card_text + + # When we reload this saved model and then re-save it, we want to override the 'sentence_transformers_model_id' + # if we have it set + loaded_model = SentenceTransformer(tmp_folder) + + with tempfile.TemporaryDirectory() as tmp_folder: + loaded_model.save(tmp_folder, model_name="test_user/test_model") + + with open(Path(tmp_folder) / "README.md", "r", encoding="utf8") as f: + model_card_text = f.read() + assert 'model = SentenceTransformer("test_user/test_model"' in model_card_text From 0dbd34989c2c0749fe0fb6e2eb13c9d8cce2e29f Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 4 Jun 2024 22:10:54 +0200 Subject: [PATCH 2/3] Ensure that self.model_card_data.model_id is set --- sentence_transformers/SentenceTransformer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index d2bd644b4..935e1ea6d 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -1095,11 +1095,12 @@ def _create_model_card( # we don't generate a new model card, but reuse the old one instead. if self._model_card_text and self.model_card_data.trainer is None: model_card = self._model_card_text - # If the original model card was saved without a model_id, we replace the model_id with the new model_id - model_card = model_card.replace( - 'model = SentenceTransformer("sentence_transformers_model_id"', - f'model = SentenceTransformer("{self.model_card_data.model_id}"', - ) + if self.model_card_data.model_id: + # If the original model card was saved without a model_id, we replace the model_id with the new model_id + model_card = model_card.replace( + 'model = SentenceTransformer("sentence_transformers_model_id"', + f'model = SentenceTransformer("{self.model_card_data.model_id}"', + ) else: try: model_card = generate_model_card(self) From c1f22d41917b8a60379a19788d3abd728e0c756f Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 4 Jun 2024 22:22:04 +0200 Subject: [PATCH 3/3] Reformat --- tests/test_sentence_transformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index 7d13480c2..5247a3668 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -565,6 +565,7 @@ def test_similarity_score_save(stsb_bert_tiny_model: SentenceTransformer) -> Non dot_scores = model.similarity(embeddings, embeddings) assert np.not_equal(cosine_scores, dot_scores).all() + def test_model_card_save_update_model_id(stsb_bert_tiny_model: SentenceTransformer) -> None: model = stsb_bert_tiny_model # Removing the saved model card will cause a fresh one to be generated when we save @@ -578,7 +579,7 @@ def test_model_card_save_update_model_id(stsb_bert_tiny_model: SentenceTransform # When we reload this saved model and then re-save it, we want to override the 'sentence_transformers_model_id' # if we have it set loaded_model = SentenceTransformer(tmp_folder) - + with tempfile.TemporaryDirectory() as tmp_folder: loaded_model.save(tmp_folder, model_name="test_user/test_model")