From c29fa5d98fc1c1550c039e84ff2f5725818c2231 Mon Sep 17 00:00:00 2001 From: Qiao Wang Date: Wed, 4 Sep 2024 11:34:43 -0700 Subject: [PATCH] fix: Tokenizers - Fixed `Tokenizer.compute_tokens` PiperOrigin-RevId: 671042779 --- tests/system/vertexai/test_tokenization.py | 39 +++++++++++++++------- vertexai/tokenization/_tokenizers.py | 31 ++++++++--------- 2 files changed, 40 insertions(+), 30 deletions(-) diff --git a/tests/system/vertexai/test_tokenization.py b/tests/system/vertexai/test_tokenization.py index 9dacfc69d4..2795831a95 100644 --- a/tests/system/vertexai/test_tokenization.py +++ b/tests/system/vertexai/test_tokenization.py @@ -20,7 +20,10 @@ from nltk.corpus import udhr from google.cloud import aiplatform from vertexai.preview.tokenization import ( - get_tokenizer_for_model, + get_tokenizer_for_model as tokenizer_preview, +) +from vertexai.tokenization._tokenizers import ( + get_tokenizer_for_model as tokenizer_ga, ) from vertexai.generative_models import ( GenerativeModel, @@ -44,8 +47,10 @@ _CORPUS_LIB = [ udhr, ] +_VERSIONED_TOKENIZER = [tokenizer_preview, tokenizer_ga] _MODEL_CORPUS_PARAMS = [ - (model_name, corpus_name, corpus_lib) + (get_tokenizer_for_model, model_name, corpus_name, corpus_lib) + for get_tokenizer_for_model in _VERSIONED_TOKENIZER for model_name in _MODELS for (corpus_name, corpus_lib) in zip(_CORPUS, _CORPUS_LIB) ] @@ -125,11 +130,16 @@ def setup_method(self, api_endpoint_env_name): ) @pytest.mark.parametrize( - "model_name, corpus_name, corpus_lib", + "get_tokenizer_for_model, model_name, corpus_name, corpus_lib", _MODEL_CORPUS_PARAMS, ) def test_count_tokens_local( - self, model_name, corpus_name, corpus_lib, api_endpoint_env_name + self, + get_tokenizer_for_model, + model_name, + corpus_name, + corpus_lib, + api_endpoint_env_name, ): # The Gemini 1.5 flash model requires the model version # number suffix (001) in staging only @@ -145,11 +155,16 @@ def test_count_tokens_local( assert service_result.total_tokens == local_result.total_tokens @pytest.mark.parametrize( - "model_name, corpus_name, corpus_lib", + "get_tokenizer_for_model, model_name, corpus_name, corpus_lib", _MODEL_CORPUS_PARAMS, ) def test_compute_tokens( - self, model_name, corpus_name, corpus_lib, api_endpoint_env_name + self, + get_tokenizer_for_model, + model_name, + corpus_name, + corpus_lib, + api_endpoint_env_name, ): # The Gemini 1.5 flash model requires the model version # number suffix (001) in staging only @@ -171,7 +186,7 @@ def test_compute_tokens( _MODELS, ) def test_count_tokens_system_instruction(self, model_name): - tokenizer = get_tokenizer_for_model(model_name) + tokenizer = tokenizer_preview(model_name) model = GenerativeModel(model_name, system_instruction=["You are a chatbot."]) assert ( @@ -188,7 +203,7 @@ def test_count_tokens_system_instruction(self, model_name): def test_count_tokens_system_instruction_is_function_call(self, model_name): part = Part._from_gapic(gapic_content_types.Part(function_call=_FUNCTION_CALL)) - tokenizer = get_tokenizer_for_model(model_name) + tokenizer = tokenizer_preview(model_name) model = GenerativeModel(model_name, system_instruction=[part]) assert ( @@ -204,7 +219,7 @@ def test_count_tokens_system_instruction_is_function_response(self, model_name): part = Part._from_gapic( gapic_content_types.Part(function_response=_FUNCTION_RESPONSE) ) - tokenizer = get_tokenizer_for_model(model_name) + tokenizer = tokenizer_preview(model_name) model = GenerativeModel(model_name, system_instruction=[part]) assert tokenizer.count_tokens(part, system_instruction=[part]).total_tokens @@ -218,7 +233,7 @@ def test_count_tokens_system_instruction_is_function_response(self, model_name): _MODELS, ) def test_count_tokens_tool_is_function_declaration(self, model_name): - tokenizer = get_tokenizer_for_model(model_name) + tokenizer = tokenizer_preview(model_name) model = GenerativeModel(model_name) tool1 = Tool._from_gapic( gapic_tool_types.Tool(function_declarations=[_FUNCTION_DECLARATION_1]) @@ -241,7 +256,7 @@ def test_count_tokens_tool_is_function_declaration(self, model_name): ) def test_count_tokens_content_is_function_call(self, model_name): part = Part._from_gapic(gapic_content_types.Part(function_call=_FUNCTION_CALL)) - tokenizer = get_tokenizer_for_model(model_name) + tokenizer = tokenizer_preview(model_name) model = GenerativeModel(model_name) assert tokenizer.count_tokens(part).total_tokens @@ -258,7 +273,7 @@ def test_count_tokens_content_is_function_response(self, model_name): part = Part._from_gapic( gapic_content_types.Part(function_response=_FUNCTION_RESPONSE) ) - tokenizer = get_tokenizer_for_model(model_name) + tokenizer = tokenizer_preview(model_name) model = GenerativeModel(model_name) assert tokenizer.count_tokens(part).total_tokens diff --git a/vertexai/tokenization/_tokenizers.py b/vertexai/tokenization/_tokenizers.py index 1a88c97d01..17e2b5b823 100644 --- a/vertexai/tokenization/_tokenizers.py +++ b/vertexai/tokenization/_tokenizers.py @@ -53,20 +53,6 @@ class TokensInfo: role: str = None -@dataclasses.dataclass(frozen=True) -class ComputeTokensResult: - tokens_info: Sequence[TokensInfo] - - -class PreviewComputeTokensResult(ComputeTokensResult): - def token_info_list(self) -> Sequence[TokensInfo]: - import warnings - - message = "PreviewComputeTokensResult.token_info_list is deprecated. Use ComputeTokensResult.tokens_info instead." - warnings.warn(message, DeprecationWarning, stacklevel=2) - return self.tokens_info - - @dataclasses.dataclass(frozen=True) class ComputeTokensResult: """Represents token string pieces and ids output in compute_tokens function. @@ -78,11 +64,18 @@ class ComputeTokensResult: item represents each string instance. Each token info consists tokens list, token_ids list and a role. - token_info_list: the value in this field equal to tokens_info. """ tokens_info: Sequence[TokensInfo] - token_info_list: Sequence[TokensInfo] + + +class PreviewComputeTokensResult(ComputeTokensResult): + def token_info_list(self) -> Sequence[TokensInfo]: + import warnings + + message = "PreviewComputeTokensResult.token_info_list is deprecated. Use ComputeTokensResult.tokens_info instead." + warnings.warn(message, DeprecationWarning, stacklevel=2) + return self.tokens_info @dataclasses.dataclass(frozen=True) @@ -169,7 +162,7 @@ def compute_tokens( role=role, ) ) - return ComputeTokensResult(token_info_list=token_infos, tokens_info=token_infos) + return ComputeTokensResult(tokens_info=token_infos) def _to_gapic_contents( @@ -539,7 +532,9 @@ def compute_tokens(self, contents: ContentsType) -> ComputeTokensResult: class PreviewTokenizer(Tokenizer): def compute_tokens(self, contents: ContentsType) -> PreviewComputeTokensResult: - return PreviewComputeTokensResult(tokens_info=super().compute_tokens(contents)) + return PreviewComputeTokensResult( + tokens_info=super().compute_tokens(contents).tokens_info + ) def _get_tokenizer_for_model_preview(model_name: str) -> PreviewTokenizer: