Skip to content

Commit

Permalink
Merge #54: Fix UTF-8 encoding issue and add embedding model configura…
Browse files Browse the repository at this point in the history
…tion.
  • Loading branch information
drazvan committed Jun 30, 2023
2 parents 0669a61 + 4d3bdf0 commit fa3d65e
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 11 deletions.
18 changes: 14 additions & 4 deletions nemoguardrails/actions/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ def __init__(

self.llm_task_manager = llm_task_manager

# If we have a customized embedding model, we'll use it.
self.embedding_model = "all-MiniLM-L6-v2"
for model in self.config.models:
if "embedding" in model.type:
self.embedding_model = model.model
assert model.engine == "SentenceTransformer"
break

def _init_user_message_index(self):
"""Initializes the index of user messages."""

Expand All @@ -93,7 +101,7 @@ def _init_user_message_index(self):
if len(items) == 0:
return

self.user_message_index = BasicEmbeddingsIndex()
self.user_message_index = BasicEmbeddingsIndex(self.embedding_model)
self.user_message_index.add_items(items)

# NOTE: this should be very fast, otherwise needs to be moved to separate thread.
Expand All @@ -114,7 +122,7 @@ def _init_bot_message_index(self):
if len(items) == 0:
return

self.bot_message_index = BasicEmbeddingsIndex()
self.bot_message_index = BasicEmbeddingsIndex(self.embedding_model)
self.bot_message_index.add_items(items)

# NOTE: this should be very fast, otherwise needs to be moved to separate thread.
Expand Down Expand Up @@ -148,7 +156,7 @@ def _init_flows_index(self):
if len(items) == 0:
return

self.flows_index = BasicEmbeddingsIndex()
self.flows_index = BasicEmbeddingsIndex(self.embedding_model)
self.flows_index.add_items(items)

# NOTE: this should be very fast, otherwise needs to be moved to separate thread.
Expand All @@ -161,7 +169,9 @@ def _init_kb(self):
return

documents = [doc.content for doc in self.config.docs]
self.kb = KnowledgeBase(documents=documents)
self.kb = KnowledgeBase(
documents=documents, embedding_model=self.embedding_model
)
self.kb.init()
self.kb.build()

Expand Down
5 changes: 3 additions & 2 deletions nemoguardrails/kb/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):
It uses Annoy to perform the search.
"""

def __init__(self, index=None):
def __init__(self, embedding_model=None, index=None):
self._model = None
self._items = []
self._embeddings = []
self.embedding_model = embedding_model

# When the index is provided, it means it's from the cache.
self._index = index
Expand All @@ -42,7 +43,7 @@ def embeddings_index(self):

def _init_model(self):
"""Initialize the model used for computing the embeddings."""
self._model = SentenceTransformer("all-MiniLM-L6-v2")
self._model = SentenceTransformer(self.embedding_model)

def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Compute embeddings for a list of texts."""
Expand Down
9 changes: 6 additions & 3 deletions nemoguardrails/kb/kb.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@
class KnowledgeBase:
"""Basic implementation of a knowledge base."""

def __init__(self, documents: List[str]):
def __init__(self, documents: List[str], embedding_model: str):
self.documents = documents
self.chunks = []
self.index = None
self.embedding_model = embedding_model

def init(self):
"""Initialize the knowledge base.
Expand Down Expand Up @@ -79,10 +80,12 @@ def build(self):
ann_index = AnnoyIndex(embedding_size, "angular")
ann_index.load(cache_file)

self.index = BasicEmbeddingsIndex(index=ann_index)
self.index = BasicEmbeddingsIndex(
embedding_model=self.embedding_model, index=ann_index
)
self.index.add_items(index_items)
else:
self.index = BasicEmbeddingsIndex()
self.index = BasicEmbeddingsIndex(self.embedding_model)
self.index.add_items(index_items)
self.index.build()

Expand Down
4 changes: 2 additions & 2 deletions nemoguardrails/rails/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,11 @@ def from_path(
)

elif file.endswith(".yml") or file.endswith(".yaml"):
with open(full_path) as f:
with open(full_path, "r", encoding="utf-8") as f:
_raw_config = yaml.safe_load(f.read())

elif file.endswith(".co"):
with open(full_path) as f:
with open(full_path, "r", encoding="utf-8") as f:
_raw_config = parse_colang_file(file, content=f.read())

# Extract test set if needed before adding the _raw_config to the app config in raw_config
Expand Down

0 comments on commit fa3d65e

Please sign in to comment.