Skip to content

Commit

Permalink
feat: LLM - Added the count_tokens method to the preview `TextGener…
Browse files Browse the repository at this point in the history
…ationModel` and `TextEmbeddingModel` classes

PiperOrigin-RevId: 570108703
  • Loading branch information
sararob authored and copybara-github committed Oct 2, 2023
1 parent 69a67f2 commit 6a2f2aa
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 6 deletions.
12 changes: 12 additions & 0 deletions tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ def test_text_generation(self):
stop_sequences=["# %%"],
).text

def test_text_generation_preview_count_tokens(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

model = preview_language_models.TextGenerationModel.from_pretrained(
"google/text-bison@001"
)

response = model.count_tokens(["How are you doing?"])

assert response.total_tokens
assert response.total_billable_characters

@pytest.mark.asyncio
async def test_text_generation_model_predict_async(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
Expand Down
90 changes: 90 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@
model as gca_model,
)

from google.cloud.aiplatform_v1beta1.services.prediction_service import (
client as prediction_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.types import (
prediction_service as gca_prediction_service_v1beta1,
)

import vertexai
from vertexai.preview import (
language_models as preview_language_models,
Expand Down Expand Up @@ -306,6 +313,11 @@ def reverse_string_2(s):""",
}
}

_TEST_COUNT_TOKENS_RESPONSE = {
"total_tokens": 5,
"total_billable_characters": 25,
}


_TEST_TEXT_BISON_TRAINING_DF = pd.DataFrame(
{
Expand Down Expand Up @@ -1206,6 +1218,43 @@ def test_text_generation(self):
== _TEST_TEXT_GENERATION_PREDICTION["safetyAttributes"]["scores"][0]
)

def test_text_generation_preview_count_tokens(self):
"""Tests the text generation model."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_BISON_PUBLISHER_MODEL_DICT
),
):
model = preview_language_models.TextGenerationModel.from_pretrained(
"text-bison@001"
)

gca_count_tokens_response = gca_prediction_service_v1beta1.CountTokensResponse(
total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"],
total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[
"total_billable_characters"
],
)

with mock.patch.object(
target=prediction_service_client_v1beta1.PredictionServiceClient,
attribute="count_tokens",
return_value=gca_count_tokens_response,
):
response = model.count_tokens(["What is the best recipe for banana bread?"])

assert response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"]
assert (
response.total_billable_characters
== _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"]
)

def test_text_generation_ga(self):
"""Tests the text generation model."""
aiplatform.init(
Expand Down Expand Up @@ -2469,6 +2518,47 @@ def test_text_embedding(self):
== expected_embedding["statistics"]["truncated"]
)

def test_text_embedding_preview_count_tokens(self):
"""Tests the text embedding model."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
),
):
model = preview_language_models.TextEmbeddingModel.from_pretrained(
"textembedding-gecko@001"
)

gca_count_tokens_response = (
gca_prediction_service_v1beta1.CountTokensResponse(
total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"],
total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[
"total_billable_characters"
],
)
)

with mock.patch.object(
target=prediction_service_client_v1beta1.PredictionServiceClient,
attribute="count_tokens",
return_value=gca_count_tokens_response,
):
response = model.count_tokens(["What is life?"])

assert (
response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"]
)
assert (
response.total_billable_characters
== _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"]
)

def test_text_embedding_ga(self):
"""Tests the text embedding model."""
aiplatform.init(
Expand Down
94 changes: 88 additions & 6 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,15 @@ def _model_resource_name(self) -> str:
@dataclasses.dataclass
class _PredictionRequest:
"""A single-instance prediction request."""

instance: Dict[str, Any]
parameters: Optional[Dict[str, Any]] = None


@dataclasses.dataclass
class _MultiInstancePredictionRequest:
"""A multi-instance prediction request."""

instances: List[Dict[str, Any]]
parameters: Optional[Dict[str, Any]] = None

Expand Down Expand Up @@ -573,6 +575,62 @@ def tune_model(
return job


@dataclasses.dataclass
class CountTokensResponse:
"""The response from a count_tokens request.
Attributes:
total_tokens (int):
The total number of tokens counted across all
instances passed to the request.
total_billable_characters (int):
The total number of billable characters
counted across all instances from the request.
"""

total_tokens: int
total_billable_characters: int
_count_tokens_response: Any


class _CountTokensMixin(_LanguageModel):
"""Mixin for models that support the CountTokens API"""

def count_tokens(
self,
prompts: List[str],
) -> CountTokensResponse:
"""Counts the tokens and billable characters for a given prompt.
Note: this does not make a request to the model, it only counts the tokens
in the request.
Args:
prompts (List[str]):
Required. A list of prompts to ask the model. For example: ["What should I do today?", "How's it going?"]
Returns:
A `CountTokensResponse` object that contains the number of tokens
in the text and the number of billable characters.
"""
instances = []

for prompt in prompts:
instances.append({"content": prompt})

count_tokens_response = self._endpoint._prediction_client.select_version(
"v1beta1"
).count_tokens(
endpoint=self._endpoint_name,
instances=instances,
)

return CountTokensResponse(
total_tokens=count_tokens_response.total_tokens,
total_billable_characters=count_tokens_response.total_billable_characters,
_count_tokens_response=count_tokens_response,
)


@dataclasses.dataclass
class TuningEvaluationSpec:
"""Specification for model evaluation to perform during tuning.
Expand All @@ -587,6 +645,7 @@ class TuningEvaluationSpec:
tensorboard: Vertex Tensorboard where to write the evaluation metrics.
The Tensorboard must be in the same location as the tuning job.
"""

__module__ = "vertexai.language_models"

evaluation_data: str
Expand All @@ -605,6 +664,7 @@ class TextGenerationResponse:
Learn more about the safety attributes here:
https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions
"""

__module__ = "vertexai.language_models"

text: str
Expand Down Expand Up @@ -761,7 +821,9 @@ def predict_streaming(
)

prediction_service_client = self._endpoint._prediction_client
for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
for (
prediction_dict
) in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
prediction_service_client=prediction_service_client,
endpoint_name=self._endpoint_name,
instance=prediction_request.instance,
Expand Down Expand Up @@ -955,6 +1017,7 @@ class _PreviewTextGenerationModel(
_PreviewTunableTextModelMixin,
_PreviewModelWithBatchPredict,
_evaluatable_language_models._EvaluatableLanguageModel,
_CountTokensMixin,
):
# Do not add docstring so that it's inherited from the base class.
__name__ = "TextGenerationModel"
Expand Down Expand Up @@ -1094,6 +1157,7 @@ class TextEmbeddingInput:
Specifies that the embeddings will be used for clustering.
title: Optional identifier of the text content.
"""

__module__ = "vertexai.language_models"

text: str
Expand All @@ -1113,6 +1177,7 @@ class TextEmbeddingModel(_LanguageModel):
vector = embedding.values
print(len(vector))
"""

__module__ = "vertexai.language_models"

_LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE
Expand Down Expand Up @@ -1173,7 +1238,8 @@ def _parse_text_embedding_response(
_prediction_response=prediction_response,
)

def get_embeddings(self,
def get_embeddings(
self,
texts: List[Union[str, TextEmbeddingInput]],
*,
auto_truncate: bool = True,
Expand Down Expand Up @@ -1207,7 +1273,8 @@ def get_embeddings(self,

return results

async def get_embeddings_async(self,
async def get_embeddings_async(
self,
texts: List[Union[str, TextEmbeddingInput]],
*,
auto_truncate: bool = True,
Expand Down Expand Up @@ -1242,7 +1309,9 @@ async def get_embeddings_async(self,
return results


class _PreviewTextEmbeddingModel(TextEmbeddingModel, _ModelWithBatchPredict):
class _PreviewTextEmbeddingModel(
TextEmbeddingModel, _ModelWithBatchPredict, _CountTokensMixin
):
__name__ = "TextEmbeddingModel"
__module__ = "vertexai.preview.language_models"

Expand All @@ -1252,6 +1321,7 @@ class _PreviewTextEmbeddingModel(TextEmbeddingModel, _ModelWithBatchPredict):
@dataclasses.dataclass
class TextEmbeddingStatistics:
"""Text embedding statistics."""

__module__ = "vertexai.language_models"

token_count: int
Expand All @@ -1261,6 +1331,7 @@ class TextEmbeddingStatistics:
@dataclasses.dataclass
class TextEmbedding:
"""Text embedding vector and statistics."""

__module__ = "vertexai.language_models"

values: List[float]
Expand All @@ -1271,6 +1342,7 @@ class TextEmbedding:
@dataclasses.dataclass
class InputOutputTextPair:
"""InputOutputTextPair represents a pair of input and output texts."""

__module__ = "vertexai.language_models"

input_text: str
Expand All @@ -1285,6 +1357,7 @@ class ChatMessage:
content: Content of the message.
author: Author of the message.
"""

__module__ = "vertexai.language_models"

content: str
Expand Down Expand Up @@ -1362,6 +1435,7 @@ class ChatModel(_ChatModelBase, _TunableChatModelMixin):
chat.send_message("Do you know any cool events this weekend?")
"""

__module__ = "vertexai.language_models"

_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml"
Expand All @@ -1388,6 +1462,7 @@ class CodeChatModel(_ChatModelBase):
code_chat.send_message("Please help write a function to calculate the min of two numbers")
"""

__module__ = "vertexai.language_models"

_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/codechat_generation_1.0.0.yaml"
Expand Down Expand Up @@ -1739,7 +1814,9 @@ def send_message_streaming(

full_response_text = ""

for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
for (
prediction_dict
) in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
prediction_service_client=prediction_service_client,
endpoint_name=self._model._endpoint_name,
instance=prediction_request.instance,
Expand Down Expand Up @@ -1770,6 +1847,7 @@ class ChatSession(_ChatSessionBase):
Within a chat session, the model keeps context and remembers the previous conversation.
"""

__module__ = "vertexai.language_models"

def __init__(
Expand Down Expand Up @@ -1802,6 +1880,7 @@ class CodeChatSession(_ChatSessionBase):
Within a code chat session, the model keeps context and remembers the previous converstion.
"""

__module__ = "vertexai.language_models"

def __init__(
Expand Down Expand Up @@ -1924,6 +2003,7 @@ class CodeGenerationModel(_LanguageModel):
prefix="def reverse_string(s):",
))
"""

__module__ = "vertexai.language_models"

_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml"
Expand Down Expand Up @@ -2074,7 +2154,9 @@ def predict_streaming(
)

prediction_service_client = self._endpoint._prediction_client
for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
for (
prediction_dict
) in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
prediction_service_client=prediction_service_client,
endpoint_name=self._endpoint_name,
instance=prediction_request.instance,
Expand Down
Loading

0 comments on commit 6a2f2aa

Please sign in to comment.