From c4d3e6865243708921a062f4dd801b721799633f Mon Sep 17 00:00:00 2001 From: Ayush Agrawal Date: Mon, 23 Sep 2024 16:07:11 -0700 Subject: [PATCH] chore: Add RagManagedDb to RAG corpus creation as well as default db config when no vector_db specified PiperOrigin-RevId: 677975431 --- vertexai/preview/rag/__init__.py | 2 ++ vertexai/preview/rag/rag_data.py | 12 +++++----- vertexai/preview/rag/utils/_gapic_utils.py | 26 +++++++++++++++++----- vertexai/preview/rag/utils/resources.py | 5 +++++ 4 files changed, 34 insertions(+), 11 deletions(-) diff --git a/vertexai/preview/rag/__init__.py b/vertexai/preview/rag/__init__.py index 2deed4c630..f009b58f40 100644 --- a/vertexai/preview/rag/__init__.py +++ b/vertexai/preview/rag/__init__.py @@ -43,6 +43,7 @@ Pinecone, RagCorpus, RagFile, + RagManagedDb, RagResource, SharePointSource, SharePointSources, @@ -61,6 +62,7 @@ "Pinecone", "RagCorpus", "RagFile", + "RagManagedDb", "RagResource", "Retrieval", "SharePointSource", diff --git a/vertexai/preview/rag/rag_data.py b/vertexai/preview/rag/rag_data.py index c4b6f96e18..7a19ce51f8 100644 --- a/vertexai/preview/rag/rag_data.py +++ b/vertexai/preview/rag/rag_data.py @@ -48,6 +48,7 @@ Pinecone, RagCorpus, RagFile, + RagManagedDb, SharePointSources, SlackChannelsSource, VertexFeatureStore, @@ -61,7 +62,7 @@ def create_corpus( description: Optional[str] = None, embedding_model_config: Optional[EmbeddingModelConfig] = None, vector_db: Optional[ - Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone] + Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb] ] = None, ) -> RagCorpus: """Creates a new RagCorpus resource. @@ -102,11 +103,10 @@ def create_corpus( embedding_model_config=embedding_model_config, rag_corpus=rag_corpus, ) - if vector_db is not None: - _gapic_utils.set_vector_db( - vector_db=vector_db, - rag_corpus=rag_corpus, - ) + _gapic_utils.set_vector_db( + vector_db=vector_db, + rag_corpus=rag_corpus, + ) request = CreateRagCorpusRequest( parent=parent, diff --git a/vertexai/preview/rag/utils/_gapic_utils.py b/vertexai/preview/rag/utils/_gapic_utils.py index 166c3b9107..b0ecc01131 100644 --- a/vertexai/preview/rag/utils/_gapic_utils.py +++ b/vertexai/preview/rag/utils/_gapic_utils.py @@ -42,6 +42,7 @@ Pinecone, RagCorpus, RagFile, + RagManagedDb, SharePointSources, SlackChannelsSource, JiraSource, @@ -107,6 +108,13 @@ def _check_weaviate(gapic_vector_db: RagVectorDbConfig) -> bool: return gapic_vector_db.weaviate.ByteSize() > 0 +def _check_rag_managed_db(gapic_vector_db: RagVectorDbConfig) -> bool: + try: + return gapic_vector_db.__contains__("rag_managed_db") + except AttributeError: + return gapic_vector_db.rag_managed_db.ByteSize() > 0 + + def _check_vertex_feature_store(gapic_vector_db: RagVectorDbConfig) -> bool: try: return gapic_vector_db.__contains__("vertex_feature_store") @@ -130,8 +138,8 @@ def _check_vertex_vector_search(gapic_vector_db: RagVectorDbConfig) -> bool: def convert_gapic_to_vector_db( gapic_vector_db: RagVectorDbConfig, -) -> Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone]: - """Convert Gapic RagVectorDbConfig to Weaviate, VertexFeatureStore, VertexVectorSearch, or Pinecone.""" +) -> Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb]: + """Convert Gapic RagVectorDbConfig to Weaviate, VertexFeatureStore, VertexVectorSearch, RagManagedDb, or Pinecone.""" if _check_weaviate(gapic_vector_db): return Weaviate( weaviate_http_endpoint=gapic_vector_db.weaviate.http_endpoint, @@ -152,6 +160,8 @@ def convert_gapic_to_vector_db( index_endpoint=gapic_vector_db.vertex_vector_search.index_endpoint, index=gapic_vector_db.vertex_vector_search.index, ) + elif _check_rag_managed_db(gapic_vector_db): + return RagManagedDb() else: return None @@ -499,11 +509,17 @@ def set_embedding_model_config( def set_vector_db( - vector_db: Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone], + vector_db: Union[ + Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb, None + ], rag_corpus: GapicRagCorpus, ) -> None: """Sets the vector db configuration for the rag corpus.""" - if isinstance(vector_db, Weaviate): + if vector_db is None or isinstance(vector_db, RagManagedDb): + rag_corpus.rag_vector_db_config = RagVectorDbConfig( + rag_managed_db=RagVectorDbConfig.RagManagedDb(), + ) + elif isinstance(vector_db, Weaviate): http_endpoint = vector_db.weaviate_http_endpoint collection_name = vector_db.collection_name api_key = vector_db.api_key @@ -553,5 +569,5 @@ def set_vector_db( ) else: raise TypeError( - "vector_db must be a Weaviate, VertexFeatureStore, VertexVectorSearch, or Pinecone." + "vector_db must be a Weaviate, VertexFeatureStore, VertexVectorSearch, RagManagedDb, or Pinecone." ) diff --git a/vertexai/preview/rag/utils/resources.py b/vertexai/preview/rag/utils/resources.py index 6f86f0a8ad..753d63b0bd 100644 --- a/vertexai/preview/rag/utils/resources.py +++ b/vertexai/preview/rag/utils/resources.py @@ -115,6 +115,11 @@ class VertexVectorSearch: index: str +@dataclasses.dataclass +class RagManagedDb: + """RagManagedDb.""" + + @dataclasses.dataclass class Pinecone: """Pinecone.