diff --git a/expose.cpp b/expose.cpp index aeb7066f14a42..68167d79b08dc 100644 --- a/expose.cpp +++ b/expose.cpp @@ -194,7 +194,7 @@ extern "C" return gpttype_generate(inputs, output); } - const char* new_token(int idx) { + const char * new_token(int idx) { if (generated_tokens.size() <= idx || idx < 0) return nullptr; return generated_tokens[idx].c_str(); @@ -232,9 +232,14 @@ extern "C" return gpttype_generate_abort(); } - int token_count(const char * input) + static std::vector toks; //just share a static object for token counting + token_count_outputs token_count(const char * input) { std::string inputstr = input; - return gpttype_token_count(inputstr); + token_count_outputs output; + toks = gpttype_get_token_arr(inputstr); + output.count = toks.size(); + output.ids = toks.data(); //this may be slightly unsafe + return output; } } diff --git a/expose.h b/expose.h index 3e17778d73b25..25e855224a3cc 100644 --- a/expose.h +++ b/expose.h @@ -83,6 +83,11 @@ struct generation_outputs int status = -1; char text[32768]; //32kb should be enough for any response }; +struct token_count_outputs +{ + int count = 0; + int * ids; //we'll just use shared memory for this one, bit of a hack +}; extern std::string executable_path; extern std::string lora_filename; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index fb12bb4f6f2bf..d6837cfdf726d 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -1390,7 +1390,7 @@ bool gpttype_generate_abort() return true; } -int gpttype_token_count(const std::string & input) +std::vector gpttype_get_token_arr(const std::string & input) { if(debugmode==1) { @@ -1403,7 +1403,7 @@ int gpttype_token_count(const std::string & input) { printf("\nTokens Counted: %d\n",tokcount); } - return tokcount; + return toks; } const std::string & gpttype_get_pending_output() diff --git a/koboldcpp.py b/koboldcpp.py index 2471103cc9497..188abd277e1c3 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -77,6 +77,10 @@ class generation_outputs(ctypes.Structure): _fields_ = [("status", ctypes.c_int), ("text", ctypes.c_char * 32768)] +class token_count_outputs(ctypes.Structure): + _fields_ = [("count", ctypes.c_int), + ("ids", ctypes.POINTER(ctypes.c_int))] + handle = None def getdirpath(): @@ -218,7 +222,7 @@ def init_library(): handle.get_total_gens.restype = ctypes.c_int handle.get_last_stop_reason.restype = ctypes.c_int handle.abort_generate.restype = ctypes.c_bool - handle.token_count.restype = ctypes.c_int + handle.token_count.restype = token_count_outputs handle.get_pending_output.restype = ctypes.c_char_p def load_model(model_filename): @@ -729,8 +733,11 @@ def do_POST(self): try: genparams = json.loads(body) countprompt = genparams.get('prompt', "") - count = handle.token_count(countprompt.encode("UTF-8")) - response_body = (json.dumps({"value": count}).encode()) + rawcountdata = handle.token_count(countprompt.encode("UTF-8")) + countlimit = rawcountdata.count if (rawcountdata.count>=0 and rawcountdata.count<50000) else 0 + # the above protects the server in case the count limit got corrupted + countdata = [rawcountdata.ids[i] for i in range(countlimit)] + response_body = (json.dumps({"value": len(countdata),"ids": countdata}).encode()) except Exception as e: utfprint("Count Tokens - Body Error: " + str(e)) diff --git a/model_adapter.h b/model_adapter.h index 65536c6d4a673..7180d4f849444 100644 --- a/model_adapter.h +++ b/model_adapter.h @@ -68,7 +68,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output); bool gpttype_generate_abort(); const std::string & gpttype_get_pending_output(); -int gpttype_token_count(const std::string & input); +std::vector gpttype_get_token_arr(const std::string & input); void timer_start(); double timer_check();