Skip to content

Commit

Permalink
Handle max batch size for embeddings.
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkDaoust committed Oct 17, 2023
1 parent 923d372 commit 4e04929
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 18 deletions.
37 changes: 31 additions & 6 deletions google/generativeai/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from __future__ import annotations

import dataclasses
from collections.abc import Sequence
from typing import Iterable, overload
from collections.abc import Iterable, Sequence
import itertools
from typing import Iterable, overload, TypeVar

import google.ai.generativelanguage as glm

Expand All @@ -28,6 +29,26 @@
from google.generativeai.types import safety_types

DEFAULT_TEXT_MODEL = "models/text-bison-001"
EMBEDDING_MAX_BATCH_SIZE = 100

try:
# python 3.12+
_batched = itertools.batched
except AttributeError:
T = TypeVar("T")

def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]:
if n < 1:
raise ValueError(f"Batch size `n` must be >1, got: {n}")
batch = []
for item in iterable:
batch.append(item)
if len(batch) == n:
yield batch
batch = []

if batch:
yield batch


def _make_text_prompt(prompt: str | dict[str, str]) -> glm.TextPrompt:
Expand Down Expand Up @@ -282,9 +303,13 @@ def generate_embeddings(
embedding_dict = type(embedding_response).to_dict(embedding_response)
embedding_dict["embedding"] = embedding_dict["embedding"]["value"]
else:
embedding_request = glm.BatchEmbedTextRequest(model=model, texts=text)
embedding_response = client.batch_embed_text(embedding_request)
embedding_dict = type(embedding_response).to_dict(embedding_response)
embedding_dict["embedding"] = [e["value"] for e in embedding_dict["embeddings"]]
result = {"embedding": []}
for batch in _batched(text, EMBEDDING_MAX_BATCH_SIZE):
# TODO(markdaoust): This could use an option for returning an iterator or wait-bar.
embedding_request = glm.BatchEmbedTextRequest(model=model, texts=batch)
embedding_response = client.batch_embed_text(embedding_request)
embedding_dict = type(embedding_response).to_dict(embedding_response)
result["embedding"].extend(e["value"] for e in embedding_dict["embeddings"])
return result

return embedding_dict
48 changes: 36 additions & 12 deletions tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
import unittest
import unittest.mock as mock

Expand Down Expand Up @@ -61,7 +62,10 @@ def batch_embed_text(
request: glm.EmbedTextRequest,
) -> glm.EmbedTextResponse:
self.observed_requests.append(request)
return self.responses["batch_embed_text"]

return glm.BatchEmbedTextResponse(
embeddings=[glm.Embedding(value=[1, 2, 3])] * len(request.texts)
)

@add_client_method
def count_text_tokens(
Expand Down Expand Up @@ -120,27 +124,47 @@ def test_generate_embeddings(self, model, text):
@parameterized.named_parameters(
[
dict(
testcase_name="basic_model",
testcase_name="small-2",
model="models/chat-lamda-001",
text=["Who are you?", "Who am I?"],
)
),
dict(
testcase_name="even-batch",
model="models/chat-lamda-001",
text=["Who are you?"] * 100,
),
dict(
testcase_name="even-batch-plus-one",
model="models/chat-lamda-001",
text=["Who are you?"] * 101,
),
dict(
testcase_name="odd-batch",
model="models/chat-lamda-001",
text=["Who are you?"] * 237,
),
]
)
def test_generate_embeddings_batch(self, model, text):
self.responses["batch_embed_text"] = glm.BatchEmbedTextResponse(
embeddings=[
glm.Embedding(value=[1, 2, 3]),
glm.Embedding(value=[4, 5, 6]),
]
)

emb = text_service.generate_embeddings(model=model, text=text)

self.assertIsInstance(emb, dict)
self.assertEqual(
self.observed_requests[-1], glm.BatchEmbedTextRequest(model=model, texts=text)
)

# Check first and last requests.
self.assertEqual(self.observed_requests[-1].model, model)
self.assertEqual(self.observed_requests[-1].texts[-1], text[-1])
self.assertEqual(self.observed_requests[0].texts[0], text[0])

# Check that the list has the right length.
self.assertIsInstance(emb["embedding"][0], list)
self.assertLen(emb["embedding"], len(text))

# Check that the right number of requests were sent.
self.assertLen(
self.observed_requests,
math.ceil(len(text) / text_service.EMBEDDING_MAX_BATCH_SIZE),
)

@parameterized.named_parameters(
[
Expand Down

0 comments on commit 4e04929

Please sign in to comment.