Skip to content

Commit

Permalink
feat: Switch Python generateContent to call Unary API endpoint
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 604369375
  • Loading branch information
matthew29tang authored and copybara-github committed Feb 5, 2024
1 parent 3f817f4 commit 9a19545
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 24 deletions.
62 changes: 56 additions & 6 deletions tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,56 @@ def mock_stream_generate_content(
yield response


def mock_generate_content(
self,
request: gapic_prediction_service_types.GenerateContentRequest,
*,
model: Optional[str] = None,
contents: Optional[MutableSequence[gapic_content_types.Content]] = None,
) -> Iterable[gapic_prediction_service_types.GenerateContentResponse]:
is_continued_chat = len(request.contents) > 1
has_tools = bool(request.tools)

if has_tools:
has_function_response = any(
"function_response" in content.parts[0] for content in request.contents
)
needs_function_call = not has_function_response
if needs_function_call:
response_part_struct = _RESPONSE_FUNCTION_CALL_PART_STRUCT
else:
response_part_struct = _RESPONSE_AFTER_FUNCTION_CALL_PART_STRUCT
elif is_continued_chat:
response_part_struct = {"text": "Other planets may have different sky color."}
else:
response_part_struct = _RESPONSE_TEXT_PART_STRUCT

return gapic_prediction_service_types.GenerateContentResponse(
candidates=[
gapic_content_types.Candidate(
index=0,
content=gapic_content_types.Content(
# Model currently does not identify itself
# role="model",
parts=[
gapic_content_types.Part(response_part_struct),
],
),
finish_reason=gapic_content_types.Candidate.FinishReason.STOP,
safety_ratings=[
gapic_content_types.SafetyRating(rating)
for rating in _RESPONSE_SAFETY_RATINGS_STRUCT
],
citation_metadata=gapic_content_types.CitationMetadata(
citations=[
gapic_content_types.Citation(_RESPONSE_CITATION_STRUCT),
]
),
),
],
)


@pytest.mark.usefixtures("google_auth_mock")
class TestGenerativeModels:
"""Unit tests for the generative models."""
Expand All @@ -178,8 +228,8 @@ def teardown_method(self):

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="stream_generate_content",
new=mock_stream_generate_content,
attribute="generate_content",
new=mock_generate_content,
)
def test_generate_content(self):
model = generative_models.GenerativeModel("gemini-pro")
Expand Down Expand Up @@ -212,8 +262,8 @@ def test_generate_content_streaming(self):

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="stream_generate_content",
new=mock_stream_generate_content,
attribute="generate_content",
new=mock_generate_content,
)
def test_chat_send_message(self):
model = generative_models.GenerativeModel("gemini-pro")
Expand All @@ -225,8 +275,8 @@ def test_chat_send_message(self):

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="stream_generate_content",
new=mock_stream_generate_content,
attribute="generate_content",
new=mock_generate_content,
)
def test_chat_function_calling(self):
get_current_weather_func = generative_models.FunctionDeclaration(
Expand Down
20 changes: 2 additions & 18 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,15 +431,7 @@ def _generate_content(
safety_settings=safety_settings,
tools=tools,
)
# generate_content is not available
# gapic_response = self._prediction_client.generate_content(request=request)
gapic_response = None
stream = self._prediction_client.stream_generate_content(request=request)
for gapic_chunk in stream:
if gapic_response:
_append_gapic_response(gapic_response, gapic_chunk)
else:
gapic_response = gapic_chunk
gapic_response = self._prediction_client.generate_content(request=request)
return self._parse_response(gapic_response)

async def _generate_content_async(
Expand Down Expand Up @@ -473,17 +465,9 @@ async def _generate_content_async(
safety_settings=safety_settings,
tools=tools,
)
# generate_content is not available
# gapic_response = await self._prediction_async_client.generate_content(request=request)
gapic_response = None
stream = await self._prediction_async_client.stream_generate_content(
gapic_response = await self._prediction_async_client.generate_content(
request=request
)
async for gapic_chunk in stream:
if gapic_response:
_append_gapic_response(gapic_response, gapic_chunk)
else:
gapic_response = gapic_chunk
return self._parse_response(gapic_response)

def _generate_content_streaming(
Expand Down

0 comments on commit 9a19545

Please sign in to comment.