Skip to content

Commit

Permalink
Merge pull request #2714 from tomaarsen/model_card/update_sentence_tr…
Browse files Browse the repository at this point in the history
…ansformers_model_id

[`model cards`] Replace 'sentence_transformers_model_id' from reused model if possible
  • Loading branch information
tomaarsen authored Jun 4, 2024
2 parents 529edc3 + 3c649c1 commit a3e1b86
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
6 changes: 6 additions & 0 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,12 @@ def _create_model_card(
# we don't generate a new model card, but reuse the old one instead.
if self._model_card_text and self.model_card_data.trainer is None:
model_card = self._model_card_text
if self.model_card_data.model_id:
# If the original model card was saved without a model_id, we replace the model_id with the new model_id
model_card = model_card.replace(
'model = SentenceTransformer("sentence_transformers_model_id"',
f'model = SentenceTransformer("{self.model_card_data.model_id}"',
)
else:
try:
model_card = generate_model_card(self)
Expand Down
22 changes: 22 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,28 @@ def test_similarity_score_save(stsb_bert_tiny_model: SentenceTransformer) -> Non
assert np.not_equal(cosine_scores, dot_scores).all()


def test_model_card_save_update_model_id(stsb_bert_tiny_model: SentenceTransformer) -> None:
model = stsb_bert_tiny_model
# Removing the saved model card will cause a fresh one to be generated when we save
model._model_card_text = ""
with tempfile.TemporaryDirectory() as tmp_folder:
model.save(tmp_folder)
with open(Path(tmp_folder) / "README.md", "r", encoding="utf8") as f:
model_card_text = f.read()
assert 'model = SentenceTransformer("sentence_transformers_model_id"' in model_card_text

# When we reload this saved model and then re-save it, we want to override the 'sentence_transformers_model_id'
# if we have it set
loaded_model = SentenceTransformer(tmp_folder)

with tempfile.TemporaryDirectory() as tmp_folder:
loaded_model.save(tmp_folder, model_name="test_user/test_model")

with open(Path(tmp_folder) / "README.md", "r", encoding="utf8") as f:
model_card_text = f.read()
assert 'model = SentenceTransformer("test_user/test_model"' in model_card_text


def test_override_config_versions(stsb_bert_tiny_model: SentenceTransformer) -> None:
model = stsb_bert_tiny_model

Expand Down

0 comments on commit a3e1b86

Please sign in to comment.