Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added alternative embedding models for sentence transformers and openai #101

Merged
merged 1 commit into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions nemoguardrails/actions/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from nemoguardrails.kb.kb import KnowledgeBase
from nemoguardrails.language.parser import parse_colang_file
from nemoguardrails.llm.params import llm_params
from nemoguardrails.llm.providers import get_embedding_provider_names
from nemoguardrails.llm.taskmanager import LLMTaskManager
from nemoguardrails.llm.types import Task
from nemoguardrails.rails.llm.config import RailsConfig
Expand Down Expand Up @@ -71,7 +72,8 @@ def __init__(
for model in self.config.models:
if model.type == "embedding":
self.embedding_model = model.model
assert model.engine == "SentenceTransformer"
assert model.engine in get_embedding_provider_names()
self.embedding_engine = model.engine
break

# If we have user messages, we build an index with them
Expand Down Expand Up @@ -108,7 +110,9 @@ def _init_user_message_index(self):
if len(items) == 0:
return

self.user_message_index = BasicEmbeddingsIndex(self.embedding_model)
self.user_message_index = BasicEmbeddingsIndex(
embedding_model=self.embedding_model, embedding_engine=self.embedding_engine
)
self.user_message_index.add_items(items)

# NOTE: this should be very fast, otherwise needs to be moved to separate thread.
Expand All @@ -129,7 +133,9 @@ def _init_bot_message_index(self):
if len(items) == 0:
return

self.bot_message_index = BasicEmbeddingsIndex(self.embedding_model)
self.bot_message_index = BasicEmbeddingsIndex(
embedding_model=self.embedding_model, embedding_engine=self.embedding_engine
)
self.bot_message_index.add_items(items)

# NOTE: this should be very fast, otherwise needs to be moved to separate thread.
Expand Down Expand Up @@ -163,7 +169,9 @@ def _init_flows_index(self):
if len(items) == 0:
return

self.flows_index = BasicEmbeddingsIndex(self.embedding_model)
self.flows_index = BasicEmbeddingsIndex(
embedding_model=self.embedding_model, embedding_engine=self.embedding_engine
)
self.flows_index.add_items(items)

# NOTE: this should be very fast, otherwise needs to be moved to separate thread.
Expand Down
52 changes: 48 additions & 4 deletions nemoguardrails/kb/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

from typing import List

import openai
from annoy import AnnoyIndex
from sentence_transformers import SentenceTransformer
from torch import cuda

from nemoguardrails.kb.index import EmbeddingsIndex, IndexItem
from nemoguardrails.kb.index import EmbeddingModel, EmbeddingsIndex, IndexItem


class BasicEmbeddingsIndex(EmbeddingsIndex):
Expand All @@ -28,11 +30,12 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):
It uses Annoy to perform the search.
"""

def __init__(self, embedding_model=None, index=None):
def __init__(self, embedding_model=None, embedding_engine=None, index=None):
self._model = None
self._items = []
self._embeddings = []
self.embedding_model = embedding_model
self.embedding_engine = embedding_engine

# When the index is provided, it means it's from the cache.
self._index = index
Expand All @@ -43,15 +46,17 @@ def embeddings_index(self):

def _init_model(self):
"""Initialize the model used for computing the embeddings."""
self._model = SentenceTransformer(self.embedding_model)
self._model = init_embedding_model(
embedding_model=self.embedding_model, embedding_engine=self.embedding_engine
)

def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Compute embeddings for a list of texts."""
if self._model is None:
self._init_model()

embeddings = self._model.encode(texts)
return [embedding.tolist() for embedding in embeddings]
return embeddings

def add_item(self, item: IndexItem):
"""Add a single item to the index."""
Expand Down Expand Up @@ -85,3 +90,42 @@ def search(self, text: str, max_results: int = 20) -> List[IndexItem]:
)

return [self._items[i] for i in results]


class SentenceTransformerEmbeddingModel(EmbeddingModel):
"""Embedding model using sentence-transformers."""

def __init__(self, embedding_model: str):
device = "cuda" if cuda.is_available() else "cpu"
self.model = SentenceTransformer(embedding_model, device=device)
# Get the embedding dimension of the model
self.embedding_size = self.model.get_sentence_embedding_dimension()

def encode(self, documents: List[str]) -> List[List[float]]:
return self.model.encode(documents)


class OpenAIEmbeddingModel(EmbeddingModel):
"""Embedding model using OpenAI API."""

def __init__(self, embedding_model: str):
self.model = embedding_model
self.embedding_size = len(self.encode(["test"])[0])

def encode(self, documents: List[str]) -> List[List[float]]:
"""Encode a list of documents into embeddings."""

# Make embedding request to OpenAI API
res = openai.Embedding.create(input=documents, engine=self.model)
embeddings = [record["embedding"] for record in res["data"]]
return embeddings


def init_embedding_model(embedding_model: str, embedding_engine: str) -> EmbeddingModel:
"""Initialize the embedding model."""
if embedding_engine == "SentenceTransformers":
return SentenceTransformerEmbeddingModel(embedding_model)
elif embedding_engine == "openai":
return OpenAIEmbeddingModel(embedding_model)
else:
raise ValueError(f"Invalid embedding engine: {embedding_engine}")
8 changes: 8 additions & 0 deletions nemoguardrails/kb/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,11 @@ def build(self):
def search(self, text: str, max_results: int) -> List[IndexItem]:
"""Searches the index for the closes matches to the provided text."""
raise NotImplementedError()


class EmbeddingModel:
"""The embedding model is responsible for creating the embeddings."""

