Skip to content

Commit

Permalink
fix: RAG Fix v1 rag_store compatibility with generative_models Tool b…
Browse files Browse the repository at this point in the history
…y changing back to v1beta1

PiperOrigin-RevId: 702520155
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Dec 4, 2024
1 parent 0537fec commit e220312
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 18 deletions.
33 changes: 33 additions & 0 deletions tests/unit/vertex_rag/test_rag_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,39 @@

@pytest.mark.usefixtures("google_auth_mock")
class TestRagStoreValidations:
def test_retrieval_tool_success(self):
tool = Tool.from_retrieval(
retrieval=rag.Retrieval(
source=rag.VertexRagStore(
rag_resources=[tc.TEST_RAG_RESOURCE],
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
),
)
)
assert tool is not None

def test_retrieval_tool_vector_similarity_success(self):
tool = Tool.from_retrieval(
retrieval=rag.Retrieval(
source=rag.VertexRagStore(
rag_resources=[tc.TEST_RAG_RESOURCE],
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
),
)
)
assert tool is not None

def test_retrieval_tool_no_rag_resources(self):
with pytest.raises(ValueError) as e:
Tool.from_retrieval(
retrieval=rag.Retrieval(
source=rag.VertexRagStore(
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
),
)
)
e.match("rag_resources must be specified.")

def test_retrieval_tool_invalid_name(self):
with pytest.raises(ValueError) as e:
Tool.from_retrieval(
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/vertex_rag/test_rag_store_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,30 @@ def test_retrieval_tool_ranking_config_success(self):
)
)

def test_empty_retrieval_tool_success(self):
tool = Tool.from_retrieval(
retrieval=rag.Retrieval(
source=rag.VertexRagStore(
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
rag_retrieval_config=rag.RagRetrievalConfig(),
similarity_top_k=3,
vector_distance_threshold=0.4,
),
)
)
assert tool is not None

def test_retrieval_tool_no_rag_resources(self):
with pytest.raises(ValueError) as e:
Tool.from_retrieval(
retrieval=rag.Retrieval(
source=rag.VertexRagStore(
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
),
)
)
e.match("rag_resources or rag_corpora must be specified.")

def test_retrieval_tool_invalid_name(self):
with pytest.raises(ValueError) as e:
Tool.from_retrieval(
Expand Down
30 changes: 12 additions & 18 deletions vertexai/rag/rag_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import re
from typing import List, Optional, Union

from google.cloud import aiplatform_v1
from google.cloud import aiplatform_v1beta1
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform_v1.types import tool as gapic_tool_types
from vertexai.preview import generative_models
from google.cloud.aiplatform_v1beta1.types import tool as gapic_tool_types
from vertexai import generative_models
from vertexai.rag.utils import _gapic_utils
from vertexai.rag.utils import resources

Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(
)

# If rag_retrieval_config is not specified, set it to default values.
api_retrieval_config = aiplatform_v1.RagRetrievalConfig()
api_retrieval_config = aiplatform_v1beta1.RagRetrievalConfig()
# If rag_retrieval_config is specified, populate the default config.
if rag_retrieval_config:
api_retrieval_config.top_k = rag_retrieval_config.top_k
Expand All @@ -128,17 +128,11 @@ def __init__(
rag_retrieval_config.filter.vector_similarity_threshold
)

if rag_resources:
gapic_rag_resource = gapic_tool_types.VertexRagStore.RagResource(
rag_corpus=rag_corpus_name,
rag_file_ids=rag_resources[0].rag_file_ids,
)
self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore(
rag_resources=[gapic_rag_resource],
rag_retrieval_config=api_retrieval_config,
)
else:
self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore(
rag_corpora=[rag_corpus_name],
rag_retrieval_config=api_retrieval_config,
)
gapic_rag_resource = gapic_tool_types.VertexRagStore.RagResource(
rag_corpus=rag_corpus_name,
rag_file_ids=rag_resources[0].rag_file_ids,
)
self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore(
rag_resources=[gapic_rag_resource],
rag_retrieval_config=api_retrieval_config,
)

0 comments on commit e220312

Please sign in to comment.