From 14884a3d2df3c254ef322c5d7ecacb8e9f83b859 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 4 Oct 2023 13:20:36 -0700 Subject: [PATCH] Add count_text_tokens, and expose operations. --- google/generativeai/__init__.py | 5 ++ google/generativeai/models.py | 19 +++++- google/generativeai/text.py | 16 +++++ google/generativeai/types/model_types.py | 2 +- tests/test_text.py | 75 ++++++++++++++++++++---- 5 files changed, 104 insertions(+), 13 deletions(-) diff --git a/google/generativeai/__init__.py b/google/generativeai/__init__.py index 452fc997c..74026804a 100644 --- a/google/generativeai/__init__.py +++ b/google/generativeai/__init__.py @@ -77,6 +77,7 @@ from google.generativeai.text import generate_text from google.generativeai.text import generate_embeddings +from google.generativeai.text import count_text_tokens from google.generativeai.models import list_models from google.generativeai.models import list_tuned_models @@ -89,6 +90,10 @@ from google.generativeai.models import update_tuned_model from google.generativeai.models import delete_tuned_model +from google.generativeai.operations import list_operations +from google.generativeai.operations import get_operation + + from google.generativeai.client import configure __version__ = version.__version__ diff --git a/google/generativeai/models.py b/google/generativeai/models.py index 34413ecdc..73a82f6e2 100644 --- a/google/generativeai/models.py +++ b/google/generativeai/models.py @@ -76,7 +76,7 @@ def get_base_model( name = model_types.make_model_name(name) if not name.startswith("models/"): - raise ValueError("Base model names must start with `models/`") + raise ValueError(f"Base model names must start with `models/`, got: {name}") result = client.get_model(name=name) result = type(result).to_dict(result) @@ -114,6 +114,23 @@ def get_tuned_model( return model_types.decode_tuned_model(result) +def get_base_model_name( + model: model_types.AnyModelNameOptions, client: glm.ModelServiceClient | None = None +): + if isinstance(model, str): + if model.startswith("tunedModels/"): + model = get_model(model, client=client) + base_model = model.base_model + else: + base_model = model + elif isinstance(model, model_types.TunedModel): + base_model = model.base_model + elif isinstance(model, model_types.Model): + base_model = model.name + + return base_model + + def _list_base_models_next_page(page_size, page_token, client): """Returns the next page of the base model or tuned model list.""" result = client.list_models(page_size=page_size, page_token=page_token) diff --git a/google/generativeai/text.py b/google/generativeai/text.py index 8782d910f..079b5174a 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -23,6 +23,7 @@ from google.generativeai.client import get_default_text_client from google.generativeai.types import text_types from google.generativeai.types import model_types +from google.generativeai import models from google.generativeai.types import safety_types DEFAULT_TEXT_MODEL = "models/text-bison-001" @@ -217,6 +218,21 @@ def _generate_response( return Completion(_client=client, **response) +def count_text_tokens( + model: model_types.ModelNameOptions, + prompt: str, + client: glm.TextServiceClient = None, +): + base_model = models.get_base_model_name(model) + + if client is None: + client = get_default_text_client() + + result = client.count_text_tokens(glm.CountTextTokensRequest(model=base_model, prompt={'text':prompt})) + + return type(result).to_dict(result) + + @overload def generate_embeddings( model: model_types.BaseModelNameOptions, diff --git a/google/generativeai/types/model_types.py b/google/generativeai/types/model_types.py index 906bd3458..d1c6b063b 100644 --- a/google/generativeai/types/model_types.py +++ b/google/generativeai/types/model_types.py @@ -239,7 +239,7 @@ def make_model_name(name: AnyModelNameOptions): raise TypeError("Expected: str, Model, or TunedModel") if not (name.startswith("models/") or name.startswith("tunedModels/")): - raise ValueError("Model names should start with `models/` or `tunedModels/`") + raise ValueError("Model names should start with `models/` or `tunedModels/`, got: {name}") return name diff --git a/tests/test_text.py b/tests/test_text.py index 76805b03a..5c9a39665 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -12,7 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import collections +import copy import os import unittest import unittest.mock as mock @@ -22,6 +23,7 @@ from google.generativeai import text as text_service from google.generativeai import client from google.generativeai.types import safety_types +from google.generativeai.types import model_types from absl.testing import absltest from absl.testing import parameterized @@ -31,8 +33,9 @@ def setUp(self): self.client = unittest.mock.MagicMock() client.default_text_client = self.client + client.default_model_client = self.client - self.observed_request = None + self.observed_requests = [] self.responses = {} @@ -45,23 +48,38 @@ def add_client_method(f): def generate_text( request: glm.GenerateTextRequest, ) -> glm.GenerateTextResponse: - self.observed_request = request + self.observed_requests.append(request) return self.responses["generate_text"] @add_client_method def embed_text( request: glm.EmbedTextRequest, ) -> glm.EmbedTextResponse: - self.observed_request = request + self.observed_requests.append(request) return self.responses["embed_text"] @add_client_method def batch_embed_text( request: glm.EmbedTextRequest, ) -> glm.EmbedTextResponse: - self.observed_request = request + self.observed_requests.append(request) return self.responses["batch_embed_text"] + @add_client_method + def count_text_tokens( + request: glm.CountTextTokensRequest + ) -> glm.CountTextTokensResponse: + self.observed_requests.append(request) + return self.responses["count_text_tokens"] + + @add_client_method + def get_tuned_model( + name + ) -> glm.TunedModel: + request = glm.GetTunedModelRequest(name=name) + self.observed_requests.append(request) + response = copy.copy(self.responses["get_tuned_model"]) + return response @parameterized.named_parameters( [ dict(testcase_name="string", prompt="Hello how are"), @@ -102,7 +120,7 @@ def test_generate_embeddings(self, model, text): self.assertIsInstance(emb, dict) self.assertEqual( - self.observed_request, glm.EmbedTextRequest(model=model, text=text) + self.observed_requests[-1], glm.EmbedTextRequest(model=model, text=text) ) self.assertIsInstance(emb["embedding"][0], float) @@ -124,7 +142,7 @@ def test_generate_embeddings_batch(self, model, text): self.assertIsInstance(emb, dict) self.assertEqual( - self.observed_request, glm.BatchEmbedTextRequest(model=model, texts=text) + self.observed_requests[-1], glm.BatchEmbedTextRequest(model=model, texts=text) ) self.assertIsInstance(emb["embedding"][0], list) @@ -160,7 +178,7 @@ def test_generate_response(self, *, prompt, **kwargs): complete = text_service.generate_text(prompt=prompt, **kwargs) self.assertEqual( - self.observed_request, + self.observed_requests[-1], glm.GenerateTextRequest( model="models/text-bison-001", prompt=glm.TextPrompt(text=prompt), @@ -190,7 +208,7 @@ def test_stop_string(self): complete = text_service.generate_text(prompt="Hello", stop_sequences="stop") self.assertEqual( - self.observed_request, + self.observed_requests[-1], glm.GenerateTextRequest( model="models/text-bison-001", prompt=glm.TextPrompt(text="Hello"), @@ -198,7 +216,7 @@ def test_stop_string(self): ), ) # Just make sure it made it into the request object. - self.assertEqual(self.observed_request.stop_sequences, ["stop"]) + self.assertEqual(self.observed_requests[-1].stop_sequences, ["stop"]) @parameterized.named_parameters( [ @@ -253,7 +271,7 @@ def test_safety_settings(self, safety_settings): ) self.assertEqual( - self.observed_request.safety_settings[0].category, + self.observed_requests[-1].safety_settings[0].category, safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, ) @@ -370,6 +388,41 @@ def test_candidate_citations(self): 6, ) + @parameterized.named_parameters( + [ + dict( + testcase_name="base-name", + model = 'models/text-bison-001' + ), + dict( + testcase_name="tuned-name", + model="tunedModels/bipedal-pangolin-001" + ), + dict( + testcase_name="model", + model = model_types.Model(name="models/text-bison-001", + base_model_id=None, version=None, display_name=None, + description=None, input_token_limit=None, output_token_limit=None, + supported_generation_methods=None) + ), + dict( + testcase_name="tuned_model", + model = model_types.TunedModel(name="tunedModels/bipedal-pangolin-001", base_model='models/text-bison-001') + ), + ] + ) + def test_count_message_tokens(self, model): + self.responses["get_tuned_model"] = glm.TunedModel( + name="tunedModels/bipedal-pangolin-001", base_model='models/text-bison-001' + ) + self.responses["count_text_tokens"] = glm.CountTextTokensResponse(token_count=7) + + response = text_service.count_text_tokens(model, 'Tell me a story about a magic backpack.') + self.assertEqual({'token_count':7}, response) + if len(self.observed_requests) > 1: + self.assertEqual(self.observed_requests[0], glm.GetTunedModelRequest(name="tunedModels/bipedal-pangolin-001")) + + if __name__ == "__main__": absltest.main()