Skip to content

Commit

Permalink
Add count_text_tokens, and expose operations.
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkDaoust committed Oct 4, 2023
1 parent 2a33920 commit 14884a3
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 13 deletions.
5 changes: 5 additions & 0 deletions google/generativeai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@

from google.generativeai.text import generate_text
from google.generativeai.text import generate_embeddings
from google.generativeai.text import count_text_tokens

from google.generativeai.models import list_models
from google.generativeai.models import list_tuned_models
Expand All @@ -89,6 +90,10 @@
from google.generativeai.models import update_tuned_model
from google.generativeai.models import delete_tuned_model

from google.generativeai.operations import list_operations
from google.generativeai.operations import get_operation


from google.generativeai.client import configure

__version__ = version.__version__
Expand Down
19 changes: 18 additions & 1 deletion google/generativeai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_base_model(

name = model_types.make_model_name(name)
if not name.startswith("models/"):
raise ValueError("Base model names must start with `models/`")
raise ValueError(f"Base model names must start with `models/`, got: {name}")

result = client.get_model(name=name)
result = type(result).to_dict(result)
Expand Down Expand Up @@ -114,6 +114,23 @@ def get_tuned_model(
return model_types.decode_tuned_model(result)


def get_base_model_name(
model: model_types.AnyModelNameOptions, client: glm.ModelServiceClient | None = None
):
if isinstance(model, str):
if model.startswith("tunedModels/"):
model = get_model(model, client=client)
base_model = model.base_model
else:
base_model = model
elif isinstance(model, model_types.TunedModel):
base_model = model.base_model
elif isinstance(model, model_types.Model):
base_model = model.name

return base_model


def _list_base_models_next_page(page_size, page_token, client):
"""Returns the next page of the base model or tuned model list."""
result = client.list_models(page_size=page_size, page_token=page_token)
Expand Down
16 changes: 16 additions & 0 deletions google/generativeai/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from google.generativeai.client import get_default_text_client
from google.generativeai.types import text_types
from google.generativeai.types import model_types
from google.generativeai import models
from google.generativeai.types import safety_types

DEFAULT_TEXT_MODEL = "models/text-bison-001"
Expand Down Expand Up @@ -217,6 +218,21 @@ def _generate_response(
return Completion(_client=client, **response)


def count_text_tokens(
model: model_types.ModelNameOptions,
prompt: str,
client: glm.TextServiceClient = None,
):
base_model = models.get_base_model_name(model)

if client is None:
client = get_default_text_client()

result = client.count_text_tokens(glm.CountTextTokensRequest(model=base_model, prompt={'text':prompt}))

return type(result).to_dict(result)


@overload
def generate_embeddings(
model: model_types.BaseModelNameOptions,
Expand Down
2 changes: 1 addition & 1 deletion google/generativeai/types/model_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def make_model_name(name: AnyModelNameOptions):
raise TypeError("Expected: str, Model, or TunedModel")

if not (name.startswith("models/") or name.startswith("tunedModels/")):
raise ValueError("Model names should start with `models/` or `tunedModels/`")
raise ValueError("Model names should start with `models/` or `tunedModels/`, got: {name}")

return name

Expand Down
75 changes: 64 additions & 11 deletions tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import copy
import os
import unittest
import unittest.mock as mock
Expand All @@ -22,6 +23,7 @@
from google.generativeai import text as text_service
from google.generativeai import client
from google.generativeai.types import safety_types
from google.generativeai.types import model_types
from absl.testing import absltest
from absl.testing import parameterized

Expand All @@ -31,8 +33,9 @@ def setUp(self):
self.client = unittest.mock.MagicMock()

client.default_text_client = self.client
client.default_model_client = self.client

self.observed_request = None
self.observed_requests = []

self.responses = {}

Expand All @@ -45,23 +48,38 @@ def add_client_method(f):
def generate_text(
request: glm.GenerateTextRequest,
) -> glm.GenerateTextResponse:
self.observed_request = request
self.observed_requests.append(request)
return self.responses["generate_text"]

@add_client_method
def embed_text(
request: glm.EmbedTextRequest,
) -> glm.EmbedTextResponse:
self.observed_request = request
self.observed_requests.append(request)
return self.responses["embed_text"]

@add_client_method
def batch_embed_text(
request: glm.EmbedTextRequest,
) -> glm.EmbedTextResponse:
self.observed_request = request
self.observed_requests.append(request)
return self.responses["batch_embed_text"]

@add_client_method
def count_text_tokens(
request: glm.CountTextTokensRequest
) -> glm.CountTextTokensResponse:
self.observed_requests.append(request)
return self.responses["count_text_tokens"]

@add_client_method
def get_tuned_model(
name
) -> glm.TunedModel:
request = glm.GetTunedModelRequest(name=name)
self.observed_requests.append(request)
response = copy.copy(self.responses["get_tuned_model"])
return response
@parameterized.named_parameters(
[
dict(testcase_name="string", prompt="Hello how are"),
Expand Down Expand Up @@ -102,7 +120,7 @@ def test_generate_embeddings(self, model, text):

self.assertIsInstance(emb, dict)
self.assertEqual(
self.observed_request, glm.EmbedTextRequest(model=model, text=text)
self.observed_requests[-1], glm.EmbedTextRequest(model=model, text=text)
)
self.assertIsInstance(emb["embedding"][0], float)

Expand All @@ -124,7 +142,7 @@ def test_generate_embeddings_batch(self, model, text):

self.assertIsInstance(emb, dict)
self.assertEqual(
self.observed_request, glm.BatchEmbedTextRequest(model=model, texts=text)
self.observed_requests[-1], glm.BatchEmbedTextRequest(model=model, texts=text)
)
self.assertIsInstance(emb["embedding"][0], list)

Expand Down Expand Up @@ -160,7 +178,7 @@ def test_generate_response(self, *, prompt, **kwargs):
complete = text_service.generate_text(prompt=prompt, **kwargs)

self.assertEqual(
self.observed_request,
self.observed_requests[-1],
glm.GenerateTextRequest(
model="models/text-bison-001",
prompt=glm.TextPrompt(text=prompt),
Expand Down Expand Up @@ -190,15 +208,15 @@ def test_stop_string(self):
complete = text_service.generate_text(prompt="Hello", stop_sequences="stop")

self.assertEqual(
self.observed_request,
self.observed_requests[-1],
glm.GenerateTextRequest(
model="models/text-bison-001",
prompt=glm.TextPrompt(text="Hello"),
stop_sequences=["stop"],
),
)
# Just make sure it made it into the request object.
self.assertEqual(self.observed_request.stop_sequences, ["stop"])
self.assertEqual(self.observed_requests[-1].stop_sequences, ["stop"])

@parameterized.named_parameters(
[
Expand Down Expand Up @@ -253,7 +271,7 @@ def test_safety_settings(self, safety_settings):
)

self.assertEqual(
self.observed_request.safety_settings[0].category,
self.observed_requests[-1].safety_settings[0].category,
safety_types.HarmCategory.HARM_CATEGORY_MEDICAL,
)

Expand Down Expand Up @@ -370,6 +388,41 @@ def test_candidate_citations(self):
6,
)

@parameterized.named_parameters(
[
dict(
testcase_name="base-name",
model = 'models/text-bison-001'
),
dict(
testcase_name="tuned-name",
model="tunedModels/bipedal-pangolin-001"
),
dict(
testcase_name="model",
model = model_types.Model(name="models/text-bison-001",
base_model_id=None, version=None, display_name=None,
description=None, input_token_limit=None, output_token_limit=None,
supported_generation_methods=None)
),
dict(
testcase_name="tuned_model",
model = model_types.TunedModel(name="tunedModels/bipedal-pangolin-001", base_model='models/text-bison-001')
),
]
)
def test_count_message_tokens(self, model):
self.responses["get_tuned_model"] = glm.TunedModel(
name="tunedModels/bipedal-pangolin-001", base_model='models/text-bison-001'
)
self.responses["count_text_tokens"] = glm.CountTextTokensResponse(token_count=7)

response = text_service.count_text_tokens(model, 'Tell me a story about a magic backpack.')
self.assertEqual({'token_count':7}, response)
if len(self.observed_requests) > 1:
self.assertEqual(self.observed_requests[0], glm.GetTunedModelRequest(name="tunedModels/bipedal-pangolin-001"))



if __name__ == "__main__":
absltest.main()

0 comments on commit 14884a3

Please sign in to comment.