Skip to content

Commit

Permalink
feat: Add compatibility for RagRetrievalConfig in rag_store and rag_r…
Browse files Browse the repository at this point in the history
…etrieval

feat: Add deprecation warnings for use of similarity_top_k, vector_search_alpha, and vector_distance_threshold in retrieval_query, use RagRetrievalConfig instead.

PiperOrigin-RevId: 700462404
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 26, 2024
1 parent 34ed530 commit c52e3e4
Show file tree
Hide file tree
Showing 7 changed files with 432 additions and 48 deletions.
12 changes: 12 additions & 0 deletions tests/unit/vertex_rag/test_rag_constants_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@

from vertexai.preview.rag import (
EmbeddingModelConfig,
Filter,
HybridSearch,
Pinecone,
RagCorpus,
RagFile,
RagResource,
RagRetrievalConfig,
SharePointSource,
SharePointSources,
SlackChannelsSource,
Expand Down Expand Up @@ -529,3 +532,12 @@
rag_corpus="213lkj-1/23jkl/",
rag_file_ids=[TEST_RAG_FILE_ID],
)
TEST_RAG_RETRIEVAL_CONFIG = RagRetrievalConfig(
top_k=2,
filter=Filter(vector_distance_threshold=0.5),
)
TEST_RAG_RETRIEVAL_CONFIG_ALPHA = RagRetrievalConfig(
top_k=2,
filter=Filter(vector_distance_threshold=0.5),
hybrid_search=HybridSearch(alpha=0.5),
)
89 changes: 84 additions & 5 deletions tests/unit/vertex_rag/test_rag_retrieval_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,58 @@ def teardown_method(self):

@pytest.mark.usefixtures("retrieve_contexts_mock")
def test_retrieval_query_rag_resources_success(self):
with pytest.warns(DeprecationWarning):
response = rag.retrieval_query(
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
text=test_rag_constants_preview.TEST_QUERY_TEXT,
similarity_top_k=2,
vector_distance_threshold=0.5,
vector_search_alpha=0.5,
)
retrieve_contexts_eq(
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
)

@pytest.mark.usefixtures("retrieve_contexts_mock")
def test_retrieval_query_rag_resources_config_success(self):
response = rag.retrieval_query(
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
text=test_rag_constants_preview.TEST_QUERY_TEXT,
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG_ALPHA,
)
retrieve_contexts_eq(
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
)

@pytest.mark.usefixtures("retrieve_contexts_mock")
def test_retrieval_query_rag_resources_default_config_success(self):
response = rag.retrieval_query(
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
text=test_rag_constants_preview.TEST_QUERY_TEXT,
similarity_top_k=2,
vector_distance_threshold=0.5,
vector_search_alpha=0.5,
)
retrieve_contexts_eq(
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
)

@pytest.mark.usefixtures("retrieve_contexts_mock")
def test_retrieval_query_rag_corpora_success(self):
with pytest.warns(DeprecationWarning):
response = rag.retrieval_query(
rag_corpora=[test_rag_constants_preview.TEST_RAG_CORPUS_ID],
text=test_rag_constants_preview.TEST_QUERY_TEXT,
similarity_top_k=2,
vector_distance_threshold=0.5,
)
retrieve_contexts_eq(
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
)

@pytest.mark.usefixtures("retrieve_contexts_mock")
def test_retrieval_query_rag_corpora_config_success(self):
response = rag.retrieval_query(
rag_corpora=[test_rag_constants_preview.TEST_RAG_CORPUS_ID],
text=test_rag_constants_preview.TEST_QUERY_TEXT,
similarity_top_k=2,
vector_distance_threshold=0.5,
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
)
retrieve_contexts_eq(
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
Expand All @@ -107,6 +141,16 @@ def test_retrieval_query_failure(self):
)
e.match("Failed in retrieving contexts due to")

@pytest.mark.usefixtures("rag_client_mock_exception")
def test_retrieval_query_config_failure(self):
with pytest.raises(RuntimeError) as e:
rag.retrieval_query(
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
text=test_rag_constants_preview.TEST_QUERY_TEXT,
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
)
e.match("Failed in retrieving contexts due to")

def test_retrieval_query_invalid_name(self):
with pytest.raises(ValueError) as e:
rag.retrieval_query(
Expand All @@ -119,6 +163,17 @@ def test_retrieval_query_invalid_name(self):
)
e.match("Invalid RagCorpus name")

def test_retrieval_query_invalid_name_config(self):
with pytest.raises(ValueError) as e:
rag.retrieval_query(
rag_resources=[
test_rag_constants_preview.TEST_RAG_RESOURCE_INVALID_NAME
],
text=test_rag_constants_preview.TEST_QUERY_TEXT,
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
)
e.match("Invalid RagCorpus name")

def test_retrieval_query_multiple_rag_corpora(self):
with pytest.raises(ValueError) as e:
rag.retrieval_query(
Expand All @@ -132,6 +187,18 @@ def test_retrieval_query_multiple_rag_corpora(self):
)
e.match("Currently only support 1 RagCorpus")

