Skip to content

Commit

Permalink
token count includes ids
Browse files Browse the repository at this point in the history
  • Loading branch information
LostRuins committed Dec 3, 2023
1 parent 0ca814e commit 6570a20
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 9 deletions.
11 changes: 8 additions & 3 deletions expose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ extern "C"
return gpttype_generate(inputs, output);
}

const char* new_token(int idx) {
const char * new_token(int idx) {
if (generated_tokens.size() <= idx || idx < 0) return nullptr;

return generated_tokens[idx].c_str();
Expand Down Expand Up @@ -232,9 +232,14 @@ extern "C"
return gpttype_generate_abort();
}

int token_count(const char * input)
static std::vector<int> toks; //just share a static object for token counting
token_count_outputs token_count(const char * input)
{
std::string inputstr = input;
return gpttype_token_count(inputstr);
token_count_outputs output;
toks = gpttype_get_token_arr(inputstr);
output.count = toks.size();
output.ids = toks.data(); //this may be slightly unsafe
return output;
}
}
5 changes: 5 additions & 0 deletions expose.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ struct generation_outputs
int status = -1;
char text[32768]; //32kb should be enough for any response
};
struct token_count_outputs
{
int count = 0;
int * ids; //we'll just use shared memory for this one, bit of a hack
};

extern std::string executable_path;
extern std::string lora_filename;
Expand Down
4 changes: 2 additions & 2 deletions gpttype_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1390,7 +1390,7 @@ bool gpttype_generate_abort()
return true;
}

int gpttype_token_count(const std::string & input)
std::vector<int> gpttype_get_token_arr(const std::string & input)
{
if(debugmode==1)
{
Expand All @@ -1403,7 +1403,7 @@ int gpttype_token_count(const std::string & input)
{
printf("\nTokens Counted: %d\n",tokcount);
}
return tokcount;
return toks;
}

const std::string & gpttype_get_pending_output()
Expand Down
13 changes: 10 additions & 3 deletions koboldcpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ class generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
("text", ctypes.c_char * 32768)]

class token_count_outputs(ctypes.Structure):
_fields_ = [("count", ctypes.c_int),
("ids", ctypes.POINTER(ctypes.c_int))]

handle = None

def getdirpath():
Expand Down Expand Up @@ -218,7 +222,7 @@ def init_library():
handle.get_total_gens.restype = ctypes.c_int
handle.get_last_stop_reason.restype = ctypes.c_int
handle.abort_generate.restype = ctypes.c_bool
handle.token_count.restype = ctypes.c_int
handle.token_count.restype = token_count_outputs
handle.get_pending_output.restype = ctypes.c_char_p

def load_model(model_filename):
Expand Down Expand Up @@ -729,8 +733,11 @@ def do_POST(self):
try:
genparams = json.loads(body)
countprompt = genparams.get('prompt', "")
count = handle.token_count(countprompt.encode("UTF-8"))
response_body = (json.dumps({"value": count}).encode())
rawcountdata = handle.token_count(countprompt.encode("UTF-8"))
countlimit = rawcountdata.count if (rawcountdata.count>=0 and rawcountdata.count<50000) else 0
# the above protects the server in case the count limit got corrupted
countdata = [rawcountdata.ids[i] for i in range(countlimit)]
response_body = (json.dumps({"value": len(countdata),"ids": countdata}).encode())

except Exception as e:
utfprint("Count Tokens - Body Error: " + str(e))
Expand Down
2 changes: 1 addition & 1 deletion model_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output);
bool gpttype_generate_abort();
const std::string & gpttype_get_pending_output();
int gpttype_token_count(const std::string & input);
std::vector<int> gpttype_get_token_arr(const std::string & input);

void timer_start();
double timer_check();
Expand Down

0 comments on commit 6570a20

Please sign in to comment.