-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
Copy pathSentenceTransformer.py
1866 lines (1633 loc) · 85.1 KB
/
SentenceTransformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import copy
import importlib
import json
import logging
import math
import os
import queue
import shutil
import sys
import tempfile
import traceback
import warnings
from collections import OrderedDict
from collections.abc import Iterable, Iterator
from contextlib import contextmanager
from multiprocessing import Queue
from pathlib import Path
from typing import Any, Callable, Literal, overload
import numpy as np
import numpy.typing as npt
import torch
import torch.multiprocessing as mp
import transformers
from huggingface_hub import HfApi
from torch import Tensor, device, nn
from tqdm.autonotebook import trange
from transformers import is_torch_npu_available
from transformers.dynamic_module_utils import get_class_from_dynamic_module, get_relative_import_files
from sentence_transformers.model_card import SentenceTransformerModelCardData, generate_model_card
from sentence_transformers.similarity_functions import SimilarityFunction
from . import __MODEL_HUB_ORGANIZATION__, __version__
from .evaluation import SentenceEvaluator
from .fit_mixin import FitMixin
from .models import Normalize, Pooling, Transformer
from .peft_mixin import PeftAdapterMixin
from .quantization import quantize_embeddings
from .util import (
batch_to_device,
get_device_name,
import_from_string,
is_sentence_transformer_model,
load_dir_path,
load_file_path,
save_to_hub_args_decorator,
truncate_embeddings,
)
logger = logging.getLogger(__name__)
class SentenceTransformer(nn.Sequential, FitMixin, PeftAdapterMixin):
"""
Loads or creates a SentenceTransformer model that can be used to map sentences / text to embeddings.
Args:
model_name_or_path (str, optional): If it is a filepath on disc, it loads the model from that path. If it is not a path,
it first tries to download a pre-trained SentenceTransformer model. If that fails, tries to construct a model
from the Hugging Face Hub with that name.
modules (Iterable[nn.Module], optional): A list of torch Modules that should be called sequentially, can be used to create custom
SentenceTransformer models from scratch.
device (str, optional): Device (like "cuda", "cpu", "mps", "npu") that should be used for computation. If None, checks if a GPU
can be used.
prompts (Dict[str, str], optional): A dictionary with prompts for the model. The key is the prompt name, the value is the prompt text.
The prompt text will be prepended before any text to encode. For example:
`{"query": "query: ", "passage": "passage: "}` or `{"clustering": "Identify the main category based on the
titles in "}`.
default_prompt_name (str, optional): The name of the prompt that should be used by default. If not set,
no prompt will be applied.
similarity_fn_name (str or SimilarityFunction, optional): The name of the similarity function to use. Valid options are "cosine", "dot",
"euclidean", and "manhattan". If not set, it is automatically set to "cosine" if `similarity` or
`similarity_pairwise` are called while `model.similarity_fn_name` is still `None`.
cache_folder (str, optional): Path to store models. Can also be set by the SENTENCE_TRANSFORMERS_HOME environment variable.
trust_remote_code (bool, optional): Whether or not to allow for custom models defined on the Hub in their own modeling files.
This option should only be set to True for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
revision (str, optional): The specific model version to use. It can be a branch name, a tag name, or a commit id,
for a stored model on Hugging Face.
local_files_only (bool, optional): Whether or not to only look at local files (i.e., do not try to download the model).
token (bool or str, optional): Hugging Face authentication token to download private models.
use_auth_token (bool or str, optional): Deprecated argument. Please use `token` instead.
truncate_dim (int, optional): The dimension to truncate sentence embeddings to. `None` does no truncation. Truncation is
only applicable during inference when :meth:`SentenceTransformer.encode` is called.
model_kwargs (Dict[str, Any], optional): Additional model configuration parameters to be passed to the Hugging Face Transformers model.
Particularly useful options are:
- ``torch_dtype``: Override the default `torch.dtype` and load the model under a specific `dtype`.
The different options are:
1. ``torch.float16``, ``torch.bfloat16`` or ``torch.float``: load in a specified
``dtype``, ignoring the model's ``config.torch_dtype`` if one exists. If not specified - the model will
get loaded in ``torch.float`` (fp32).
2. ``"auto"`` - A ``torch_dtype`` entry in the ``config.json`` file of the model will be
attempted to be used. If this entry isn't found then next check the ``dtype`` of the first weight in
the checkpoint that's of a floating point type and use that as ``dtype``. This will load the model
using the ``dtype`` it was saved in at the end of the training. It can't be used as an indicator of how
the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
- ``attn_implementation``: The attention implementation to use in the model (if relevant). Can be any of
`"eager"` (manual implementation of the attention), `"sdpa"` (using `F.scaled_dot_product_attention
<https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html>`_),
or `"flash_attention_2"` (using `Dao-AILab/flash-attention <https://github.com/Dao-AILab/flash-attention>`_).
By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"`
implementation.
- ``provider``: If backend is "onnx", this is the provider to use for inference, for example "CPUExecutionProvider",
"CUDAExecutionProvider", etc. See https://onnxruntime.ai/docs/execution-providers/ for all ONNX execution providers.
- ``file_name``: If backend is "onnx" or "openvino", this is the file name to load, useful for loading optimized
or quantized ONNX or OpenVINO models.
- ``export``: If backend is "onnx" or "openvino", then this is a boolean flag specifying whether this model should
be exported to the backend. If not specified, the model will be exported only if the model repository or directory
does not already contain an exported model.
See the `PreTrainedModel.from_pretrained
<https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained>`_
documentation for more details.
tokenizer_kwargs (Dict[str, Any], optional): Additional tokenizer configuration parameters to be passed to the Hugging Face Transformers tokenizer.
See the `AutoTokenizer.from_pretrained
<https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained>`_
documentation for more details.
config_kwargs (Dict[str, Any], optional): Additional model configuration parameters to be passed to the Hugging Face Transformers config.
See the `AutoConfig.from_pretrained
<https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoConfig.from_pretrained>`_
documentation for more details.
model_card_data (:class:`~sentence_transformers.model_card.SentenceTransformerModelCardData`, optional): A model
card data object that contains information about the model. This is used to generate a model card when saving
the model. If not set, a default model card data object is created.
backend (str): The backend to use for inference. Can be one of "torch" (default), "onnx", or "openvino".
See https://sbert.net/docs/sentence_transformer/usage/efficiency.html for benchmarking information
on the different backends.
Example:
::
from sentence_transformers import SentenceTransformer
# Load a pre-trained SentenceTransformer model
model = SentenceTransformer('all-mpnet-base-v2')
# Encode some texts
sentences = [
"The weather is lovely today.",
"It's so sunny outside!",
"He drove to the stadium.",
]
embeddings = model.encode(sentences)
print(embeddings.shape)
# (3, 768)
# Get the similarity scores between all sentences
similarities = model.similarity(embeddings, embeddings)
print(similarities)
# tensor([[1.0000, 0.6817, 0.0492],
# [0.6817, 1.0000, 0.0421],
# [0.0492, 0.0421, 1.0000]])
"""
def __init__(
self,
model_name_or_path: str | None = None,
modules: Iterable[nn.Module] | None = None,
device: str | None = None,
prompts: dict[str, str] | None = None,
default_prompt_name: str | None = None,
similarity_fn_name: str | SimilarityFunction | None = None,
cache_folder: str | None = None,
trust_remote_code: bool = False,
revision: str | None = None,
local_files_only: bool = False,
token: bool | str | None = None,
use_auth_token: bool | str | None = None,
truncate_dim: int | None = None,
model_kwargs: dict[str, Any] | None = None,
tokenizer_kwargs: dict[str, Any] | None = None,
config_kwargs: dict[str, Any] | None = None,
model_card_data: SentenceTransformerModelCardData | None = None,
backend: Literal["torch", "onnx", "openvino"] = "torch",
) -> None:
# Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name`
self.prompts = prompts or {}
self.default_prompt_name = default_prompt_name
self.similarity_fn_name = similarity_fn_name
self.trust_remote_code = trust_remote_code
self.truncate_dim = truncate_dim
self.model_card_data = model_card_data or SentenceTransformerModelCardData()
self.module_kwargs = None
self._model_card_vars = {}
self._model_card_text = None
self._model_config = {}
self.backend = backend
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v4 of SentenceTransformers.",
FutureWarning,
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token
if cache_folder is None:
cache_folder = os.getenv("SENTENCE_TRANSFORMERS_HOME")
if device is None:
device = get_device_name()
logger.info(f"Use pytorch device_name: {device}")
if device == "hpu" and importlib.util.find_spec("optimum") is not None:
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
adapt_transformers_to_gaudi()
if model_name_or_path is not None and model_name_or_path != "":
logger.info(f"Load pretrained SentenceTransformer: {model_name_or_path}")
# Old models that don't belong to any organization
basic_transformer_models = [
"albert-base-v1",
"albert-base-v2",
"albert-large-v1",
"albert-large-v2",
"albert-xlarge-v1",
"albert-xlarge-v2",
"albert-xxlarge-v1",
"albert-xxlarge-v2",
"bert-base-cased-finetuned-mrpc",
"bert-base-cased",
"bert-base-chinese",
"bert-base-german-cased",
"bert-base-german-dbmdz-cased",
"bert-base-german-dbmdz-uncased",
"bert-base-multilingual-cased",
"bert-base-multilingual-uncased",
"bert-base-uncased",
"bert-large-cased-whole-word-masking-finetuned-squad",
"bert-large-cased-whole-word-masking",
"bert-large-cased",
"bert-large-uncased-whole-word-masking-finetuned-squad",
"bert-large-uncased-whole-word-masking",
"bert-large-uncased",
"camembert-base",
"ctrl",
"distilbert-base-cased-distilled-squad",
"distilbert-base-cased",
"distilbert-base-german-cased",
"distilbert-base-multilingual-cased",
"distilbert-base-uncased-distilled-squad",
"distilbert-base-uncased-finetuned-sst-2-english",
"distilbert-base-uncased",
"distilgpt2",
"distilroberta-base",
"gpt2-large",
"gpt2-medium",
"gpt2-xl",
"gpt2",
"openai-gpt",
"roberta-base-openai-detector",
"roberta-base",
"roberta-large-mnli",
"roberta-large-openai-detector",
"roberta-large",
"t5-11b",
"t5-3b",
"t5-base",
"t5-large",
"t5-small",
"transfo-xl-wt103",
"xlm-clm-ende-1024",
"xlm-clm-enfr-1024",
"xlm-mlm-100-1280",
"xlm-mlm-17-1280",
"xlm-mlm-en-2048",
"xlm-mlm-ende-1024",
"xlm-mlm-enfr-1024",
"xlm-mlm-enro-1024",
"xlm-mlm-tlm-xnli15-1024",
"xlm-mlm-xnli15-1024",
"xlm-roberta-base",
"xlm-roberta-large-finetuned-conll02-dutch",
"xlm-roberta-large-finetuned-conll02-spanish",
"xlm-roberta-large-finetuned-conll03-english",
"xlm-roberta-large-finetuned-conll03-german",
"xlm-roberta-large",
"xlnet-base-cased",
"xlnet-large-cased",
]
if not os.path.exists(model_name_or_path):
# Not a path, load from hub
if "\\" in model_name_or_path or model_name_or_path.count("/") > 1:
raise ValueError(f"Path {model_name_or_path} not found")
if "/" not in model_name_or_path and model_name_or_path.lower() not in basic_transformer_models:
# A model from sentence-transformers
model_name_or_path = __MODEL_HUB_ORGANIZATION__ + "/" + model_name_or_path
if is_sentence_transformer_model(
model_name_or_path,
token,
cache_folder=cache_folder,
revision=revision,
local_files_only=local_files_only,
):
modules, self.module_kwargs = self._load_sbert_model(
model_name_or_path,
token=token,
cache_folder=cache_folder,
revision=revision,
trust_remote_code=trust_remote_code,
local_files_only=local_files_only,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
config_kwargs=config_kwargs,
)
else:
modules = self._load_auto_model(
model_name_or_path,
token=token,
cache_folder=cache_folder,
revision=revision,
trust_remote_code=trust_remote_code,
local_files_only=local_files_only,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
config_kwargs=config_kwargs,
)
if modules is not None and not isinstance(modules, OrderedDict):
modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)])
super().__init__(modules)
# Ensure all tensors in the model are of the same dtype as the first tensor
# This is necessary if the first module has been given a lower precision via
# model_kwargs["torch_dtype"]. The rest of the model should be loaded in the same dtype
# See #2887 for more details
try:
dtype = next(self.parameters()).dtype
self.to(dtype)
except StopIteration:
pass
self.to(device)
self.is_hpu_graph_enabled = False
if self.default_prompt_name is not None and self.default_prompt_name not in self.prompts:
raise ValueError(
f"Default prompt name '{self.default_prompt_name}' not found in the configured prompts "
f"dictionary with keys {list(self.prompts.keys())!r}."
)
if self.prompts:
logger.info(f"{len(self.prompts)} prompts are loaded, with the keys: {list(self.prompts.keys())}")
if self.default_prompt_name:
logger.warning(
f"Default prompt name is set to '{self.default_prompt_name}'. "
"This prompt will be applied to all `encode()` calls, except if `encode()` "
"is called with `prompt` or `prompt_name` parameters."
)
# Ideally, INSTRUCTOR models should set `include_prompt=False` in their pooling configuration, but
# that would be a breaking change for users currently using the InstructorEmbedding project.
# So, instead we hardcode setting it for the main INSTRUCTOR models, and otherwise give a warning if we
# suspect the user is using an INSTRUCTOR model.
if model_name_or_path in ("hkunlp/instructor-base", "hkunlp/instructor-large", "hkunlp/instructor-xl"):
self.set_pooling_include_prompt(include_prompt=False)
elif (
model_name_or_path
and "/" in model_name_or_path
and "instructor" in model_name_or_path.split("/")[1].lower()
):
if any([module.include_prompt for module in self if isinstance(module, Pooling)]):
logger.warning(
"Instructor models require `include_prompt=False` in the pooling configuration. "
"Either update the model configuration or call `model.set_pooling_include_prompt(False)` after loading the model."
)
# Pass the model to the model card data for later use in generating a model card upon saving this model
self.model_card_data.register_model(self)
def get_backend(self) -> Literal["torch", "onnx", "openvino"]:
"""Return the backend used for inference, which can be one of "torch", "onnx", or "openvino".
Returns:
str: The backend used for inference.
"""
return self.backend
@overload
def encode(
self,
sentences: str,
prompt_name: str | None = ...,
prompt: str | None = ...,
batch_size: int = ...,
show_progress_bar: bool | None = ...,
output_value: Literal["sentence_embedding", "token_embeddings"] | None = ...,
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = ...,
convert_to_numpy: Literal[False] = ...,
convert_to_tensor: Literal[False] = ...,
device: str = ...,
normalize_embeddings: bool = ...,
**kwargs,
) -> Tensor: ...
@overload
def encode(
self,
sentences: str | list[str],
prompt_name: str | None = ...,
prompt: str | None = ...,
batch_size: int = ...,
show_progress_bar: bool | None = ...,
output_value: Literal["sentence_embedding", "token_embeddings"] | None = ...,
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = ...,
convert_to_numpy: Literal[True] = ...,
convert_to_tensor: Literal[False] = ...,
device: str = ...,
normalize_embeddings: bool = ...,
**kwargs,
) -> np.ndarray: ...
@overload
def encode(
self,
sentences: str | list[str],
prompt_name: str | None = ...,
prompt: str | None = ...,
batch_size: int = ...,
show_progress_bar: bool | None = ...,
output_value: Literal["sentence_embedding", "token_embeddings"] | None = ...,
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = ...,
convert_to_numpy: bool = ...,
convert_to_tensor: Literal[True] = ...,
device: str = ...,
normalize_embeddings: bool = ...,
**kwargs,
) -> Tensor: ...
@overload
def encode(
self,
sentences: list[str] | np.ndarray,
prompt_name: str | None = ...,
prompt: str | None = ...,
batch_size: int = ...,
show_progress_bar: bool | None = ...,
output_value: Literal["sentence_embedding", "token_embeddings"] | None = ...,
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = ...,
convert_to_numpy: Literal[False] = ...,
convert_to_tensor: Literal[False] = ...,
device: str = ...,
normalize_embeddings: bool = ...,
**kwargs,
) -> list[Tensor]: ...
def encode(
self,
sentences: str | list[str],
prompt_name: str | None = None,
prompt: str | None = None,
batch_size: int = 32,
show_progress_bar: bool | None = None,
output_value: Literal["sentence_embedding", "token_embeddings"] | None = "sentence_embedding",
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
device: str = None,
normalize_embeddings: bool = False,
**kwargs,
) -> list[Tensor] | np.ndarray | Tensor:
"""
Computes sentence embeddings.
Args:
sentences (Union[str, List[str]]): The sentences to embed.
prompt_name (Optional[str], optional): The name of the prompt to use for encoding. Must be a key in the `prompts` dictionary,
which is either set in the constructor or loaded from the model configuration. For example if
``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, then the sentence "What
is the capital of France?" will be encoded as "query: What is the capital of France?" because the sentence
is appended to the prompt. If ``prompt`` is also set, this argument is ignored. Defaults to None.
prompt (Optional[str], optional): The prompt to use for encoding. For example, if the prompt is "query: ", then the
sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
because the sentence is appended to the prompt. If ``prompt`` is set, ``prompt_name`` is ignored. Defaults to None.
batch_size (int, optional): The batch size used for the computation. Defaults to 32.
show_progress_bar (bool, optional): Whether to output a progress bar when encode sentences. Defaults to None.
output_value (Optional[Literal["sentence_embedding", "token_embeddings"]], optional): The type of embeddings to return:
"sentence_embedding" to get sentence embeddings, "token_embeddings" to get wordpiece token embeddings, and `None`,
to get all output values. Defaults to "sentence_embedding".
precision (Literal["float32", "int8", "uint8", "binary", "ubinary"], optional): The precision to use for the embeddings.
Can be "float32", "int8", "uint8", "binary", or "ubinary". All non-float32 precisions are quantized embeddings.
Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy. They are useful for
reducing the size of the embeddings of a corpus for semantic search, among other tasks. Defaults to "float32".
convert_to_numpy (bool, optional): Whether the output should be a list of numpy vectors. If False, it is a list of PyTorch tensors.
Defaults to True.
convert_to_tensor (bool, optional): Whether the output should be one large tensor. Overwrites `convert_to_numpy`.
Defaults to False.
device (str, optional): Which :class:`torch.device` to use for the computation. Defaults to None.
normalize_embeddings (bool, optional): Whether to normalize returned vectors to have length 1. In that case,
the faster dot-product (util.dot_score) instead of cosine similarity can be used. Defaults to False.
Returns:
Union[List[Tensor], ndarray, Tensor]: By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned.
If only one string input is provided, then the output is a 1d array with shape [output_dimension]. If ``convert_to_tensor``,
a torch Tensor is returned instead. If ``self.truncate_dim <= output_dimension`` then output_dimension is ``self.truncate_dim``.
Example:
::
from sentence_transformers import SentenceTransformer
# Load a pre-trained SentenceTransformer model
model = SentenceTransformer('all-mpnet-base-v2')
# Encode some texts
sentences = [
"The weather is lovely today.",
"It's so sunny outside!",
"He drove to the stadium.",
]
embeddings = model.encode(sentences)
print(embeddings.shape)
# (3, 768)
"""
if self.device.type == "hpu" and not self.is_hpu_graph_enabled:
import habana_frameworks.torch as ht
ht.hpu.wrap_in_hpu_graph(self, disable_tensor_cache=True)
self.is_hpu_graph_enabled = True
self.eval()
if show_progress_bar is None:
show_progress_bar = logger.getEffectiveLevel() in (logging.INFO, logging.DEBUG)
if convert_to_tensor:
convert_to_numpy = False
if output_value != "sentence_embedding":
convert_to_tensor = False
convert_to_numpy = False
input_was_string = False
if isinstance(sentences, str) or not hasattr(
sentences, "__len__"
): # Cast an individual sentence to a list with length 1
sentences = [sentences]
input_was_string = True
if prompt is None:
if prompt_name is not None:
try:
prompt = self.prompts[prompt_name]
except KeyError:
raise ValueError(
f"Prompt name '{prompt_name}' not found in the configured prompts dictionary with keys {list(self.prompts.keys())!r}."
)
elif self.default_prompt_name is not None:
prompt = self.prompts.get(self.default_prompt_name, None)
else:
if prompt_name is not None:
logger.warning(
"Encode with either a `prompt`, a `prompt_name`, or neither, but not both. "
"Ignoring the `prompt_name` in favor of `prompt`."
)
extra_features = {}
if prompt is not None:
sentences = [prompt + sentence for sentence in sentences]
# Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling
# Tracking the prompt length allow us to remove the prompt during pooling
tokenized_prompt = self.tokenize([prompt])
if "input_ids" in tokenized_prompt:
extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1
if device is None:
device = self.device
self.to(device)
all_embeddings = []
length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
sentences_batch = sentences_sorted[start_index : start_index + batch_size]
features = self.tokenize(sentences_batch)
if self.device.type == "hpu":
if "input_ids" in features:
curr_tokenize_len = features["input_ids"].shape
additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1]
features["input_ids"] = torch.cat(
(
features["input_ids"],
torch.ones((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
),
-1,
)
features["attention_mask"] = torch.cat(
(
features["attention_mask"],
torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
),
-1,
)
if "token_type_ids" in features:
features["token_type_ids"] = torch.cat(
(
features["token_type_ids"],
torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
),
-1,
)
features = batch_to_device(features, device)
features.update(extra_features)
with torch.no_grad():
out_features = self.forward(features, **kwargs)
if self.device.type == "hpu":
out_features = copy.deepcopy(out_features)
out_features["sentence_embedding"] = truncate_embeddings(
out_features["sentence_embedding"], self.truncate_dim
)
if output_value == "token_embeddings":
embeddings = []
for token_emb, attention in zip(out_features[output_value], out_features["attention_mask"]):
last_mask_id = len(attention) - 1
while last_mask_id > 0 and attention[last_mask_id].item() == 0:
last_mask_id -= 1
embeddings.append(token_emb[0 : last_mask_id + 1])
elif output_value is None: # Return all outputs
embeddings = []
for sent_idx in range(len(out_features["sentence_embedding"])):
row = {name: out_features[name][sent_idx] for name in out_features}
embeddings.append(row)
else: # Sentence embeddings
embeddings = out_features[output_value]
embeddings = embeddings.detach()
if normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
# fixes for #522 and #487 to avoid oom problems on gpu with large datasets
if convert_to_numpy:
embeddings = embeddings.cpu()
all_embeddings.extend(embeddings)
all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
if precision and precision != "float32":
all_embeddings = quantize_embeddings(all_embeddings, precision=precision)
if convert_to_tensor:
if len(all_embeddings):
if isinstance(all_embeddings, np.ndarray):
all_embeddings = torch.from_numpy(all_embeddings)
else:
all_embeddings = torch.stack(all_embeddings)
else:
all_embeddings = torch.Tensor()
elif convert_to_numpy:
if not isinstance(all_embeddings, np.ndarray):
if all_embeddings and all_embeddings[0].dtype == torch.bfloat16:
all_embeddings = np.asarray([emb.float().numpy() for emb in all_embeddings])
else:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
elif isinstance(all_embeddings, np.ndarray):
all_embeddings = [torch.from_numpy(embedding) for embedding in all_embeddings]
if input_was_string:
all_embeddings = all_embeddings[0]
return all_embeddings
def forward(self, input: dict[str, Tensor], **kwargs) -> dict[str, Tensor]:
if self.module_kwargs is None:
return super().forward(input)
for module_name, module in self.named_children():
module_kwarg_keys = self.module_kwargs.get(module_name, [])
module_kwargs = {key: value for key, value in kwargs.items() if key in module_kwarg_keys}
input = module(input, **module_kwargs)
return input
@property
def similarity_fn_name(self) -> Literal["cosine", "dot", "euclidean", "manhattan"]:
"""Return the name of the similarity function used by :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise`.
Returns:
Optional[str]: The name of the similarity function. Can be None if not set, in which case it will
default to "cosine" when first called.
Example:
>>> model = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
>>> model.similarity_fn_name
'dot'
"""
if self._similarity_fn_name is None:
self.similarity_fn_name = SimilarityFunction.COSINE
return self._similarity_fn_name
@similarity_fn_name.setter
def similarity_fn_name(
self, value: Literal["cosine", "dot", "euclidean", "manhattan"] | SimilarityFunction
) -> None:
if isinstance(value, SimilarityFunction):
value = value.value
self._similarity_fn_name = value
if value is not None:
self._similarity = SimilarityFunction.to_similarity_fn(value)
self._similarity_pairwise = SimilarityFunction.to_similarity_pairwise_fn(value)
@overload
def similarity(self, embeddings1: Tensor, embeddings2: Tensor) -> Tensor: ...
@overload
def similarity(self, embeddings1: npt.NDArray[np.float32], embeddings2: npt.NDArray[np.float32]) -> Tensor: ...
@property
def similarity(self) -> Callable[[Tensor | npt.NDArray[np.float32], Tensor | npt.NDArray[np.float32]], Tensor]:
"""
Compute the similarity between two collections of embeddings. The output will be a matrix with the similarity
scores between all embeddings from the first parameter and all embeddings from the second parameter. This
differs from `similarity_pairwise` which computes the similarity between each pair of embeddings.
This method supports only embeddings with fp32 precision and does not accommodate quantized embeddings.
Args:
embeddings1 (Union[Tensor, ndarray]): [num_embeddings_1, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
embeddings2 (Union[Tensor, ndarray]): [num_embeddings_2, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
Returns:
Tensor: A [num_embeddings_1, num_embeddings_2]-shaped torch tensor with similarity scores.
Example:
::
>>> model = SentenceTransformer("all-mpnet-base-v2")
>>> sentences = [
... "The weather is so nice!",
... "It's so sunny outside.",
... "He's driving to the movie theater.",
... "She's going to the cinema.",
... ]
>>> embeddings = model.encode(sentences, normalize_embeddings=True)
>>> model.similarity(embeddings, embeddings)
tensor([[1.0000, 0.7235, 0.0290, 0.1309],
[0.7235, 1.0000, 0.0613, 0.1129],
[0.0290, 0.0613, 1.0000, 0.5027],
[0.1309, 0.1129, 0.5027, 1.0000]])
>>> model.similarity_fn_name
"cosine"
>>> model.similarity_fn_name = "euclidean"
>>> model.similarity(embeddings, embeddings)
tensor([[-0.0000, -0.7437, -1.3935, -1.3184],
[-0.7437, -0.0000, -1.3702, -1.3320],
[-1.3935, -1.3702, -0.0000, -0.9973],
[-1.3184, -1.3320, -0.9973, -0.0000]])
"""
if self.similarity_fn_name is None:
self.similarity_fn_name = SimilarityFunction.COSINE
return self._similarity
@overload
def similarity_pairwise(self, embeddings1: Tensor, embeddings2: Tensor) -> Tensor: ...
@overload
def similarity_pairwise(
self, embeddings1: npt.NDArray[np.float32], embeddings2: npt.NDArray[np.float32]
) -> Tensor: ...
@property
def similarity_pairwise(
self,
) -> Callable[[Tensor | npt.NDArray[np.float32], Tensor | npt.NDArray[np.float32]], Tensor]:
"""
Compute the similarity between two collections of embeddings. The output will be a vector with the similarity
scores between each pair of embeddings.
This method supports only embeddings with fp32 precision and does not accommodate quantized embeddings.
Args:
embeddings1 (Union[Tensor, ndarray]): [num_embeddings, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
embeddings2 (Union[Tensor, ndarray]): [num_embeddings, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
Returns:
Tensor: A [num_embeddings]-shaped torch tensor with pairwise similarity scores.
Example:
::
>>> model = SentenceTransformer("all-mpnet-base-v2")
>>> sentences = [
... "The weather is so nice!",
... "It's so sunny outside.",
... "He's driving to the movie theater.",
... "She's going to the cinema.",
... ]
>>> embeddings = model.encode(sentences, normalize_embeddings=True)
>>> model.similarity_pairwise(embeddings[::2], embeddings[1::2])
tensor([0.7235, 0.5027])
>>> model.similarity_fn_name
"cosine"
>>> model.similarity_fn_name = "euclidean"
>>> model.similarity_pairwise(embeddings[::2], embeddings[1::2])
tensor([-0.7437, -0.9973])
"""
if self.similarity_fn_name is None:
self.similarity_fn_name = SimilarityFunction.COSINE
return self._similarity_pairwise
def start_multi_process_pool(
self, target_devices: list[str] = None
) -> dict[Literal["input", "output", "processes"], Any]:
"""
Starts a multi-process pool to process the encoding with several independent processes
via :meth:`SentenceTransformer.encode_multi_process <sentence_transformers.SentenceTransformer.encode_multi_process>`.
This method is recommended if you want to encode on multiple GPUs or CPUs. It is advised
to start only one process per GPU. This method works together with encode_multi_process
and stop_multi_process_pool.
Args:
target_devices (List[str], optional): PyTorch target devices, e.g. ["cuda:0", "cuda:1", ...],
["npu:0", "npu:1", ...], or ["cpu", "cpu", "cpu", "cpu"]. If target_devices is None and CUDA/NPU
is available, then all available CUDA/NPU devices will be used. If target_devices is None and
CUDA/NPU is not available, then 4 CPU devices will be used.
Returns:
Dict[str, Any]: A dictionary with the target processes, an input queue, and an output queue.
"""
if target_devices is None:
if torch.cuda.is_available():
target_devices = [f"cuda:{i}" for i in range(torch.cuda.device_count())]
elif is_torch_npu_available():
target_devices = [f"npu:{i}" for i in range(torch.npu.device_count())]
else:
logger.info("CUDA/NPU is not available. Starting 4 CPU workers")
target_devices = ["cpu"] * 4
logger.info("Start multi-process pool on devices: {}".format(", ".join(map(str, target_devices))))
self.to("cpu")
self.share_memory()
ctx = mp.get_context("spawn")
input_queue = ctx.Queue()
output_queue = ctx.Queue()
processes = []
for device_id in target_devices:
p = ctx.Process(
target=SentenceTransformer._encode_multi_process_worker,
args=(device_id, self, input_queue, output_queue),
daemon=True,
)
p.start()
processes.append(p)
return {"input": input_queue, "output": output_queue, "processes": processes}
@staticmethod
def stop_multi_process_pool(pool: dict[Literal["input", "output", "processes"], Any]) -> None:
"""
Stops all processes started with start_multi_process_pool.
Args:
pool (Dict[str, object]): A dictionary containing the input queue, output queue, and process list.
Returns:
None
"""
for p in pool["processes"]:
p.terminate()
for p in pool["processes"]:
p.join()
p.close()
pool["input"].close()
pool["output"].close()
def encode_multi_process(
self,
sentences: list[str],
pool: dict[Literal["input", "output", "processes"], Any],
prompt_name: str | None = None,
prompt: str | None = None,
batch_size: int = 32,
chunk_size: int = None,
show_progress_bar: bool | None = None,
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
normalize_embeddings: bool = False,
) -> np.ndarray:
"""
Encodes a list of sentences using multiple processes and GPUs via
:meth:`SentenceTransformer.encode <sentence_transformers.SentenceTransformer.encode>`.
The sentences are chunked into smaller packages and sent to individual processes, which encode them on different
GPUs or CPUs. This method is only suitable for encoding large sets of sentences.
Args:
sentences (List[str]): List of sentences to encode.
pool (Dict[Literal["input", "output", "processes"], Any]): A pool of workers started with
:meth:`SentenceTransformer.start_multi_process_pool <sentence_transformers.SentenceTransformer.start_multi_process_pool>`.
prompt_name (Optional[str], optional): The name of the prompt to use for encoding. Must be a key in the `prompts` dictionary,
which is either set in the constructor or loaded from the model configuration. For example if
``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, then the sentence "What
is the capital of France?" will be encoded as "query: What is the capital of France?" because the sentence
is appended to the prompt. If ``prompt`` is also set, this argument is ignored. Defaults to None.
prompt (Optional[str], optional): The prompt to use for encoding. For example, if the prompt is "query: ", then the
sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
because the sentence is appended to the prompt. If ``prompt`` is set, ``prompt_name`` is ignored. Defaults to None.
batch_size (int): Encode sentences with batch size. (default: 32)
chunk_size (int): Sentences are chunked and sent to the individual processes. If None, it determines a
sensible size. Defaults to None.
show_progress_bar (bool, optional): Whether to output a progress bar when encode sentences. Defaults to None.
precision (Literal["float32", "int8", "uint8", "binary", "ubinary"]): The precision to use for the
embeddings. Can be "float32", "int8", "uint8", "binary", or "ubinary". All non-float32 precisions
are quantized embeddings. Quantized embeddings are smaller in size and faster to compute, but may
have lower accuracy. They are useful for reducing the size of the embeddings of a corpus for
semantic search, among other tasks. Defaults to "float32".
normalize_embeddings (bool): Whether to normalize returned vectors to have length 1. In that case,
the faster dot-product (util.dot_score) instead of cosine similarity can be used. Defaults to False.
Returns:
np.ndarray: A 2D numpy array with shape [num_inputs, output_dimension].
Example:
::
from sentence_transformers import SentenceTransformer
def main():
model = SentenceTransformer("all-mpnet-base-v2")
sentences = ["The weather is so nice!", "It's so sunny outside.", "He's driving to the movie theater.", "She's going to the cinema."] * 1000
pool = model.start_multi_process_pool()
embeddings = model.encode_multi_process(sentences, pool)
model.stop_multi_process_pool(pool)
print(embeddings.shape)
# => (4000, 768)
if __name__ == "__main__":
main()
"""
if chunk_size is None:
chunk_size = min(math.ceil(len(sentences) / len(pool["processes"]) / 10), 5000)
if show_progress_bar is None:
show_progress_bar = logger.getEffectiveLevel() in (logging.INFO, logging.DEBUG)
logger.debug(f"Chunk data into {math.ceil(len(sentences) / chunk_size)} packages of size {chunk_size}")
input_queue = pool["input"]
last_chunk_id = 0
chunk = []
for sentence in sentences:
chunk.append(sentence)
if len(chunk) >= chunk_size:
input_queue.put(
[last_chunk_id, batch_size, chunk, prompt_name, prompt, precision, normalize_embeddings]
)
last_chunk_id += 1
chunk = []
if len(chunk) > 0:
input_queue.put([last_chunk_id, batch_size, chunk, prompt_name, prompt, precision, normalize_embeddings])
last_chunk_id += 1
output_queue = pool["output"]
results_list = sorted(
[output_queue.get() for _ in trange(last_chunk_id, desc="Chunks", disable=not show_progress_bar)],
key=lambda x: x[0],
)
embeddings = np.concatenate([result[1] for result in results_list])
return embeddings
@staticmethod
def _encode_multi_process_worker(
target_device: str, model: SentenceTransformer, input_queue: Queue, results_queue: Queue
) -> None:
"""
Internal working process to encode sentences in multi-process setup
"""
while True:
try:
chunk_id, batch_size, sentences, prompt_name, prompt, precision, normalize_embeddings = (
input_queue.get()
)
embeddings = model.encode(
sentences,