Skip to content

Commit

Permalink
Merge branch 'master' into model_card/update_sentence_transformers_mo…
Browse files Browse the repository at this point in the history
…del_id
  • Loading branch information
tomaarsen committed Jun 4, 2024
2 parents c1f22d4 + 2224477 commit 3c649c1
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
11 changes: 5 additions & 6 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,12 +1004,11 @@ def save(
modules_config = []

# Save some model info
if "__version__" not in self._model_config:
self._model_config["__version__"] = {
"sentence_transformers": __version__,
"transformers": transformers.__version__,
"pytorch": torch.__version__,
}
self._model_config["__version__"] = {
"sentence_transformers": __version__,
"transformers": transformers.__version__,
"pytorch": torch.__version__,
}

with open(os.path.join(path, "config_sentence_transformers.json"), "w") as fOut:
config = self._model_config.copy()
Expand Down
5 changes: 5 additions & 0 deletions sentence_transformers/model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ class SentenceTransformerModelCardData(CardData):
citations: Dict[str, str] = field(default_factory=dict, init=False)
best_model_step: Optional[int] = field(default=None, init=False)
trainer: Optional["SentenceTransformerTrainer"] = field(default=None, init=False, repr=False)
datasets: List[str] = field(default_factory=list, init=False, repr=False)

# Utility fields
first_save: bool = field(default=True, init=False)
Expand Down Expand Up @@ -357,6 +358,10 @@ def validate_datasets(self, dataset_list, infer_languages: bool = True) -> None:
if language not in self.language:
self.language.append(language)

# Track dataset IDs for the metadata
if info.id not in self.datasets:
self.datasets.append(info.id)

output_dataset_list.append(dataset)
return output_dataset_list

Expand Down
11 changes: 11 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,3 +586,14 @@ def test_model_card_save_update_model_id(stsb_bert_tiny_model: SentenceTransform
with open(Path(tmp_folder) / "README.md", "r", encoding="utf8") as f:
model_card_text = f.read()
assert 'model = SentenceTransformer("test_user/test_model"' in model_card_text


def test_override_config_versions(stsb_bert_tiny_model: SentenceTransformer) -> None:
model = stsb_bert_tiny_model

assert model._model_config["__version__"]["sentence_transformers"] == "2.2.2"
with tempfile.TemporaryDirectory() as tmp_folder:
model.save(tmp_folder)
loaded_model = SentenceTransformer(tmp_folder)
# Verify that the version has now been updated when saving the model again
assert loaded_model._model_config["__version__"]["sentence_transformers"] != "2.2.2"

0 comments on commit 3c649c1

Please sign in to comment.