Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle max batch size for embeddings. #83

Merged
merged 1 commit into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions google/generativeai/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from __future__ import annotations

import dataclasses
from collections.abc import Sequence
from typing import Iterable, overload
from collections.abc import Iterable, Sequence
import itertools
from typing import Iterable, overload, TypeVar
MarkDaoust marked this conversation as resolved.
Show resolved Hide resolved

import google.ai.generativelanguage as glm

Expand All @@ -28,6 +29,26 @@
from google.generativeai.types import safety_types

DEFAULT_TEXT_MODEL = "models/text-bison-001"
EMBEDDING_MAX_BATCH_SIZE = 100

try:
# python 3.12+
_batched = itertools.batched # type: ignore
except AttributeError:
T = TypeVar("T")

def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]:
if n < 1:
raise ValueError(f"Batch size `n` must be >1, got: {n}")
batch = []
for item in iterable:
batch.append(item)
if len(batch) == n:
yield batch
batch = []

if batch:
yield batch


def _make_text_prompt(prompt: str | dict[str, str]) -> glm.TextPrompt:
Expand Down Expand Up @@ -282,9 +303,13 @@ def generate_embeddings(
embedding_dict = type(embedding_response).to_dict(embedding_response)
embedding_dict["embedding"] = embedding_dict["embedding"]["value"]
else:
embedding_request = glm.BatchEmbedTextRequest(model=model, texts=text)
embedding_response = client.batch_embed_text(embedding_request)
embedding_dict = type(embedding_response).to_dict(embedding_response)
embedding_dict["embedding"] = [e["value"] for e in embedding_dict["embeddings"]]
result = {"embedding": []}
for batch in _batched(text, EMBEDDING_MAX_BATCH_SIZE):
# TODO(markdaoust): This could use an option for returning an iterator or wait-bar.
embedding_request = glm.BatchEmbedTextRequest(model=model, texts=batch)
embedding_response = client.batch_embed_text(embedding_request)
embedding_dict = type(embedding_response).to_dict(embedding_response)
result["embedding"].extend(e["value"] for e in embedding_dict["embeddings"])
return result

return embedding_dict
49 changes: 36 additions & 13 deletions tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
import unittest
import unittest.mock as mock

Expand Down Expand Up @@ -61,7 +62,10 @@ def batch_embed_text(
request: glm.EmbedTextRequest,
) -> glm.EmbedTextResponse:
self.observed_requests.append(request)
return self.responses["batch_embed_text"]

return glm.BatchEmbedTextResponse(
embeddings=[glm.Embedding(value=[1, 2, 3])] * len(request.texts)
)

@add_client_method
def count_text_tokens(
Expand Down Expand Up @@ -120,27 +124,46 @@ def test_generate_embeddings(self, model, text):
@parameterized.named_parameters(
[
dict(
testcase_name="basic_model",
testcase_name="small-2",
model="models/chat-lamda-001",
text=["Who are you?", "Who am I?"],
)
),
dict(
testcase_name="even-batch",
model="models/chat-lamda-001",
text=["Who are you?"] * 100,
),
dict(
testcase_name="even-batch-plus-one",
model="models/chat-lamda-001",
text=["Who are you?"] * 101,
),
dict(
testcase_name="odd-batch",
model="models/chat-lamda-001",
text=["Who are you?"] * 237,
),
]
)
def test_generate_embeddings_batch(self, model, text):
self.responses["batch_embed_text"] = glm.BatchEmbedTextResponse(
embeddings=[
glm.Embedding(value=[1, 2, 3]),
glm.Embedding(value=[4, 5, 6]),
]
)

emb = text_service.generate_embeddings(model=model, text=text)

self.assertIsInstance(emb, dict)
self.assertEqual(
self.observed_requests[-1], glm.BatchEmbedTextRequest(model=model, texts=text)
)

# Check first and last requests.
self.assertEqual(self.observed_requests[-1].model, model)
self.assertEqual(self.observed_requests[-1].texts[-1], text[-1])
self.assertEqual(self.observed_requests[0].texts[0], text[0])

# Check that the list has the right length.
self.assertIsInstance(emb["embedding"][0], list)
self.assertLen(emb["embedding"], len(text))

# Check that the right number of requests were sent.
self.assertLen(
self.observed_requests,
math.ceil(len(text) / text_service.EMBEDDING_MAX_BATCH_SIZE),
)

@parameterized.named_parameters(
[
Expand Down
Loading