Skip to content

Commit

Permalink
feat: Adding Vertex AI Search Config for RAG corpuses to SDK
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700775020
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 27, 2024
1 parent 88ac48c commit d3d69d6
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 8 deletions.
40 changes: 40 additions & 0 deletions tests/unit/vertex_rag/test_rag_constants_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
JiraSource,
JiraQuery,
Weaviate,
VertexAiSearchConfig,
VertexVectorSearch,
VertexFeatureStore,
)
Expand All @@ -52,6 +53,7 @@
RagContexts,
RetrieveContextsResponse,
RagVectorDbConfig,
VertexAiSearchConfig as GapicVertexAiSearchConfig,
)
from google.cloud.aiplatform_v1beta1.types import api_auth
from google.protobuf import timestamp_pb2
Expand Down Expand Up @@ -189,6 +191,44 @@
vector_db=TEST_VERTEX_VECTOR_SEARCH_CONFIG,
)
TEST_PAGE_TOKEN = "test-page-token"
# Vertex AI Search Config
TEST_VERTEX_AI_SEARCH_ENGINE_SERVING_CONFIG = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/collections/test-collection/engines/test-engine/servingConfigs/test-serving-config"
TEST_VERTEX_AI_SEARCH_DATASTORE_SERVING_CONFIG = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/collections/test-collection/dataStores/test-datastore/servingConfigs/test-serving-config"
TEST_GAPIC_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG = GapicRagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
vertex_ai_search_config=GapicVertexAiSearchConfig(
serving_config=TEST_VERTEX_AI_SEARCH_ENGINE_SERVING_CONFIG,
),
)
TEST_GAPIC_RAG_CORPUS_VERTEX_AI_DATASTORE_SEARCH_CONFIG = GapicRagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
vertex_ai_search_config=GapicVertexAiSearchConfig(
serving_config=TEST_VERTEX_AI_SEARCH_DATASTORE_SERVING_CONFIG,
),
)
TEST_VERTEX_AI_SEARCH_CONFIG_ENGINE = VertexAiSearchConfig(
serving_config=TEST_VERTEX_AI_SEARCH_ENGINE_SERVING_CONFIG,
)
TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE = VertexAiSearchConfig(
serving_config=TEST_VERTEX_AI_SEARCH_DATASTORE_SERVING_CONFIG,
)
TEST_VERTEX_AI_SEARCH_CONFIG_INVALID = VertexAiSearchConfig(
serving_config="invalid-serving-config",
)
TEST_VERTEX_AI_SEARCH_CONFIG_EMPTY = VertexAiSearchConfig()

TEST_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG = RagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
vertex_ai_search_config=TEST_VERTEX_AI_SEARCH_CONFIG_ENGINE,
)
TEST_RAG_CORPUS_VERTEX_AI_DATASTORE_SEARCH_CONFIG = RagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
vertex_ai_search_config=TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE,
)

# RagFiles
TEST_PATH = "usr/home/my_file.txt"
Expand Down
141 changes: 141 additions & 0 deletions tests/unit/vertex_rag/test_rag_data_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,57 @@ def create_rag_corpus_mock_pinecone():
yield create_rag_corpus_mock_pinecone


@pytest.fixture
def create_rag_corpus_mock_vertex_ai_engine_search_config():
with mock.patch.object(
VertexRagDataServiceClient,
"create_rag_corpus",
) as create_rag_corpus_mock_vertex_ai_engine_search_config:
create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
create_rag_corpus_lro_mock.done.return_value = True
create_rag_corpus_lro_mock.result.return_value = (
test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG
)
create_rag_corpus_mock_vertex_ai_engine_search_config.return_value = (
create_rag_corpus_lro_mock
)
yield create_rag_corpus_mock_vertex_ai_engine_search_config


@pytest.fixture
def create_rag_corpus_mock_vertex_ai_datastore_search_config():
with mock.patch.object(
VertexRagDataServiceClient,
"create_rag_corpus",
) as create_rag_corpus_mock_vertex_ai_datastore_search_config:
create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
create_rag_corpus_lro_mock.done.return_value = True
create_rag_corpus_lro_mock.result.return_value = (
test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS_VERTEX_AI_DATASTORE_SEARCH_CONFIG
)
create_rag_corpus_mock_vertex_ai_datastore_search_config.return_value = (
create_rag_corpus_lro_mock
)
yield create_rag_corpus_mock_vertex_ai_datastore_search_config


