Skip to content

Commit

Permalink
Add instructor embedding and update Mistral embedding. (#30)
Browse files Browse the repository at this point in the history
Signed-off-by: zilliz <xy.wang@zilliz.com>
Signed-off-by: wxywb <xy.wang@zilliz.com>
  • Loading branch information
wxywb authored Aug 15, 2024
1 parent a450490 commit 9fe20bf
Show file tree
Hide file tree
Showing 9 changed files with 813 additions and 12 deletions.
6 changes: 6 additions & 0 deletions milvus_model/dense/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"CohereEmbeddingFunction",
"MistralAIEmbeddingFunction",
"NomicEmbeddingFunction",
"InstructorEmbeddingFunction"
]

from milvus_model.utils.lazy_import import LazyImport
Expand All @@ -21,6 +22,7 @@
cohere = LazyImport("cohere", globals(), "milvus_model.dense.cohere")
mistralai = LazyImport("mistralai", globals(), "milvus_model.dense.mistralai")
nomic = LazyImport("nomic", globals(), "milvus_model.dense.nomic")
instructor = LazyImport("instructor", globals(), "milvus_model.dense.instructor")


def JinaEmbeddingFunction(*args, **kwargs):
Expand Down Expand Up @@ -53,3 +55,7 @@ def MistralAIEmbeddingFunction(*args, **kwargs):

def NomicEmbeddingFunction(*args, **kwargs):
return nomic.NomicEmbeddingFunction(*args, **kwargs)


def InstructorEmbeddingFunction(*args, **kwargs):
return instructor.InstructorEmbeddingFunction(*args, **kwargs)
77 changes: 77 additions & 0 deletions milvus_model/dense/instructor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import List, Optional
import struct
from collections import defaultdict
import numpy as np

from milvus_model.base import BaseEmbeddingFunction
from milvus_model.utils import import_sentence_transformers, import_huggingface_hub

import_sentence_transformers()
import_huggingface_hub()

from .instructor_embedding.instructor_impl import Instructor

class InstructorEmbeddingFunction(BaseEmbeddingFunction):
def __init__(
self,
model_name: str = "hkunlp/instructor-xl",
batch_size: int = 32,
query_instruction: str = "Represent the question for retrieval:",
doc_instruction: str = "Represent the document for retrieval:",
device: str = "cpu",
normalize_embeddings: bool = True,
**kwargs,
):
self.model_name = model_name
self.query_instruction = query_instruction
self.doc_instruction = doc_instruction
self.batch_size = batch_size
self.normalize_embeddings = normalize_embeddings

_model_config = dict({"model_name_or_path": model_name, "device": device}, **kwargs)
self.model = Instructor(**_model_config)

def __call__(self, texts: List[str]) -> List[np.array]:
return self._encode([[self.doc_instruction, text] for text in texts])

def _encode(self, texts: List[str]) -> List[np.array]:
embs = self.model.encode(
texts, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True,
)
return list(embs)

@property
def dim(self):
return self.model.get_sentence_embedding_dimension()

def encode_queries(self, queries: List[str]) -> List[np.array]:
instructed_queries = [[self.query_instruction, query] for query in queries]
return self._encode(instructed_queries)

def encode_documents(self, documents: List[str]) -> List[np.array]:
instructed_documents = [[self.doc_instruction, document] for document in documents]
return self._encode(instructed_documents)

def _encode_query(self, query: str) -> np.array:
instructed_query = self.query_instruction + query
embs = self.model.encode(
sentences=[instructed_query],
batch_size=1,
show_progress_bar=False,
convert_to_numpy=True,
normalize_embeddings=self.normalize_embeddings,
)
return embs[0]

def _encode_document(self, document: str) -> np.array:
instructed_document = self.doc_instruction + document
embs = self.model.encode(
sentences=[instructed_document],
batch_size=1,
show_progress_bar=False,
convert_to_numpy=True,
normalize_embeddings=self.normalize_embeddings,
)
return embs[0]


Loading

0 comments on commit 9fe20bf

Please sign in to comment.