Skip to content

Commit

Permalink
Add count_text_tokens, and expose operations. (google-gemini#76)
Browse files Browse the repository at this point in the history
* Add count_text_tokens, and expose operations.

* Format and fix pytype errors.

* use get_base_model_name in create_tuned_model

* Resolve comments
  • Loading branch information
MarkDaoust authored and markmcd committed Oct 30, 2023
1 parent a12fe9c commit 49d0441
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 22 deletions.
5 changes: 5 additions & 0 deletions google/generativeai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@

from google.generativeai.text import generate_text
from google.generativeai.text import generate_embeddings
from google.generativeai.text import count_text_tokens

from google.generativeai.models import list_models
from google.generativeai.models import list_tuned_models
Expand All @@ -89,6 +90,10 @@
from google.generativeai.models import update_tuned_model
from google.generativeai.models import delete_tuned_model

from google.generativeai.operations import list_operations
from google.generativeai.operations import get_operation


from google.generativeai.client import configure

__version__ = version.__version__
Expand Down
2 changes: 1 addition & 1 deletion google/generativeai/discuss.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def count_message_tokens(
messages: discuss_types.MessagesOptions | None = None,
model: model_types.AnyModelNameOptions = DEFAULT_DISCUSS_MODEL,
client: glm.DiscussServiceAsyncClient | None = None,
):
) -> discuss_types.TokenCount:
model = model_types.make_model_name(model)
prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages)

Expand Down
35 changes: 28 additions & 7 deletions google/generativeai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_base_model(name: model_types.BaseModelNameOptions, *, client=None) -> mo

name = model_types.make_model_name(name)
if not name.startswith("models/"):
raise ValueError("Base model names must start with `models/`")
raise ValueError(f"Base model names must start with `models/`, got: {name}")

