diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index dfcc297c8..438c54bd7 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -1094,6 +1094,12 @@ def _create_model_card( # we don't generate a new model card, but reuse the old one instead. if self._model_card_text and self.model_card_data.trainer is None: model_card = self._model_card_text + if self.model_card_data.model_id: + # If the original model card was saved without a model_id, we replace the model_id with the new model_id + model_card = model_card.replace( + 'model = SentenceTransformer("sentence_transformers_model_id"', + f'model = SentenceTransformer("{self.model_card_data.model_id}"', + ) else: try: model_card = generate_model_card(self) diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index 2d789dd4a..9336a54fe 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -566,6 +566,28 @@ def test_similarity_score_save(stsb_bert_tiny_model: SentenceTransformer) -> Non assert np.not_equal(cosine_scores, dot_scores).all() +def test_model_card_save_update_model_id(stsb_bert_tiny_model: SentenceTransformer) -> None: + model = stsb_bert_tiny_model + # Removing the saved model card will cause a fresh one to be generated when we save + model._model_card_text = "" + with tempfile.TemporaryDirectory() as tmp_folder: + model.save(tmp_folder) + with open(Path(tmp_folder) / "README.md", "r", encoding="utf8") as f: + model_card_text = f.read() + assert 'model = SentenceTransformer("sentence_transformers_model_id"' in model_card_text + + # When we reload this saved model and then re-save it, we want to override the 'sentence_transformers_model_id' + # if we have it set + loaded_model = SentenceTransformer(tmp_folder) + + with tempfile.TemporaryDirectory() as tmp_folder: + loaded_model.save(tmp_folder, model_name="test_user/test_model") + + with open(Path(tmp_folder) / "README.md", "r", encoding="utf8") as f: + model_card_text = f.read() + assert 'model = SentenceTransformer("test_user/test_model"' in model_card_text + + def test_override_config_versions(stsb_bert_tiny_model: SentenceTransformer) -> None: model = stsb_bert_tiny_model