def test_retrieval_query_multiple_rag_corpora_config(self):
with pytest.raises(ValueError) as e:
rag.retrieval_query(
rag_corpora=[
test_rag_constants_preview.TEST_RAG_CORPUS_ID,
test_rag_constants_preview.TEST_RAG_CORPUS_ID,
],
text=test_rag_constants_preview.TEST_QUERY_TEXT,
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
)
e.match("Currently only support 1 RagCorpus")

def test_retrieval_query_multiple_rag_resources(self):
with pytest.raises(ValueError) as e:
rag.retrieval_query(
Expand All @@ -144,3 +211,15 @@ def test_retrieval_query_multiple_rag_resources(self):
vector_distance_threshold=0.5,
)
e.match("Currently only support 1 RagResource")

def test_retrieval_query_multiple_rag_resources_config(self):
with pytest.raises(ValueError) as e:
rag.retrieval_query(
rag_resources=[
test_rag_constants_preview.TEST_RAG_RESOURCE,
test_rag_constants_preview.TEST_RAG_RESOURCE,
],
text=test_rag_constants_preview.TEST_QUERY_TEXT,
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
)
e.match("Currently only support 1 RagResource")
71 changes: 70 additions & 1 deletion tests/unit/vertex_rag/test_rag_store_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,32 @@


@pytest.mark.usefixtures("google_auth_mock")
class TestRagStoreValidations:
class TestRagStore:
def test_retrieval_tool_success(self):
with pytest.warns(DeprecationWarning):
Tool.from_retrieval(
retrieval=rag.Retrieval(
source=rag.VertexRagStore(
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
similarity_top_k=3,
vector_distance_threshold=0.4,
),
)
)

def test_retrieval_tool_config_success(self):
with pytest.warns(DeprecationWarning):
Tool.from_retrieval(
retrieval=rag.Retrieval(
source=rag.VertexRagStore(
rag_corpora=[
test_rag_constants_preview.TEST_RAG_CORPUS_ID,
],
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
),
)
)

def test_retrieval_tool_invalid_name(self):
with pytest.raises(ValueError) as e:
Tool.from_retrieval(
Expand All @@ -37,6 +62,20 @@ def test_retrieval_tool_invalid_name(self):
)
e.match("Invalid RagCorpus name")

def test_retrieval_tool_invalid_name_config(self):
with pytest.raises(ValueError) as e:
Tool.from_retrieval(
retrieval=rag.Retrieval(
source=rag.VertexRagStore(
rag_resources=[
test_rag_constants_preview.TEST_RAG_RESOURCE_INVALID_NAME
],
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
),
)
)
e.match("Invalid RagCorpus name")

def test_retrieval_tool_multiple_rag_corpora(self):
with pytest.raises(ValueError) as e:
Tool.from_retrieval(
Expand All @@ -53,6 +92,21 @@ def test_retrieval_tool_multiple_rag_corpora(self):
)
e.match("Currently only support 1 RagCorpus")

def test_retrieval_tool_multiple_rag_corpora_config(self):
with pytest.raises(ValueError) as e:
Tool.from_retrieval(
retrieval=rag.Retrieval(
source=rag.VertexRagStore(
rag_corpora=[
test_rag_constants_preview.TEST_RAG_CORPUS_ID,
test_rag_constants_preview.TEST_RAG_CORPUS_ID,
],
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
),
)
)
e.match("Currently only support 1 RagCorpus")

def test_retrieval_tool_multiple_rag_resources(self):
with pytest.raises(ValueError) as e:
Tool.from_retrieval(
Expand All @@ -68,3 +122,18 @@ def test_retrieval_tool_multiple_rag_resources(self):
)
)
e.match("Currently only support 1 RagResource")

def test_retrieval_tool_multiple_rag_resources_config(self):
with pytest.raises(ValueError) as e:
Tool.from_retrieval(
retrieval=rag.Retrieval(
source=rag.VertexRagStore(
rag_resources=[
test_rag_constants_preview.TEST_RAG_RESOURCE,
test_rag_constants_preview.TEST_RAG_RESOURCE,
],
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
),
)
)
e.match("Currently only support 1 RagResource")
6 changes: 6 additions & 0 deletions vertexai/preview/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,16 @@
)
from vertexai.preview.rag.utils.resources import (
EmbeddingModelConfig,
Filter,
HybridSearch,
JiraQuery,
JiraSource,
Pinecone,
RagCorpus,
RagFile,
RagManagedDb,
RagResource,
RagRetrievalConfig,
SharePointSource,
SharePointSources,
SlackChannel,
Expand All @@ -58,13 +61,16 @@

__all__ = (
"EmbeddingModelConfig",
"Filter",
"HybridSearch",
"JiraQuery",
"JiraSource",
"Pinecone",
"RagCorpus",
"RagFile",
"RagManagedDb",
"RagResource",
"RagRetrievalConfig",
"Retrieval",
"SharePointSource",
"SharePointSources",
Expand Down
Loading

0 comments on commit c52e3e4

Please sign in to comment.