result = client.get_model(name=name)
result = type(result).to_dict(result)
Expand Down Expand Up @@ -112,6 +112,31 @@ def get_tuned_model(
return model_types.decode_tuned_model(result)


def get_base_model_name(
model: model_types.AnyModelNameOptions, client: glm.ModelServiceClient | None = None
):
if isinstance(model, str):
if model.startswith("tunedModels/"):
model = get_model(model, client=client)
base_model = model.base_model
else:
base_model = model
elif isinstance(model, model_types.TunedModel):
base_model = model.base_model
elif isinstance(model, model_types.Model):
base_model = model.name
elif isinstance(model, glm.Model):
base_model = model.name
elif isinstance(model, glm.TunedModel):
base_model = getattr(model, "base_model", None)
if not base_model:
base_model = model.tuned_model_source.base_model
else:
raise TypeError(f"Cannot understand model: {model}")

return base_model


def _list_base_models_next_page(page_size, page_token, client):
"""Returns the next page of the base model or tuned model list."""
result = client.list_models(page_size=page_size, page_token=page_token)
Expand Down Expand Up @@ -270,18 +295,14 @@ def create_tuned_model(
client = get_default_model_client()

source_model_name = model_types.make_model_name(source_model)
base_model_name = get_base_model_name(source_model)
if source_model_name.startswith("models/"):
source_model = {"base_model": source_model_name}
elif source_model_name.startswith("tunedModels/"):
source_model = client.get_tuned_model(name=source_model_name)
base_model = source_model.base_model
if not base_model:
base_model = source_model.tuned_model_source.base_model

source_model = {
"tuned_model_source": {
"tuned_model": source_model_name,
"base_model": base_model,
"base_model": base_model_name,
}
}
else:
Expand Down
18 changes: 18 additions & 0 deletions google/generativeai/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.generativeai import string_utils
from google.generativeai.types import text_types
from google.generativeai.types import model_types
from google.generativeai import models
from google.generativeai.types import safety_types

DEFAULT_TEXT_MODEL = "models/text-bison-001"
Expand Down Expand Up @@ -217,6 +218,23 @@ def _generate_response(
return Completion(_client=client, **response)


def count_text_tokens(
model: model_types.AnyModelNameOptions,
prompt: str,
client: glm.TextServiceClient | None = None,
) -> text_types.TokenCount:
base_model = models.get_base_model_name(model)

if client is None:
client = get_default_text_client()

result = client.count_text_tokens(
glm.CountTextTokensRequest(model=base_model, prompt={"text": prompt})
)

return type(result).to_dict(result)


@overload
def generate_embeddings(
model: model_types.BaseModelNameOptions,
Expand Down
4 changes: 4 additions & 0 deletions google/generativeai/types/discuss_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
]


class TokenCount(TypedDict):
token_count: int


class MessageDict(TypedDict):
"""A dict representation of a `glm.Message`."""

Expand Down
2 changes: 1 addition & 1 deletion google/generativeai/types/model_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def make_model_name(name: AnyModelNameOptions):
raise TypeError("Expected: str, Model, or TunedModel")

if not (name.startswith("models/") or name.startswith("tunedModels/")):
raise ValueError("Model names should start with `models/` or `tunedModels/`")
raise ValueError("Model names should start with `models/` or `tunedModels/`, got: {name}")

return name

Expand Down
4 changes: 4 additions & 0 deletions google/generativeai/types/text_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
__all__ = ["Completion"]


class TokenCount(TypedDict):
token_count: int


class EmbeddingDict(TypedDict):
embedding: list[float]

Expand Down
106 changes: 93 additions & 13 deletions tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import copy
import unittest
import unittest.mock as mock

Expand All @@ -22,6 +21,7 @@
from google.generativeai import text as text_service
from google.generativeai import client
from google.generativeai.types import safety_types
from google.generativeai.types import model_types
from absl.testing import absltest
from absl.testing import parameterized

Expand All @@ -31,8 +31,9 @@ def setUp(self):
self.client = unittest.mock.MagicMock()

client._client_manager.text_client = self.client
client._client_manager.model_client = self.client

self.observed_request = None
self.observed_requests = []

self.responses = {}

Expand All @@ -45,23 +46,37 @@ def add_client_method(f):
def generate_text(
request: glm.GenerateTextRequest,
) -> glm.GenerateTextResponse:
self.observed_request = request
self.observed_requests.append(request)
return self.responses["generate_text"]

@add_client_method
def embed_text(
request: glm.EmbedTextRequest,
) -> glm.EmbedTextResponse:
self.observed_request = request
self.observed_requests.append(request)
return self.responses["embed_text"]

@add_client_method
def batch_embed_text(
request: glm.EmbedTextRequest,
) -> glm.EmbedTextResponse:
self.observed_request = request
self.observed_requests.append(request)
return self.responses["batch_embed_text"]

@add_client_method
def count_text_tokens(
request: glm.CountTextTokensRequest,
) -> glm.CountTextTokensResponse:
self.observed_requests.append(request)
return self.responses["count_text_tokens"]

@add_client_method
def get_tuned_model(name) -> glm.TunedModel:
request = glm.GetTunedModelRequest(name=name)
self.observed_requests.append(request)
response = copy.copy(self.responses["get_tuned_model"])
return response

@parameterized.named_parameters(
[
dict(testcase_name="string", prompt="Hello how are"),
Expand Down Expand Up @@ -99,7 +114,7 @@ def test_generate_embeddings(self, model, text):
emb = text_service.generate_embeddings(model=model, text=text)

self.assertIsInstance(emb, dict)
self.assertEqual(self.observed_request, glm.EmbedTextRequest(model=model, text=text))
self.assertEqual(self.observed_requests[-1], glm.EmbedTextRequest(model=model, text=text))
self.assertIsInstance(emb["embedding"][0], float)

@parameterized.named_parameters(
Expand All @@ -123,8 +138,7 @@ def test_generate_embeddings_batch(self, model, text):

self.assertIsInstance(emb, dict)
self.assertEqual(
self.observed_request,
glm.BatchEmbedTextRequest(model=model, texts=text),
self.observed_requests[-1], glm.BatchEmbedTextRequest(model=model, texts=text)
)
self.assertIsInstance(emb["embedding"][0], list)

Expand Down Expand Up @@ -160,7 +174,7 @@ def test_generate_response(self, *, prompt, **kwargs):
complete = text_service.generate_text(prompt=prompt, **kwargs)

self.assertEqual(
self.observed_request,
self.observed_requests[-1],
glm.GenerateTextRequest(
model="models/text-bison-001", prompt=glm.TextPrompt(text=prompt), **kwargs
),
Expand Down Expand Up @@ -188,15 +202,15 @@ def test_stop_string(self):
complete = text_service.generate_text(prompt="Hello", stop_sequences="stop")

self.assertEqual(
self.observed_request,
self.observed_requests[-1],
glm.GenerateTextRequest(
model="models/text-bison-001",
prompt=glm.TextPrompt(text="Hello"),
stop_sequences=["stop"],
),
)
# Just make sure it made it into the request object.
self.assertEqual(self.observed_request.stop_sequences, ["stop"])
self.assertEqual(self.observed_requests[-1].stop_sequences, ["stop"])

@parameterized.named_parameters(
[
Expand Down Expand Up @@ -251,7 +265,7 @@ def test_safety_settings(self, safety_settings):
)

self.assertEqual(
self.observed_request.safety_settings[0].category,
self.observed_requests[-1].safety_settings[0].category,
safety_types.HarmCategory.HARM_CATEGORY_MEDICAL,
)

Expand Down Expand Up @@ -367,6 +381,72 @@ def test_candidate_citations(self):
6,
)

@parameterized.named_parameters(
[
dict(testcase_name="base-name", model="models/text-bison-001"),
dict(testcase_name="tuned-name", model="tunedModels/bipedal-pangolin-001"),
dict(
testcase_name="model",
model=model_types.Model(
name="models/text-bison-001",
base_model_id="text-bison-001",
version="001",
display_name="🦬",
description="🦬🦬🦬🦬🦬🦬🦬🦬🦬🦬🦬",
input_token_limit=8000,
output_token_limit=4000,
supported_generation_methods=["GenerateText"],
),
),
dict(
testcase_name="tuned_model",
model=model_types.TunedModel(
name="tunedModels/bipedal-pangolin-001",
base_model="models/text-bison-001",
),
),
dict(
testcase_name="glm_model",
model=glm.Model(
name="models/text-bison-001",
),
),
dict(
testcase_name="glm_tuned_model",
model=glm.TunedModel(
name="tunedModels/bipedal-pangolin-001",
base_model="models/text-bison-001",
),
),
dict(
testcase_name="glm_tuned_model_nested",
model=glm.TunedModel(
name="tunedModels/bipedal-pangolin-002",
tuned_model_source={
"tuned_model": "tunedModels/bipedal-pangolin-002",
"base_model": "models/text-bison-001",
},
),
),
]
)
def test_count_message_tokens(self, model):
self.responses["get_tuned_model"] = glm.TunedModel(
name="tunedModels/bipedal-pangolin-001", base_model="models/text-bison-001"
)
self.responses["count_text_tokens"] = glm.CountTextTokensResponse(token_count=7)

response = text_service.count_text_tokens(model, "Tell me a story about a magic backpack.")
self.assertEqual({"token_count": 7}, response)

should_look_up_model = isinstance(model, str) and model.startswith("tunedModels/")
if should_look_up_model:
self.assertLen(self.observed_requests, 2)
self.assertEqual(
self.observed_requests[0],
glm.GetTunedModelRequest(name="tunedModels/bipedal-pangolin-001"),
)


if __name__ == "__main__":
absltest.main()

0 comments on commit 49d0441

Please sign in to comment.