@pytest.fixture
def update_rag_corpus_mock_vertex_ai_engine_search_config():
with mock.patch.object(
VertexRagDataServiceClient,
"update_rag_corpus",
) as update_rag_corpus_mock_vertex_ai_engine_search_config:
update_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
update_rag_corpus_lro_mock.done.return_value = True
update_rag_corpus_lro_mock.result.return_value = (
test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG
)
update_rag_corpus_mock_vertex_ai_engine_search_config.return_value = (
update_rag_corpus_lro_mock
)
yield update_rag_corpus_mock_vertex_ai_engine_search_config


@pytest.fixture
def update_rag_corpus_mock_weaviate():
with mock.patch.object(
Expand Down Expand Up @@ -280,6 +331,9 @@ def rag_corpus_eq(returned_corpus, expected_corpus):
assert returned_corpus.name == expected_corpus.name
assert returned_corpus.display_name == expected_corpus.display_name
assert returned_corpus.vector_db.__eq__(expected_corpus.vector_db)
assert returned_corpus.vertex_ai_search_config.__eq__(
expected_corpus.vertex_ai_search_config
)


def rag_file_eq(returned_file, expected_file):
Expand Down Expand Up @@ -373,6 +427,70 @@ def test_create_corpus_pinecone_success(self):

rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_PINECONE)

@pytest.mark.usefixtures("create_rag_corpus_mock_vertex_ai_engine_search_config")
def test_create_corpus_vais_engine_search_config_success(self):
rag_corpus = rag.create_corpus(
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_ENGINE,
)

rag_corpus_eq(
rag_corpus,
test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG,
)

@pytest.mark.usefixtures("create_rag_corpus_mock_vertex_ai_datastore_search_config")
def test_create_corpus_vais_datastore_search_config_success(self):
rag_corpus = rag.create_corpus(
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE,
)

rag_corpus_eq(
rag_corpus,
test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_AI_DATASTORE_SEARCH_CONFIG,
)

def test_create_corpus_vais_datastore_search_config_with_vector_db_failure(self):
with pytest.raises(ValueError) as e:
rag.create_corpus(
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE,
vector_db=test_rag_constants_preview.TEST_VERTEX_VECTOR_SEARCH_CONFIG,
)
e.match("Only one of vertex_ai_search_config or vector_db can be set.")

def test_create_corpus_vais_datastore_search_config_with_embedding_model_config_failure(
self,
):
with pytest.raises(ValueError) as e:
rag.create_corpus(
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE,
embedding_model_config=test_rag_constants_preview.TEST_EMBEDDING_MODEL_CONFIG,
)
e.match(
"Only one of vertex_ai_search_config or embedding_model_config can be set."
)

def test_set_vertex_ai_search_config_with_invalid_serving_config_failure(self):
with pytest.raises(ValueError) as e:
rag.create_corpus(
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_INVALID,
)
e.match(
"serving_config must be of the format `projects/{project}/locations/{location}/collections/{collection}/engines/{engine}/servingConfigs/{serving_config}` or `projects/{project}/locations/{location}/collections/{collection}/dataStores/{data_store}/servingConfigs/{serving_config}`"
)

def test_set_vertex_ai_search_config_with_empty_serving_config_failure(self):
with pytest.raises(ValueError) as e:
rag.create_corpus(
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_EMPTY,
)
e.match("serving_config must be set.")

@pytest.mark.usefixtures("rag_data_client_preview_mock_exception")
def test_create_corpus_failure(self):
with pytest.raises(RuntimeError) as e:
Expand Down Expand Up @@ -462,6 +580,29 @@ def test_update_corpus_failure(self):
)
e.match("Failed in RagCorpus update due to")

@pytest.mark.usefixtures("update_rag_corpus_mock_vertex_ai_engine_search_config")
def test_update_corpus_vais_engine_search_config_success(self):
rag_corpus = rag.update_corpus(
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_ENGINE,
)

rag_corpus_eq(
rag_corpus,
test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG,
)