def encode(self, documents: List[str]) -> List[List[float]]:
"""Encode the provided documents into embeddings."""
raise NotImplementedError()
23 changes: 18 additions & 5 deletions nemoguardrails/kb/kb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from annoy import AnnoyIndex

from nemoguardrails.kb.basic import BasicEmbeddingsIndex
from nemoguardrails.kb.basic import BasicEmbeddingsIndex, init_embedding_model
from nemoguardrails.kb.index import IndexItem
from nemoguardrails.kb.utils import split_markdown_in_topic_chunks

Expand All @@ -33,11 +33,14 @@
class KnowledgeBase:
"""Basic implementation of a knowledge base."""

def __init__(self, documents: List[str], embedding_model: str):
def __init__(
self, documents: List[str], embedding_model: str, embedding_engine: str
):
self.documents = documents
self.chunks = []
self.index = None
self.embedding_model = embedding_model
self.embedding_engine = embedding_engine

def init(self):
"""Initialize the knowledge base.
Expand Down Expand Up @@ -76,16 +79,26 @@ def build(self):
# If we have already computed this before, we use it
if os.path.exists(cache_file):
# TODO: this should not be hardcoded. Currently set for all-MiniLM-L6-v2.
embedding_size = 384
# Get embedding size from model
model = init_embedding_model(
embedding_model=self.embedding_model,
embedding_engine=self.embedding_engine,
)
embedding_size = model.embedding_size
ann_index = AnnoyIndex(embedding_size, "angular")
ann_index.load(cache_file)

self.index = BasicEmbeddingsIndex(
embedding_model=self.embedding_model, index=ann_index
embedding_model=self.embedding_model,
embedding_engine=self.embedding_engine,
index=ann_index,
)
self.index.add_items(index_items)
else:
self.index = BasicEmbeddingsIndex(self.embedding_model)
self.index = BasicEmbeddingsIndex(
embedding_model=self.embedding_model,
embedding_engine=self.embedding_engine,
)
self.index.add_items(index_items)
self.index.build()

Expand Down
5 changes: 5 additions & 0 deletions nemoguardrails/llm/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,8 @@ def get_llm_provider(model_config: Model) -> Type[BaseLanguageModel]:
def get_llm_provider_names() -> List[str]:
"""Returns the list of supported LLM providers."""
return list(sorted(list(_providers.keys())))


def get_embedding_provider_names() -> List[str]:
"""Returns the list of supported embedding providers."""
return ["openai", "SentenceTransformers"]
74 changes: 41 additions & 33 deletions nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
from nemoguardrails.actions.llm.utils import get_colang_history
from nemoguardrails.flows.runtime import Runtime
from nemoguardrails.language.parser import parse_colang_file
from nemoguardrails.llm.providers import get_llm_provider, get_llm_provider_names
from nemoguardrails.llm.providers import (
get_embedding_provider_names,
get_llm_provider,
get_llm_provider_names,
)
from nemoguardrails.logging.stats import llm_stats
from nemoguardrails.rails.llm.config import RailsConfig
from nemoguardrails.rails.llm.utils import get_history_cache_key
Expand Down Expand Up @@ -120,39 +124,43 @@ def _init_llms(self):
# to search for the main model config.

for llm_config in self.config.models:
if llm_config.engine not in get_llm_provider_names():
raise Exception(f"Unknown LLM engine: {llm_config.engine}")

provider_cls = get_llm_provider(llm_config)
# We need to compute the kwargs for initializing the LLM
kwargs = llm_config.parameters

# We also need to pass the model, if specified
if llm_config.model:
# Some LLM providers use `model_name` instead of model. For backward compatibility
# we keep this hard-coded mapping.
if llm_config.engine in [
"azure",
"openai",
"gooseai",
"nlpcloud",
"petals",
]:
kwargs["model_name"] = llm_config.model
else:
# The `__fields__` attribute is computed dynamically by pydantic.
if "model" in provider_cls.__fields__:
kwargs["model"] = llm_config.model

if llm_config.type == "main" or len(self.config.models) == 1:
self.llm = provider_cls(**kwargs)
self.runtime.register_action_param("llm", self.llm)
if llm_config.type == "embedding":
if llm_config.engine not in get_embedding_provider_names():
raise Exception(f"Unknown embedding engine: {llm_config.engine}")
else:
model_name = f"{llm_config.type}_llm"
setattr(self, model_name, provider_cls(**kwargs))
self.runtime.register_action_param(
model_name, getattr(self, model_name)
)
if llm_config.engine not in get_llm_provider_names():
raise Exception(f"Unknown LLM engine: {llm_config.engine}")

provider_cls = get_llm_provider(llm_config)
# We need to compute the kwargs for initializing the LLM
kwargs = llm_config.parameters

# We also need to pass the model, if specified
if llm_config.model:
# Some LLM providers use `model_name` instead of model. For backward compatibility
# we keep this hard-coded mapping.
if llm_config.engine in [
"azure",
"openai",
"gooseai",
"nlpcloud",
"petals",
]:
kwargs["model_name"] = llm_config.model
else:
# The `__fields__` attribute is computed dynamically by pydantic.
if "model" in provider_cls.__fields__:
kwargs["model"] = llm_config.model

if llm_config.type == "main" or len(self.config.models) == 1:
self.llm = provider_cls(**kwargs)
self.runtime.register_action_param("llm", self.llm)
else:
model_name = f"{llm_config.type}_llm"
setattr(self, model_name, provider_cls(**kwargs))
self.runtime.register_action_param(
model_name, getattr(self, model_name)
)

def _get_events_for_messages(self, messages: List[dict]):
"""Return the list of events corresponding to the provided messages.
Expand Down
Loading