From c52e3e4ea63e43346b439c3eaf6b264c83bf1c25 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Tue, 26 Nov 2024 14:31:30 -0800 Subject: [PATCH] feat: Add compatibility for RagRetrievalConfig in rag_store and rag_retrieval feat: Add deprecation warnings for use of similarity_top_k, vector_search_alpha, and vector_distance_threshold in retrieval_query, use RagRetrievalConfig instead. PiperOrigin-RevId: 700462404 --- .../vertex_rag/test_rag_constants_preview.py | 12 ++ .../vertex_rag/test_rag_retrieval_preview.py | 89 +++++++++++- .../unit/vertex_rag/test_rag_store_preview.py | 71 +++++++++- vertexai/preview/rag/__init__.py | 6 + vertexai/preview/rag/rag_retrieval.py | 129 ++++++++++++++---- vertexai/preview/rag/rag_store.py | 82 +++++++++-- vertexai/preview/rag/utils/resources.py | 91 ++++++++++++ 7 files changed, 432 insertions(+), 48 deletions(-) diff --git a/tests/unit/vertex_rag/test_rag_constants_preview.py b/tests/unit/vertex_rag/test_rag_constants_preview.py index 7bf576c21d..ebac9f1edc 100644 --- a/tests/unit/vertex_rag/test_rag_constants_preview.py +++ b/tests/unit/vertex_rag/test_rag_constants_preview.py @@ -20,10 +20,13 @@ from vertexai.preview.rag import ( EmbeddingModelConfig, + Filter, + HybridSearch, Pinecone, RagCorpus, RagFile, RagResource, + RagRetrievalConfig, SharePointSource, SharePointSources, SlackChannelsSource, @@ -529,3 +532,12 @@ rag_corpus="213lkj-1/23jkl/", rag_file_ids=[TEST_RAG_FILE_ID], ) +TEST_RAG_RETRIEVAL_CONFIG = RagRetrievalConfig( + top_k=2, + filter=Filter(vector_distance_threshold=0.5), +) +TEST_RAG_RETRIEVAL_CONFIG_ALPHA = RagRetrievalConfig( + top_k=2, + filter=Filter(vector_distance_threshold=0.5), + hybrid_search=HybridSearch(alpha=0.5), +) diff --git a/tests/unit/vertex_rag/test_rag_retrieval_preview.py b/tests/unit/vertex_rag/test_rag_retrieval_preview.py index 21bdc7b4cd..9c79636a8c 100644 --- a/tests/unit/vertex_rag/test_rag_retrieval_preview.py +++ b/tests/unit/vertex_rag/test_rag_retrieval_preview.py @@ -73,12 +73,34 @@ def teardown_method(self): @pytest.mark.usefixtures("retrieve_contexts_mock") def test_retrieval_query_rag_resources_success(self): + with pytest.warns(DeprecationWarning): + response = rag.retrieval_query( + rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE], + text=test_rag_constants_preview.TEST_QUERY_TEXT, + similarity_top_k=2, + vector_distance_threshold=0.5, + vector_search_alpha=0.5, + ) + retrieve_contexts_eq( + response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE + ) + + @pytest.mark.usefixtures("retrieve_contexts_mock") + def test_retrieval_query_rag_resources_config_success(self): + response = rag.retrieval_query( + rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE], + text=test_rag_constants_preview.TEST_QUERY_TEXT, + rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG_ALPHA, + ) + retrieve_contexts_eq( + response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE + ) + + @pytest.mark.usefixtures("retrieve_contexts_mock") + def test_retrieval_query_rag_resources_default_config_success(self): response = rag.retrieval_query( rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE], text=test_rag_constants_preview.TEST_QUERY_TEXT, - similarity_top_k=2, - vector_distance_threshold=0.5, - vector_search_alpha=0.5, ) retrieve_contexts_eq( response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE @@ -86,11 +108,23 @@ def test_retrieval_query_rag_resources_success(self): @pytest.mark.usefixtures("retrieve_contexts_mock") def test_retrieval_query_rag_corpora_success(self): + with pytest.warns(DeprecationWarning): + response = rag.retrieval_query( + rag_corpora=[test_rag_constants_preview.TEST_RAG_CORPUS_ID], + text=test_rag_constants_preview.TEST_QUERY_TEXT, + similarity_top_k=2, + vector_distance_threshold=0.5, + ) + retrieve_contexts_eq( + response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE + ) + + @pytest.mark.usefixtures("retrieve_contexts_mock") + def test_retrieval_query_rag_corpora_config_success(self): response = rag.retrieval_query( rag_corpora=[test_rag_constants_preview.TEST_RAG_CORPUS_ID], text=test_rag_constants_preview.TEST_QUERY_TEXT, - similarity_top_k=2, - vector_distance_threshold=0.5, + rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG, ) retrieve_contexts_eq( response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE @@ -107,6 +141,16 @@ def test_retrieval_query_failure(self): ) e.match("Failed in retrieving contexts due to") + @pytest.mark.usefixtures("rag_client_mock_exception") + def test_retrieval_query_config_failure(self): + with pytest.raises(RuntimeError) as e: + rag.retrieval_query( + rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE], + text=test_rag_constants_preview.TEST_QUERY_TEXT, + rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG, + ) + e.match("Failed in retrieving contexts due to") + def test_retrieval_query_invalid_name(self): with pytest.raises(ValueError) as e: rag.retrieval_query( @@ -119,6 +163,17 @@ def test_retrieval_query_invalid_name(self): ) e.match("Invalid RagCorpus name") + def test_retrieval_query_invalid_name_config(self): + with pytest.raises(ValueError) as e: + rag.retrieval_query( + rag_resources=[ + test_rag_constants_preview.TEST_RAG_RESOURCE_INVALID_NAME + ], + text=test_rag_constants_preview.TEST_QUERY_TEXT, + rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG, + ) + e.match("Invalid RagCorpus name") + def test_retrieval_query_multiple_rag_corpora(self): with pytest.raises(ValueError) as e: rag.retrieval_query( @@ -132,6 +187,18 @@ def test_retrieval_query_multiple_rag_corpora(self): ) e.match("Currently only support 1 RagCorpus") + def test_retrieval_query_multiple_rag_corpora_config(self): + with pytest.raises(ValueError) as e: + rag.retrieval_query( + rag_corpora=[ + test_rag_constants_preview.TEST_RAG_CORPUS_ID, + test_rag_constants_preview.TEST_RAG_CORPUS_ID, + ], + text=test_rag_constants_preview.TEST_QUERY_TEXT, + rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG, + ) + e.match("Currently only support 1 RagCorpus") + def test_retrieval_query_multiple_rag_resources(self): with pytest.raises(ValueError) as e: rag.retrieval_query( @@ -144,3 +211,15 @@ def test_retrieval_query_multiple_rag_resources(self): vector_distance_threshold=0.5, ) e.match("Currently only support 1 RagResource") + + def test_retrieval_query_multiple_rag_resources_config(self): + with pytest.raises(ValueError) as e: + rag.retrieval_query( + rag_resources=[ + test_rag_constants_preview.TEST_RAG_RESOURCE, + test_rag_constants_preview.TEST_RAG_RESOURCE, + ], + text=test_rag_constants_preview.TEST_QUERY_TEXT, + rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG, + ) + e.match("Currently only support 1 RagResource") diff --git a/tests/unit/vertex_rag/test_rag_store_preview.py b/tests/unit/vertex_rag/test_rag_store_preview.py index 6d733b7baf..6ff3d4cef5 100644 --- a/tests/unit/vertex_rag/test_rag_store_preview.py +++ b/tests/unit/vertex_rag/test_rag_store_preview.py @@ -21,7 +21,32 @@ @pytest.mark.usefixtures("google_auth_mock") -class TestRagStoreValidations: +class TestRagStore: + def test_retrieval_tool_success(self): + with pytest.warns(DeprecationWarning): + Tool.from_retrieval( + retrieval=rag.Retrieval( + source=rag.VertexRagStore( + rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE], + similarity_top_k=3, + vector_distance_threshold=0.4, + ), + ) + ) + + def test_retrieval_tool_config_success(self): + with pytest.warns(DeprecationWarning): + Tool.from_retrieval( + retrieval=rag.Retrieval( + source=rag.VertexRagStore( + rag_corpora=[ + test_rag_constants_preview.TEST_RAG_CORPUS_ID, + ], + rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG, + ), + ) + ) + def test_retrieval_tool_invalid_name(self): with pytest.raises(ValueError) as e: Tool.from_retrieval( @@ -37,6 +62,20 @@ def test_retrieval_tool_invalid_name(self): ) e.match("Invalid RagCorpus name") + def test_retrieval_tool_invalid_name_config(self): + with pytest.raises(ValueError) as e: + Tool.from_retrieval( + retrieval=rag.Retrieval( + source=rag.VertexRagStore( + rag_resources=[ + test_rag_constants_preview.TEST_RAG_RESOURCE_INVALID_NAME + ], + rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG, + ), + ) + ) + e.match("Invalid RagCorpus name") + def test_retrieval_tool_multiple_rag_corpora(self): with pytest.raises(ValueError) as e: Tool.from_retrieval( @@ -53,6 +92,21 @@ def test_retrieval_tool_multiple_rag_corpora(self): ) e.match("Currently only support 1 RagCorpus") + def test_retrieval_tool_multiple_rag_corpora_config(self): + with pytest.raises(ValueError) as e: + Tool.from_retrieval( + retrieval=rag.Retrieval( + source=rag.VertexRagStore( + rag_corpora=[ + test_rag_constants_preview.TEST_RAG_CORPUS_ID, + test_rag_constants_preview.TEST_RAG_CORPUS_ID, + ], + rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG, + ), + ) + ) + e.match("Currently only support 1 RagCorpus") + def test_retrieval_tool_multiple_rag_resources(self): with pytest.raises(ValueError) as e: Tool.from_retrieval( @@ -68,3 +122,18 @@ def test_retrieval_tool_multiple_rag_resources(self): ) ) e.match("Currently only support 1 RagResource") + + def test_retrieval_tool_multiple_rag_resources_config(self): + with pytest.raises(ValueError) as e: + Tool.from_retrieval( + retrieval=rag.Retrieval( + source=rag.VertexRagStore( + rag_resources=[ + test_rag_constants_preview.TEST_RAG_RESOURCE, + test_rag_constants_preview.TEST_RAG_RESOURCE, + ], + rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG, + ), + ) + ) + e.match("Currently only support 1 RagResource") diff --git a/vertexai/preview/rag/__init__.py b/vertexai/preview/rag/__init__.py index 054cb4482d..4065616746 100644 --- a/vertexai/preview/rag/__init__.py +++ b/vertexai/preview/rag/__init__.py @@ -39,6 +39,8 @@ ) from vertexai.preview.rag.utils.resources import ( EmbeddingModelConfig, + Filter, + HybridSearch, JiraQuery, JiraSource, Pinecone, @@ -46,6 +48,7 @@ RagFile, RagManagedDb, RagResource, + RagRetrievalConfig, SharePointSource, SharePointSources, SlackChannel, @@ -58,6 +61,8 @@ __all__ = ( "EmbeddingModelConfig", + "Filter", + "HybridSearch", "JiraQuery", "JiraSource", "Pinecone", @@ -65,6 +70,7 @@ "RagFile", "RagManagedDb", "RagResource", + "RagRetrievalConfig", "Retrieval", "SharePointSource", "SharePointSources", diff --git a/vertexai/preview/rag/rag_retrieval.py b/vertexai/preview/rag/rag_retrieval.py index 1d8bbb5612..8e1b096aee 100644 --- a/vertexai/preview/rag/rag_retrieval.py +++ b/vertexai/preview/rag/rag_retrieval.py @@ -18,27 +18,23 @@ import re from typing import List, Optional -from google.cloud.aiplatform import initializer +import warnings -from google.cloud.aiplatform_v1beta1 import ( - RagQuery, - RetrieveContextsRequest, - RetrieveContextsResponse, -) -from vertexai.preview.rag.utils import ( - _gapic_utils, -) -from vertexai.preview.rag.utils.resources import RagResource +from google.cloud import aiplatform_v1beta1 +from google.cloud.aiplatform import initializer +from vertexai.preview.rag.utils import _gapic_utils +from vertexai.preview.rag.utils import resources def retrieval_query( text: str, - rag_resources: Optional[List[RagResource]] = None, + rag_resources: Optional[List[resources.RagResource]] = None, rag_corpora: Optional[List[str]] = None, - similarity_top_k: Optional[int] = 10, - vector_distance_threshold: Optional[float] = 0.3, - vector_search_alpha: Optional[float] = 0.5, -) -> RetrieveContextsResponse: + similarity_top_k: Optional[int] = None, + vector_distance_threshold: Optional[float] = None, + vector_search_alpha: Optional[float] = None, + rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None, +) -> aiplatform_v1beta1.RetrieveContextsResponse: """Retrieve top k relevant docs/chunks. Example usage: @@ -73,6 +69,9 @@ def retrieval_query( sparse vector search results. The range is [0, 1], where 0 means sparse vector search only and 1 means dense vector search only. The default value is 0.5. + rag_retrieval_config: Optional. The config containing the retrieval + parameters, including similarity_top_k, vector_distance_threshold, + vector_search_alpha, and hybrid_search. Returns: RetrieveContextsResonse. @@ -89,6 +88,12 @@ def retrieval_query( if len(rag_corpora) > 1: raise ValueError("Currently only support 1 RagCorpus.") name = rag_corpora[0] + warnings.warn( + f"rag_corpora is deprecated. Please use rag_resources instead." + f" After {resources.DEPRECATION_DATE} using" + " rag_corpora will raise error", + DeprecationWarning, + ) else: raise ValueError("rag_resources or rag_corpora must be specified.") @@ -99,32 +104,98 @@ def retrieval_query( rag_corpus_name = parent + "/ragCorpora/" + name else: raise ValueError( - "Invalid RagCorpus name: %s. Proper format should be: projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}", - rag_corpora, + f"Invalid RagCorpus name: {rag_corpora}. Proper format should be:" + " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}" ) if rag_resources: - gapic_rag_resource = RetrieveContextsRequest.VertexRagStore.RagResource( - rag_corpus=rag_corpus_name, - rag_file_ids=rag_resources[0].rag_file_ids, + gapic_rag_resource = ( + aiplatform_v1beta1.RetrieveContextsRequest.VertexRagStore.RagResource( + rag_corpus=rag_corpus_name, + rag_file_ids=rag_resources[0].rag_file_ids, + ) ) - vertex_rag_store = RetrieveContextsRequest.VertexRagStore( + vertex_rag_store = aiplatform_v1beta1.RetrieveContextsRequest.VertexRagStore( rag_resources=[gapic_rag_resource], ) else: - vertex_rag_store = RetrieveContextsRequest.VertexRagStore( + vertex_rag_store = aiplatform_v1beta1.RetrieveContextsRequest.VertexRagStore( rag_corpora=[rag_corpus_name], ) - vertex_rag_store.vector_distance_threshold = vector_distance_threshold - query = RagQuery( + # Check for deprecated parameters and raise warnings. + if similarity_top_k: + # If similarity_top_k is specified, throw deprecation warning. + warnings.warn( + "similarity_top_k is deprecated. Please use" + " rag_retrieval_config.top_k instead." + f" After {resources.DEPRECATION_DATE} using" + " similarity_top_k will raise error", + DeprecationWarning, + ) + if vector_search_alpha: + # If vector_search_alpha is specified, throw deprecation warning. + warnings.warn( + "vector_search_alpha is deprecated. Please use" + " rag_retrieval_config.alpha instead." + f" After {resources.DEPRECATION_DATE} using" + " vector_search_alpha will raise error", + DeprecationWarning, + ) + if vector_distance_threshold: + # If vector_distance_threshold is specified, throw deprecation warning. + warnings.warn( + "vector_distance_threshold is deprecated. Please use" + " rag_retrieval_config.filter.vector_distance_threshold instead." + f" After {resources.DEPRECATION_DATE} using" + " vector_distance_threshold will raise error", + DeprecationWarning, + ) + + # If rag_retrieval_config is not specified, set it to default values. + if not rag_retrieval_config: + api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig( + top_k=similarity_top_k, + hybrid_search=aiplatform_v1beta1.RagRetrievalConfig.HybridSearch( + alpha=vector_search_alpha, + ), + filter=aiplatform_v1beta1.RagRetrievalConfig.Filter( + vector_distance_threshold=vector_distance_threshold + ), + ) + else: + # If rag_retrieval_config is specified, check for missing parameters. + api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig() + api_retrival_config.top_k = ( + rag_retrieval_config.top_k + if rag_retrieval_config.top_k + else similarity_top_k + ) + if ( + rag_retrieval_config.hybrid_search + and rag_retrieval_config.hybrid_search.alpha + ): + api_retrival_config.hybrid_search.alpha = ( + rag_retrieval_config.hybrid_search.alpha + ) + else: + api_retrival_config.hybrid_search.alpha = vector_search_alpha + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + ): + api_retrival_config.filter.vector_distance_threshold = ( + rag_retrieval_config.filter.vector_distance_threshold + ) + else: + api_retrival_config.filter.vector_distance_threshold = ( + vector_distance_threshold + ) + query = aiplatform_v1beta1.RagQuery( text=text, - similarity_top_k=similarity_top_k, - ranking=RagQuery.Ranking( - alpha=vector_search_alpha, - ), + rag_retrieval_config=api_retrival_config, ) - request = RetrieveContextsRequest( + request = aiplatform_v1beta1.RetrieveContextsRequest( vertex_rag_store=vertex_rag_store, parent=parent, query=query, diff --git a/vertexai/preview/rag/rag_store.py b/vertexai/preview/rag/rag_store.py index c899c5fe42..62012df93d 100644 --- a/vertexai/preview/rag/rag_store.py +++ b/vertexai/preview/rag/rag_store.py @@ -18,11 +18,14 @@ import re from typing import List, Optional, Union -from google.cloud.aiplatform_v1beta1.types import tool as gapic_tool_types +import warnings + +from google.cloud import aiplatform_v1beta1 from google.cloud.aiplatform import initializer -from vertexai.preview.rag.utils import _gapic_utils -from vertexai.preview.rag.utils.resources import RagResource +from google.cloud.aiplatform_v1beta1.types import tool as gapic_tool_types from vertexai.preview import generative_models +from vertexai.preview.rag.utils import _gapic_utils +from vertexai.preview.rag.utils import resources class Retrieval(generative_models.grounding.Retrieval): @@ -44,10 +47,11 @@ class VertexRagStore: def __init__( self, - rag_resources: Optional[List[RagResource]] = None, + rag_resources: Optional[List[resources.RagResource]] = None, rag_corpora: Optional[List[str]] = None, - similarity_top_k: Optional[int] = 10, - vector_distance_threshold: Optional[float] = 0.3, + similarity_top_k: Optional[int] = None, + vector_distance_threshold: Optional[float] = None, + rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None, ): """Initializes a Vertex RAG store tool. @@ -78,7 +82,11 @@ def __init__( similarity_top_k: Number of top k results to return from the selected corpora. vector_distance_threshold (float): - Optional. Only return results with vector distance smaller than the threshold. + Optional. Only return results with vector distance smaller + than the threshold. + rag_retrieval_config: Optional. The config containing the retrieval + parameters, including similarity_top_k, hybrid search alpha, + and vector_distance_threshold. """ @@ -89,6 +97,12 @@ def __init__( elif rag_corpora: if len(rag_corpora) > 1: raise ValueError("Currently only support 1 RagCorpus.") + warnings.warn( + "rag_corpora is deprecated. Please use rag_resources instead." + f" After {resources.DEPRECATION_DATE} using" + " rag_corpora will raise error", + DeprecationWarning, + ) name = rag_corpora[0] else: raise ValueError("rag_resources or rag_corpora must be specified.") @@ -101,9 +115,53 @@ def __init__( rag_corpus_name = parent + "/ragCorpora/" + name else: raise ValueError( - "Invalid RagCorpus name: %s. Proper format should be: projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}", - rag_corpora, + f"Invalid RagCorpus name: {rag_corpora}. Proper format should" + + " be: projects/{{project}}/locations/{{location}}/ragCorpora/{{rag_corpus_id}}" ) + + # Check for deprecated parameters and raise warnings. + if similarity_top_k: + # If similarity_top_k is specified, throw deprecation warning. + warnings.warn( + "similarity_top_k is deprecated. Please use" + " rag_retrieval_config.top_k instead." + f" After {resources.DEPRECATION_DATE} using" + " similarity_top_k will raise error", + DeprecationWarning, + ) + if vector_distance_threshold: + # If vector_distance_threshold is specified, throw deprecation warning. + warnings.warn( + "vector_distance_threshold is deprecated. Please use" + " rag_retrieval_config.filter.vector_distance_threshold instead." + f" After {resources.DEPRECATION_DATE} using" + " vector_distance_threshold will raise error", + DeprecationWarning, + ) + + # If rag_retrieval_config is not specified, set it to default values. + if not rag_retrieval_config: + api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig( + top_k=similarity_top_k, + filter=aiplatform_v1beta1.RagRetrievalConfig.Filter( + vector_distance_threshold=vector_distance_threshold + ), + ) + else: + # If rag_retrieval_config is specified, check for missing parameters. + api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig() + if not rag_retrieval_config.top_k: + api_retrival_config.top_k = similarity_top_k + if ( + not rag_retrieval_config.filter + or not rag_retrieval_config.filter.vector_distance_threshold + ): + api_retrival_config.filter = ( + aiplatform_v1beta1.RagRetrievalConfig.Filter( + vector_distance_threshold=vector_distance_threshold + ), + ) + if rag_resources: gapic_rag_resource = gapic_tool_types.VertexRagStore.RagResource( rag_corpus=rag_corpus_name, @@ -111,12 +169,10 @@ def __init__( ) self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore( rag_resources=[gapic_rag_resource], - similarity_top_k=similarity_top_k, - vector_distance_threshold=vector_distance_threshold, + rag_retrieval_config=api_retrival_config, ) else: self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore( rag_corpora=[rag_corpus_name], - similarity_top_k=similarity_top_k, - vector_distance_threshold=vector_distance_threshold, + rag_retrieval_config=api_retrival_config, ) diff --git a/vertexai/preview/rag/utils/resources.py b/vertexai/preview/rag/utils/resources.py index b6ac1ea050..1db8f0899f 100644 --- a/vertexai/preview/rag/utils/resources.py +++ b/vertexai/preview/rag/utils/resources.py @@ -20,6 +20,8 @@ from google.protobuf import timestamp_pb2 +DEPRECATION_DATE = "June 2025" + @dataclasses.dataclass class RagFile: @@ -282,3 +284,92 @@ class SharePointSources: """ share_point_sources: Sequence[SharePointSource] + + +@dataclasses.dataclass +class Filter: + """Filter. + + Attributes: + vector_distance_threshold: Only returns contexts with vector + distance smaller than the threshold. + vector_similarity_threshold: Only returns contexts with vector + similarity larger than the threshold. + metadata_filter: String for metadata filtering. + """ + + vector_distance_threshold: Optional[float] = None + vector_similarity_threshold: Optional[float] = None + metadata_filter: Optional[str] = None + + +@dataclasses.dataclass +class HybridSearch: + """HybridSearch. + + Attributes: + alpha: Alpha value controls the weight between dense and + sparse vector search results. The range is [0, 1], while 0 + means sparse vector search only and 1 means dense vector + search only. The default value is 0.5 which balances sparse + and dense vector search equally. + """ + + alpha: Optional[float] = None + + +@dataclasses.dataclass +class LlmRanker: + """LlmRanker. + + Attributes: + model_name: The model name used for ranking. + """ + + model_name: Optional[str] = None + + +@dataclasses.dataclass +class RankService: + """RankService. + + Attributes: + model_name: The model name of the rank service. Format: + ``semantic-ranker-512@latest`` + """ + + model_name: Optional[str] = None + + +@dataclasses.dataclass +class Ranking: + """Ranking. + + Attributes: + rank_service: (google.cloud.aiplatform_v1beta1.types.RagRetrievalConfig.Ranking.RankService) + Config for Rank Service. + llm_ranker (google.cloud.aiplatform_v1beta1.types.RagRetrievalConfig.Ranking.LlmRanker): + Config for LlmRanker. + """ + + rank_service: Optional[RankService] = None + llm_ranker: Optional[LlmRanker] = None + + +@dataclasses.dataclass +class RagRetrievalConfig: + """RagRetrievalConfig. + + Attributes: + top_k: The number of contexts to retrieve. + filter: Config for filters. + hybrid_search (google.cloud.aiplatform_v1beta1.types.RagRetrievalConfig.HybridSearch): + Config for Hybrid Search. + ranking (google.cloud.aiplatform_v1beta1.types.RagRetrievalConfig.Ranking): + Config for ranking and reranking. + """ + + top_k: Optional[int] = None + filter: Optional[Filter] = None + hybrid_search: Optional[HybridSearch] = None + ranking: Optional[Ranking] = None