Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement cache embeddings (resolves #200) #208

Merged
merged 16 commits into from
Feb 15, 2024

Conversation

Pouyanpi
Copy link
Collaborator

@Pouyanpi Pouyanpi commented Dec 4, 2023

Summary

This PR introduces a caching mechanism for embeddings in the BasicEmbeddingsIndex class. It uses a decorator cache_embeddings to cache the results of the _get_embeddings method. The caching mechanism includes key generation and cache storage, which are implemented using abstract base classes to allow for extensibility.

Changes

  • Added cache_embeddings decorator to cache the results of the _get_embeddings method.
  • Made _get_embeddings asynchronous.
  • Introduced KeyGenerator and CacheStore abstract base classes for key generation and cache storage. (The base class also plays the role of a Factory)
  • Implemented HashKeyGenerator, MD5KeyGenerator, InMemoryCacheStore, FilesystemCacheStore, and RedisCacheStore as concrete implementations of the abstract base classes.
  • Updated BasicEmbeddingsIndex to use the caching mechanism.
  • Implemented EmbeddingsCacheConfig

Limitations

  • Cache eviction policy is not implemented.
  • The key is generated based on the text content, so if we use two different embedding engines/models there would be a cache hit if the embeddings were generated before with a different model.
  • the OpenAIEmbedding model does not return List[List[float]] but rather

Future Work

  • Implement a cache eviction policy.
  • Update the key generation to include the model name and engine to avoid cache hits when using different models. (easy)

Remarks

An object that needs to use cache embedding must instantiate these two objects (otherwise default values are used). The cache_embedding decorator could be used on any method that accepts a list of texts and returns a list of embeddings. So one could have applied it to the EmbeddingModel as well, but it seems more relevant to have it in BasicEmbeddingIndex as it is instantiated from the configs within the project. Furthermore, The CacheStore and KeyGenerator could be defined within the embedding config, the constructor could also accept a boolean to define whether to use caching at all.

Issue

This PR resolves issue #200.

Example Usage

Using the openai embedding engine

from nemoguardrails.embeddings.basic import BasicEmbeddingsIndex

# Example usage with OpenAI
embedding_index = BasicEmbeddingsIndex(
    embedding_engine="openai",
    embedding_model="text-embedding-ada-002",
    cache_config={"enabled": True, "store": "filesystem"}

)

embeddings = await embedding_index._get_embeddings(
    [
        "this is a text to test",
        "another text to test",
    ]
)

print(embeddings)

# If we run it for the second time it fetches the embeddings from file
embeddings = embedding_index._get_embeddings(
    [
        "this is a text to test",
        "another text to test",
    ]
)

print(embeddings)

using FastEmbed embedding engine

from nemoguardrails.embeddings.basic import BasicEmbeddingsIndex

embedding_index = BasicEmbeddingsIndex(
    embedding_engine="FastEmbed",
    embedding_model="all-MiniLM-L6-v2",
    cache_config = {"enabled": True},
)


# if we remove the "MODIFIED" we will have cache hits
# as discussed in limitations
embeddings = await embedding_index._get_embeddings(
    [
        "MODIFIED this is a text to test",
        "MODIFIED another text to test",
    ]
)

print(embeddings)

using in memory cache store

embedding_index = BasicEmbeddingsIndex(
    embedding_engine="FastEmbed",
    embedding_model="all-MiniLM-L6-v2",
    cache_config={"enabled": True, "store": "in_memory"}

)

embedding_index._get_embeddings(
    [
        "MODIFIED this is a text to test",
        "MODIFIED another text to test",
    ]
)

print(embeddings)

@drazvan drazvan self-assigned this Dec 4, 2023
Copy link
Collaborator

@drazvan drazvan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! 👍 You should fix the "bulk computing" problem.
We also need to add documentation for this and potentially enable the caching behavior to be controlled from the config. Let me know if you want me to make a first proposal for that.

nemoguardrails/embeddings/cache.py Outdated Show resolved Hide resolved
nemoguardrails/embeddings/cache.py Outdated Show resolved Hide resolved
@drazvan
Copy link
Collaborator

drazvan commented Dec 4, 2023

Thanks for implementing this @Pouyanpi! 👍
On the more boring admin side, I see that your first commit is not signed. Feel free to sign it and force-push the branch again.

@Pouyanpi Pouyanpi force-pushed the feature/200/embeddings/cache branch 2 times, most recently from ebe1653 to 3f11484 Compare December 5, 2023 17:33
@Pouyanpi
Copy link
Collaborator Author

Pouyanpi commented Dec 5, 2023

Thank you for your feedback, @drazvan . I've made the changes. Please let me know if you need further modifications.

If you'd like, I can also update/add the documentation to reflect these updates. Could you please advise on which file I should update and the level of detail you'd prefer?
Thanks!

@drazvan drazvan added this to the v0.7.0 milestone Dec 14, 2023
@drazvan drazvan modified the milestones: v0.7.0, v0.8.0 Jan 16, 2024
@drazvan drazvan added the enhancement New feature or request label Jan 16, 2024
@drazvan drazvan linked an issue Jan 16, 2024 that may be closed by this pull request
@drazvan
Copy link
Collaborator

drazvan commented Feb 6, 2024

@Pouyanpi: do you have availability to wrap this PR this week? We'd like to include it in the 0.8.0 release.

@Pouyanpi
Copy link
Collaborator Author

Pouyanpi commented Feb 6, 2024

Sure @drazvan, I will have it done later today.


@classmethod
def from_dict(cls, d):
key_generator = d.get("key_generator", MD5KeyGenerator())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The d dictionary here can contain strings rather than actual instances. This will be the case when called with the parameters data from esp_provider. I think there should be a logic that translates from str to actual instances. e.g.

if key_generator == "md5":
  key_generator = MD5KeyGenerator()

@drazvan
Copy link
Collaborator

drazvan commented Feb 6, 2024

@Pouyanpi : Also, let's make the configuration of the embeddings cache first-class citizen, like this:

class EmbeddingsCache(BaseModel):
    """Configuration for the caching embeddings."""

    enabled: bool = Field(
        default=False,
        description="Whether caching of the embeddings should be enabled or not.",
    )
    key_generator: str = Field(
        default="md5",
        description="The method to use for generating the cache keys.",
    )
    store: str = Field(
        default="filesystem",
        description="What type of store to use for the cached embeddings.",
    )
    store_config: Dict[str, Any] = Field(
        default_factory=dict,
        description="Any additional configuration options required for the store. "
        "For example, path for `filesystem` or `host`/`port`/`db` for redis.",
    )


class EmbeddingSearchProvider(BaseModel):
    """Configuration of a embedding search provider."""

    name: str = Field(
        default="default",
        description="The name of the embedding search provider. If not specified, default is used.",
    )
    parameters: Dict[str, Any] = Field(default_factory=dict)
    cache: EmbeddingsCache = Field(
        default_factory=EmbeddingsCache,
        description="Configuration option for whether to use a cache for embeddings or not.",
    )

And the configuration will be done as per:
https://github.com/NVIDIA/NeMo-Guardrails/blob/develop/docs/user_guides/advanced/embedding-search-providers.md

core:
  embedding_search_provider:
    name: default
    parameters:
      embedding_engine: FastEmbed
      embedding_model: all-MiniLM-L6-v2
    cache:
      enabled: True
      # The rest of the settings are optional; this will use filesystem + md5

@drazvan
Copy link
Collaborator

drazvan commented Feb 6, 2024

And last but not least, have a look at this as well: #267. This makes the _get_embeddings method async. So, maybe update this (i.e. make the method async) in your PR as well. Not sure in what order we'll merge.

Add  and  parameters to the  constructor to allow for optional caching of embeddings. This can improve performance by avoiding recomputation of embeddings.
Implement selective caching in the CacheEmbeddings class to maintain bulk computing behavior of embedding providers. The updated logic checks the cache before making a call and only processes uncached items, reducing unnecessary computations and API calls.
Add the  method in the  base class to allow custom embeddings to benefit from the @cache_embeddings decorator.
Update the  initialization to use the  configuration. This includes optional , , and  parameters.
Added a new `EmbeddingsCacheConfig` class to support customizable embedding caching strategies, featuring options for enabling cache, cache key generation methods, storage types, and additional store-specific configurations. This update facilitates performance improvements by allowing the reuse of embeddings across tasks. Integrated the cache configuration into the `EmbeddingSearchProvider` class to streamline usage.
Introduced type annotations and 'name' attributes to key and cache store classes for consistent object creation. Added factory methods to KeyGenerator and CacheStore classes, simplifying instantiation through names. Refactored CacheEmbeddings to EmbeddingsCache, now able to convert its configuration to a dictionary and create instances from a dictionary or EmbeddingsCacheConfig. Updated `cache_embeddings` decorator to utilize a class's cache_config attribute, enabling asynchronous retrieval and caching of embeddings. Enhanced logging to display longer text snippets.
Updated the BasicEmbeddingsIndex to use a more flexible cache configuration setup. We replaced the bool flag and the CacheEmbeddings instance with a single unified `cache_config` parameter that can accept either a dictionary or an
EmbeddingsCacheConfig instance for enhanced configurability. Additionally, the `_get_embeddings` method is now an asynchronous function to allow for non-blocking I/O operations per NVIDIA#267.
Enhanced the EmbeddingsIndex class by adding a 'cache_config' attribute for customizable cache management. Also, updated the '_get_embeddings' method to be asynchronous per NVIDIA#267.
Removed a direct dependency on `CacheEmbeddings` in the `LLMRails` class, streamlining the embedding cache setup. Configuration now relies on a generic `cache_config` pulled directly from the `esp_config`
Updated the documentation to reflect the new caching feature for embedding searches
@Pouyanpi
Copy link
Collaborator Author

Pouyanpi commented Feb 7, 2024

@drazvan : I have implemented the changes based on your suggestions and have pushed the updates for review.

Currently, the caching feature is disabled by default. Please let me know if this aligns with your expectations.

Regarding the factory-related behavior of CacheStore and KeyGenerator, I acknowledge it might not be the "clean" approach. I have alternative solutions in mind and can share preliminary drafts if you're interested for future releases. However, I believe the current implementation is satisfactory for the moment.

Feel free to review the modifications at your convenience.

Signed-off-by: Razvan Dinu <rdinu@nvidia.com>
@drazvan drazvan merged commit aec4c42 into NVIDIA:develop Feb 15, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Durable embeddings
3 participants