From d10b887e9791274efd4a13272dba53094fd4f06c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Aires=20Rast=C3=A9n?= Date: Fri, 3 May 2024 14:45:53 +0200 Subject: [PATCH] add_special option for server tokenize endpoint --- examples/server/README.md | 2 +- examples/server/server.cpp | 3 +- examples/server/tests/features/server.feature | 13 +++++++- examples/server/tests/features/steps/steps.py | 33 ++++++++++++++++--- 4 files changed, 44 insertions(+), 7 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index b96a4444a2bd3b..961c96721813a8 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -319,7 +319,7 @@ Notice that each `probs` is an array of length `n_probs`. `content`: Set the text to tokenize. - Note that a special `BOS` token is never inserted. + `add_special`: Boolean indicating if special tokens, i.e. `BOS`, should be inserted. Default: `false` - **POST** `/detokenize`: Convert tokens to text. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f60530cf3db561..c921168747560f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3632,7 +3632,8 @@ int main(int argc, char ** argv) { std::vector tokens; if (body.count("content") != 0) { - tokens = ctx_server.tokenize(body["content"], false); + const bool add_special = json_value(body, "add_special", false); + tokens = ctx_server.tokenize(body["content"], add_special); } const json data = format_tokenizer_response(tokens); return res.set_content(data.dump(), "application/json; charset=utf-8"); diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature index 646a4e49d0d56f..c2b54b840b85a0 100644 --- a/examples/server/tests/features/server.feature +++ b/examples/server/tests/features/server.feature @@ -91,7 +91,18 @@ Feature: llama.cpp server """ What is the capital of France ? """ - Then tokens can be detokenize + Then tokens can be detokenized + And tokens do not begin with BOS + + Scenario: Tokenize w/ BOS + Given adding special tokens + When tokenizing: + """ + What is the capital of Germany? + """ + Then tokens begin with BOS + Given first token is removed + Then tokens can be detokenized Scenario: Models available Given available models diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index b8dbef21d1b768..1ab8cf102989eb 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -656,21 +656,29 @@ async def all_embeddings_are_generated(context): assert_embeddings(context.tasks_result.pop().pop()) +@step('adding special tokens') +def step_tokenize_set_add_special(context): + context.tokenize_add_special = True + + @step('tokenizing') @async_run_until_complete async def step_tokenize(context): context.tokenized_text = context_text(context) async with aiohttp.ClientSession() as session: + tokenize_args = { + "content": context.tokenized_text, + } + if getattr(context, 'tokenize_add_special', None) is not None: + tokenize_args['add_special'] = context.tokenize_add_special async with session.post(f'{context.base_url}/tokenize', - json={ - "content": context.tokenized_text, - }) as response: + json=tokenize_args) as response: assert response.status == 200 tokenize_json = await response.json() context.tokens = tokenize_json['tokens'] -@step('tokens can be detokenize') +@step('tokens can be detokenized') @async_run_until_complete async def step_detokenize(context): assert len(context.tokens) > 0 @@ -685,6 +693,21 @@ async def step_detokenize(context): assert context.tokenized_text == detokenize_json['content'].strip() +@step('tokens begin with BOS') +def step_strings_for_tokenization(context): + assert context.tokens[0] == context.bos + + +@step('tokens do not begin with BOS') +def step_strings_for_tokenization(context): + assert context.tokens[0] != context.bos + + +@step('first token is removed') +def step_strings_for_tokenization(context): + context.tokens = context.tokens[1:] + + @step('an OPTIONS request is sent from {origin}') @async_run_until_complete async def step_options_request(context, origin): @@ -1289,4 +1312,6 @@ def server_log(in_stream, out_stream): thread_stderr = threading.Thread(target=server_log, args=(context.server_process.stderr, sys.stderr)) thread_stderr.start() + context.bos = 1 + print(f"server pid={context.server_process.pid}, behave pid={os.getpid()}")