diff --git a/tests/unit/vertex_rag/test_rag_store.py b/tests/unit/vertex_rag/test_rag_store.py index 5971eb866a..89d0be4b52 100644 --- a/tests/unit/vertex_rag/test_rag_store.py +++ b/tests/unit/vertex_rag/test_rag_store.py @@ -22,6 +22,39 @@ @pytest.mark.usefixtures("google_auth_mock") class TestRagStoreValidations: + def test_retrieval_tool_success(self): + tool = Tool.from_retrieval( + retrieval=rag.Retrieval( + source=rag.VertexRagStore( + rag_resources=[tc.TEST_RAG_RESOURCE], + rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG, + ), + ) + ) + assert tool is not None + + def test_retrieval_tool_vector_similarity_success(self): + tool = Tool.from_retrieval( + retrieval=rag.Retrieval( + source=rag.VertexRagStore( + rag_resources=[tc.TEST_RAG_RESOURCE], + rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG, + ), + ) + ) + assert tool is not None + + def test_retrieval_tool_no_rag_resources(self): + with pytest.raises(ValueError) as e: + Tool.from_retrieval( + retrieval=rag.Retrieval( + source=rag.VertexRagStore( + rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG, + ), + ) + ) + e.match("rag_resources must be specified.") + def test_retrieval_tool_invalid_name(self): with pytest.raises(ValueError) as e: Tool.from_retrieval( diff --git a/tests/unit/vertex_rag/test_rag_store_preview.py b/tests/unit/vertex_rag/test_rag_store_preview.py index 969ea22108..0529cfafa8 100644 --- a/tests/unit/vertex_rag/test_rag_store_preview.py +++ b/tests/unit/vertex_rag/test_rag_store_preview.py @@ -73,6 +73,30 @@ def test_retrieval_tool_ranking_config_success(self): ) ) + def test_empty_retrieval_tool_success(self): + tool = Tool.from_retrieval( + retrieval=rag.Retrieval( + source=rag.VertexRagStore( + rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE], + rag_retrieval_config=rag.RagRetrievalConfig(), + similarity_top_k=3, + vector_distance_threshold=0.4, + ), + ) + ) + assert tool is not None + + def test_retrieval_tool_no_rag_resources(self): + with pytest.raises(ValueError) as e: + Tool.from_retrieval( + retrieval=rag.Retrieval( + source=rag.VertexRagStore( + rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG, + ), + ) + ) + e.match("rag_resources or rag_corpora must be specified.") + def test_retrieval_tool_invalid_name(self): with pytest.raises(ValueError) as e: Tool.from_retrieval( diff --git a/vertexai/rag/rag_store.py b/vertexai/rag/rag_store.py index 88982d2045..57cef287b3 100644 --- a/vertexai/rag/rag_store.py +++ b/vertexai/rag/rag_store.py @@ -19,10 +19,10 @@ import re from typing import List, Optional, Union -from google.cloud import aiplatform_v1 +from google.cloud import aiplatform_v1beta1 from google.cloud.aiplatform import initializer -from google.cloud.aiplatform_v1.types import tool as gapic_tool_types -from vertexai.preview import generative_models +from google.cloud.aiplatform_v1beta1.types import tool as gapic_tool_types +from vertexai import generative_models from vertexai.rag.utils import _gapic_utils from vertexai.rag.utils import resources @@ -103,7 +103,7 @@ def __init__( ) # If rag_retrieval_config is not specified, set it to default values. - api_retrieval_config = aiplatform_v1.RagRetrievalConfig() + api_retrieval_config = aiplatform_v1beta1.RagRetrievalConfig() # If rag_retrieval_config is specified, populate the default config. if rag_retrieval_config: api_retrieval_config.top_k = rag_retrieval_config.top_k @@ -128,17 +128,11 @@ def __init__( rag_retrieval_config.filter.vector_similarity_threshold ) - if rag_resources: - gapic_rag_resource = gapic_tool_types.VertexRagStore.RagResource( - rag_corpus=rag_corpus_name, - rag_file_ids=rag_resources[0].rag_file_ids, - ) - self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore( - rag_resources=[gapic_rag_resource], - rag_retrieval_config=api_retrieval_config, - ) - else: - self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore( - rag_corpora=[rag_corpus_name], - rag_retrieval_config=api_retrieval_config, - ) + gapic_rag_resource = gapic_tool_types.VertexRagStore.RagResource( + rag_corpus=rag_corpus_name, + rag_file_ids=rag_resources[0].rag_file_ids, + ) + self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore( + rag_resources=[gapic_rag_resource], + rag_retrieval_config=api_retrieval_config, + )