def test_update_corpus_vais_datastore_search_config_with_vector_db_failure(self):
with pytest.raises(ValueError) as e:
rag.update_corpus(
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE,
vector_db=test_rag_constants_preview.TEST_VERTEX_VECTOR_SEARCH_CONFIG,
)
e.match("Only one of vertex_ai_search_config or vector_db can be set.")

@pytest.mark.usefixtures("rag_data_client_preview_mock")
def test_get_corpus_success(self):
rag_corpus = rag.get_corpus(
Expand Down
2 changes: 2 additions & 0 deletions vertexai/preview/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
SharePointSources,
SlackChannel,
SlackChannelsSource,
VertexAiSearchConfig,
VertexFeatureStore,
VertexVectorSearch,
Weaviate,
Expand All @@ -76,6 +77,7 @@
"SharePointSources",
"SlackChannel",
"SlackChannelsSource",
"VertexAiSearchConfig",
"VertexFeatureStore",
"VertexRagStore",
"VertexVectorSearch",
Expand Down
50 changes: 42 additions & 8 deletions vertexai/preview/rag/rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
RagManagedDb,
SharePointSources,
SlackChannelsSource,
VertexAiSearchConfig,
VertexFeatureStore,
VertexVectorSearch,
Weaviate,
Expand All @@ -64,6 +65,7 @@ def create_corpus(
vector_db: Optional[
Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb]
] = None,
vertex_ai_search_config: Optional[VertexAiSearchConfig] = None,
) -> RagCorpus:
"""Creates a new RagCorpus resource.
Expand All @@ -87,6 +89,9 @@ def create_corpus(
embedding_model_config: The embedding model config.
vector_db: The vector db config of the RagCorpus. If unspecified, the
default database Spanner is used.
vertex_ai_search_config: The Vertex AI Search config of the RagCorpus.
Note: embedding_model_config or vector_db cannot be set if
vertex_ai_search_config is specified.
Returns:
RagCorpus.
Raises:
Expand All @@ -103,10 +108,25 @@ def create_corpus(
embedding_model_config=embedding_model_config,
rag_corpus=rag_corpus,
)
_gapic_utils.set_vector_db(
vector_db=vector_db,
rag_corpus=rag_corpus,
)

if vertex_ai_search_config and embedding_model_config:
raise ValueError(
"Only one of vertex_ai_search_config or embedding_model_config can be set."
)

if vertex_ai_search_config and vector_db:
raise ValueError("Only one of vertex_ai_search_config or vector_db can be set.")

if vertex_ai_search_config:
_gapic_utils.set_vertex_ai_search_config(
vertex_ai_search_config=vertex_ai_search_config,
rag_corpus=rag_corpus,
)
else:
_gapic_utils.set_vector_db(
vector_db=vector_db,
rag_corpus=rag_corpus,
)

request = CreateRagCorpusRequest(
parent=parent,
Expand Down Expand Up @@ -134,6 +154,7 @@ def update_corpus(
RagManagedDb,
]
] = None,
vertex_ai_search_config: Optional[VertexAiSearchConfig] = None,
) -> RagCorpus:
"""Updates a RagCorpus resource.
Expand Down Expand Up @@ -161,6 +182,10 @@ def update_corpus(
description will not be updated.
vector_db: The vector db config of the RagCorpus. If not provided, the
vector db will not be updated.
vertex_ai_search_config: The Vertex AI Search config of the RagCorpus.
If not provided, the Vertex AI Search config will not be updated.
Note: embedding_model_config or vector_db cannot be set if
vertex_ai_search_config is specified.
Returns:
RagCorpus.
Expand All @@ -180,10 +205,19 @@ def update_corpus(
else:
rag_corpus = GapicRagCorpus(name=corpus_name)

_gapic_utils.set_vector_db(
vector_db=vector_db,
rag_corpus=rag_corpus,
)
if vertex_ai_search_config and vector_db:
raise ValueError("Only one of vertex_ai_search_config or vector_db can be set.")

if vertex_ai_search_config:
_gapic_utils.set_vertex_ai_search_config(
vertex_ai_search_config=vertex_ai_search_config,
rag_corpus=rag_corpus,
)
else:
_gapic_utils.set_vector_db(
vector_db=vector_db,
rag_corpus=rag_corpus,
)

request = UpdateRagCorpusRequest(
rag_corpus=rag_corpus,
Expand Down
Loading

0 comments on commit d3d69d6

Please sign in to comment.