diff --git a/google/generativeai/text.py b/google/generativeai/text.py index d91d037c8..6f5f972b1 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -15,8 +15,9 @@ from __future__ import annotations import dataclasses -from collections.abc import Sequence -from typing import Iterable, overload +from collections.abc import Iterable, Sequence +import itertools +from typing import Iterable, overload, TypeVar import google.ai.generativelanguage as glm @@ -28,6 +29,26 @@ from google.generativeai.types import safety_types DEFAULT_TEXT_MODEL = "models/text-bison-001" +EMBEDDING_MAX_BATCH_SIZE = 100 + +try: + # python 3.12+ + _batched = itertools.batched # type: ignore +except AttributeError: + T = TypeVar("T") + + def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]: + if n < 1: + raise ValueError(f"Batch size `n` must be >1, got: {n}") + batch = [] + for item in iterable: + batch.append(item) + if len(batch) == n: + yield batch + batch = [] + + if batch: + yield batch def _make_text_prompt(prompt: str | dict[str, str]) -> glm.TextPrompt: @@ -282,9 +303,13 @@ def generate_embeddings( embedding_dict = type(embedding_response).to_dict(embedding_response) embedding_dict["embedding"] = embedding_dict["embedding"]["value"] else: - embedding_request = glm.BatchEmbedTextRequest(model=model, texts=text) - embedding_response = client.batch_embed_text(embedding_request) - embedding_dict = type(embedding_response).to_dict(embedding_response) - embedding_dict["embedding"] = [e["value"] for e in embedding_dict["embeddings"]] + result = {"embedding": []} + for batch in _batched(text, EMBEDDING_MAX_BATCH_SIZE): + # TODO(markdaoust): This could use an option for returning an iterator or wait-bar. + embedding_request = glm.BatchEmbedTextRequest(model=model, texts=batch) + embedding_response = client.batch_embed_text(embedding_request) + embedding_dict = type(embedding_response).to_dict(embedding_response) + result["embedding"].extend(e["value"] for e in embedding_dict["embeddings"]) + return result return embedding_dict diff --git a/tests/test_text.py b/tests/test_text.py index 0d2636b06..180811deb 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +import math import unittest import unittest.mock as mock @@ -61,7 +62,10 @@ def batch_embed_text( request: glm.EmbedTextRequest, ) -> glm.EmbedTextResponse: self.observed_requests.append(request) - return self.responses["batch_embed_text"] + + return glm.BatchEmbedTextResponse( + embeddings=[glm.Embedding(value=[1, 2, 3])] * len(request.texts) + ) @add_client_method def count_text_tokens( @@ -120,27 +124,47 @@ def test_generate_embeddings(self, model, text): @parameterized.named_parameters( [ dict( - testcase_name="basic_model", + testcase_name="small-2", model="models/chat-lamda-001", text=["Who are you?", "Who am I?"], - ) + ), + dict( + testcase_name="even-batch", + model="models/chat-lamda-001", + text=["Who are you?"] * 100, + ), + dict( + testcase_name="even-batch-plus-one", + model="models/chat-lamda-001", + text=["Who are you?"] * 101, + ), + dict( + testcase_name="odd-batch", + model="models/chat-lamda-001", + text=["Who are you?"] * 237, + ), ] ) def test_generate_embeddings_batch(self, model, text): - self.responses["batch_embed_text"] = glm.BatchEmbedTextResponse( - embeddings=[ - glm.Embedding(value=[1, 2, 3]), - glm.Embedding(value=[4, 5, 6]), - ] - ) emb = text_service.generate_embeddings(model=model, text=text) self.assertIsInstance(emb, dict) - self.assertEqual( - self.observed_requests[-1], glm.BatchEmbedTextRequest(model=model, texts=text) - ) + + # Check first and last requests. + self.assertEqual(self.observed_requests[-1].model, model) + self.assertEqual(self.observed_requests[-1].texts[-1], text[-1]) + self.assertEqual(self.observed_requests[0].texts[0], text[0]) + + # Check that the list has the right length. self.assertIsInstance(emb["embedding"][0], list) + self.assertLen(emb["embedding"], len(text)) + + # Check that the right number of requests were sent. + self.assertLen( + self.observed_requests, + math.ceil(len(text) / text_service.EMBEDDING_MAX_BATCH_SIZE), + ) @parameterized.named_parameters( [