Skip to content

Commit

Permalink
Add tests and separate embed and query instructions (#836)
Browse files Browse the repository at this point in the history
@enoreyes @hwchase17 I have taken a stab at the cleanups. Note that in
my testing the separate query instruction was quite important (i.e.
performance suffered if you didn't separate them) - so I have put that
in to.
  • Loading branch information
seanaedmiston authored Feb 2, 2023
1 parent d9fa5e4 commit 24957ba
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
20 changes: 13 additions & 7 deletions langchain/embeddings/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from langchain.embeddings.base import Embeddings

DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
DEFAULT_INSTRUCTION = "Represent the following text:"
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: "
DEFAULT_QUERY_INSTRUCTION = (
"Represent the question for retrieving supporting documents: "
)


class HuggingFaceEmbeddings(BaseModel, Embeddings):
Expand Down Expand Up @@ -80,15 +84,17 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
.. code-block:: python
from langchain.embeddings import HuggingFaceInstructEmbeddings
model_name = TODO
model_name = "hkunlp/instructor-large"
hf = HuggingFaceInstructEmbeddings(model_name=model_name)
"""

client: Any #: :meta private:
model_name: str = DEFAULT_MODEL_NAME
model_name: str = DEFAULT_INSTRUCT_MODEL
"""Model name to use."""
instruction: str = DEFAULT_INSTRUCTION
"""Instruction to use."""
embed_instruction: str = DEFAULT_EMBED_INSTRUCTION
"""Instruction to use for embedding documents."""
query_instruction: str = DEFAULT_QUERY_INSTRUCTION
"""Instruction to use for embedding query."""

def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
Expand Down Expand Up @@ -119,7 +125,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
instruction_pairs = []
for text in texts:
instruction_pairs.append([self.instruction, text])
instruction_pairs.append([self.embed_instruction, text])
embeddings = self.client.encode(instruction_pairs)
return embeddings.tolist()

Expand All @@ -132,6 +138,6 @@ def embed_query(self, text: str) -> List[float]:
Returns:
Embeddings for the text.
"""
instruction_pair = [self.instruction, text]
instruction_pair = [self.query_instruction, text]
embedding = self.client.encode([instruction_pair])[0]
return embedding.tolist()
22 changes: 21 additions & 1 deletion tests/integration_tests/embeddings/test_huggingface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Test huggingface embeddings."""
import unittest

from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.huggingface import (
HuggingFaceEmbeddings,
HuggingFaceInstructEmbeddings,
)


@unittest.skip("This test causes a segfault.")
Expand All @@ -21,3 +24,20 @@ def test_huggingface_embedding_query() -> None:
embedding = HuggingFaceEmbeddings()
output = embedding.embed_query(document)
assert len(output) == 768


def test_huggingface_instructor_embedding_documents() -> None:
"""Test huggingface embeddings."""
documents = ["foo bar"]
embedding = HuggingFaceInstructEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 768


def test_huggingface_instructor_embedding_query() -> None:
"""Test huggingface embeddings."""
query = "foo bar"
embedding = HuggingFaceInstructEmbeddings()
output = embedding.embed_query(query)
assert len(output) == 768

0 comments on commit 24957ba

Please sign in to comment.