Skip to content

Commit

Permalink
feat: Add support for ranking field in rag_retrieval_config for rag_s…
Browse files Browse the repository at this point in the history
…tore creation.

PiperOrigin-RevId: 700842367
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 28, 2024
1 parent b7f9492 commit 6faa1d0
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 0 deletions.
16 changes: 16 additions & 0 deletions tests/unit/vertex_rag/test_rag_constants_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
RagFile,
RagResource,
RagRetrievalConfig,
Ranking,
RankService,
LlmRanker,
SharePointSource,
SharePointSources,
SlackChannelsSource,
Expand Down Expand Up @@ -590,3 +593,16 @@
top_k=2,
filter=Filter(vector_distance_threshold=0.5, vector_similarity_threshold=0.5),
)
TEST_RAG_RETRIEVAL_RANKING_CONFIG = RagRetrievalConfig(
top_k=2,
filter=Filter(vector_distance_threshold=0.5),
ranking=Ranking(rank_service=RankService(model_name="test-rank-service")),
)
TEST_RAG_RETRIEVAL_ERROR_RANKING_CONFIG = RagRetrievalConfig(
top_k=2,
filter=Filter(vector_distance_threshold=0.5),
ranking=Ranking(
rank_service=RankService(model_name="test-rank-service"),
llm_ranker=LlmRanker(model_name="test-llm-ranker"),
),
)
29 changes: 29 additions & 0 deletions tests/unit/vertex_rag/test_rag_store_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,19 @@ def test_retrieval_tool_similarity_config_success(self):
)
)

def test_retrieval_tool_ranking_config_success(self):
with pytest.warns(DeprecationWarning):
Tool.from_retrieval(
retrieval=rag.Retrieval(
source=rag.VertexRagStore(
rag_corpora=[
test_rag_constants_preview.TEST_RAG_CORPUS_ID,
],
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_RANKING_CONFIG,
),
)
)

def test_retrieval_tool_invalid_name(self):
with pytest.raises(ValueError) as e:
Tool.from_retrieval(
Expand Down Expand Up @@ -166,3 +179,19 @@ def test_retrieval_tool_invalid_config_filter(self):
" vector_similarity_threshold can be specified at a time"
" in rag_retrieval_config."
)

def test_retrieval_tool_invalid_ranking_config_filter(self):
with pytest.raises(ValueError) as e:
Tool.from_retrieval(
retrieval=rag.Retrieval(
source=rag.VertexRagStore(
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_ERROR_RANKING_CONFIG,
)
)
)
e.match(
"Only one of vector_distance_threshold or"
" vector_similarity_threshold can be specified at a time"
" in rag_retrieval_config."
)
6 changes: 6 additions & 0 deletions vertexai/preview/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
RagManagedDb,
RagResource,
RagRetrievalConfig,
Ranking,
RankService,
LlmRanker,
SharePointSource,
SharePointSources,
SlackChannel,
Expand All @@ -72,6 +75,9 @@
"RagManagedDb",
"RagResource",
"RagRetrievalConfig",
"Ranking",
"RankService",
"LlmRanker",
"Retrieval",
"SharePointSource",
"SharePointSources",
Expand Down
25 changes: 25 additions & 0 deletions vertexai/preview/rag/rag_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,31 @@ def __init__(
api_retrival_config.filter.vector_similarity_threshold = (
rag_retrieval_config.filter.vector_similarity_threshold
)
# Check if both rank_service and llm_ranker are specified.
if (
rag_retrieval_config.ranking
and rag_retrieval_config.ranking.rank_service
and rag_retrieval_config.ranking.rank_service.model_name
and rag_retrieval_config.ranking.llm_ranker
and rag_retrieval_config.ranking.llm_ranker.model_name
):
raise ValueError(
"Only one of rank_service or llm_ranker can be specified"
" at a time in rag_retrieval_config."
)
# Set rank_service to config value if specified
if (
rag_retrieval_config.ranking
and rag_retrieval_config.ranking.rank_service
):
api_retrival_config.ranking.rank_service.model_name = (
rag_retrieval_config.ranking.rank_service.model_name
)
# Set llm_ranker to config value if specified
if rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker:
api_retrival_config.ranking.llm_ranker.model_name = (
rag_retrieval_config.ranking.llm_ranker.model_name
)

if rag_resources:
gapic_rag_resource = gapic_tool_types.VertexRagStore.RagResource(
Expand Down

0 comments on commit 6faa1d0

Please sign in to comment.