Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server : fix logprobs, make it OAI-compatible #10783

Merged
merged 21 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,10 @@ node index.js

### POST `/completion`: Given a `prompt`, it returns the predicted completion.

> [!IMPORTANT]
>
> This endpoint is **not** OAI-compatible

*Options:*

`prompt`: Provide the prompt for this completion as a string or as an array of strings or numbers representing tokens. Internally, if `cache_prompt` is `true`, the prompt is compared to the previous completion and only the "unseen" suffix is evaluated. A `BOS` token is inserted at the start, if all of the following conditions are true:
Expand Down Expand Up @@ -448,27 +452,48 @@ These words will not be included in the completion, so make sure to add them to

- Note: When using streaming mode (`stream`), only `content` and `stop` will be returned until end of completion.

- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has the following structure:
- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has a nested array `top_logprobs`. It contains at **maximum** `n_probs` elements:

```json
{
"content": "<the token selected by the model>",
"probs": [
"content": "<the generated completion text>",
...
"completion_probabilities": [
{
"id": <token id>,
"prob": float,
"tok_str": "<most likely token>"
"token": "<most likely token>",
"bytes": [int, int, ...],
"top_logprobs": [
{
"id": <token id>,
"prob": float,
ngxson marked this conversation as resolved.
Show resolved Hide resolved
"token": "<token text>",
"bytes": [int, int, ...],
},
{
"id": <token id>,
"prob": float,
"token": "<token text>",
"bytes": [int, int, ...],
},
...
]
},
{
"id": <token id>,
"prob": float,
"tok_str": "<second most likely token>"
"token": "<most likely token>",
"bytes": [int, int, ...],
"top_logprobs": [
...
]
},
...
]
},
```

Notice that each `probs` is an array of length `n_probs`.

- `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string.
- `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options)
- `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.).
Expand Down
151 changes: 89 additions & 62 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,11 @@ struct server_task {
}
}

if (params.sampling.n_probs > 0 && params.cache_prompt) {
SRV_WRN("cache_prompt is not compatible with n_probs > 0 (current value = %d), disabling cache_prompt.\n", params.sampling.n_probs);
ngxson marked this conversation as resolved.
Show resolved Hide resolved
params.cache_prompt = false;
}

std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
params.oaicompat_model = json_value(data, "model", model_name);

Expand Down Expand Up @@ -416,6 +421,7 @@ inline std::string stop_type_to_str(stop_type type) {

struct completion_token_output {
llama_token tok;
float prob;
std::string text_to_send;
struct token_prob {
llama_token tok;
Expand All @@ -427,25 +433,41 @@ struct completion_token_output {
json to_json() const {
json probs_for_token = json::array();
for (const auto & p : probs) {
std::string tok_str(p.tok_str);
tok_str.resize(validate_utf8(tok_str));
probs_for_token.push_back(json {
{"tok_str", p.tok_str},
{"prob", p.prob},
{"id", p.tok},
{"token", tok_str},
{"bytes", str_to_bytes(p.tok_str)},
{"logprob", p.prob},
});
}
return probs_for_token;
}

static json probs_vector_to_json(const std::vector<completion_token_output> & probs) {
json out = json::array();
for (const auto & prob : probs) {
const std::string tok_str = prob.text_to_send;
for (const auto & it : probs) {
ggerganov marked this conversation as resolved.
Show resolved Hide resolved
std::string tok_str(it.text_to_send);
tok_str.resize(validate_utf8(tok_str));
out.push_back(json {
{"content", tok_str},
{"probs", prob.to_json()},
{"id", it.tok},
{"token", tok_str},
{"logprob", it.prob},
{"bytes", str_to_bytes(it.text_to_send)},
{"top_logprobs", it.to_json()},
});
}
return out;
}

static std::vector<unsigned char> str_to_bytes(const std::string & str) {
std::vector<unsigned char> bytes;
for (unsigned char c : str) {
bytes.push_back(c);
}
return bytes;
}
};

struct server_task_result_cmpl_final : server_task_result {
Expand Down Expand Up @@ -506,7 +528,7 @@ struct server_task_result_cmpl_final : server_task_result {
{"tokens_cached", n_tokens_cached},
{"timings", timings.to_json()},
};
if (!probs_output.empty()) {
if (!stream && !probs_output.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
}
return res;
Expand All @@ -518,19 +540,25 @@ struct server_task_result_cmpl_final : server_task_result {
finish_reason = "stop";
}

json choices = json::array({json{
json choice = json{
{"finish_reason", finish_reason},
{"index", 0},
{"message", json{
{"content", content},
{"role", "assistant"}
}
}}});
}};

if (!stream && probs_output.size() > 0) {
choice["logprobs"] = json{
{"content", completion_token_output::probs_vector_to_json(probs_output)},
};
}

std::time_t t = std::time(0);

json res = json {
{"choices", choices},
{"choices", json::array({choice})},
{"created", t},
{"model", oaicompat_model},
{"object", "chat.completion"},
Expand Down Expand Up @@ -560,12 +588,14 @@ struct server_task_result_cmpl_final : server_task_result {
finish_reason = "stop";
}

json choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}});
json choice = json{
{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}
};

json ret = json {
{"choices", choices},
{"choices", json::array({choice})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
Expand All @@ -592,7 +622,7 @@ struct server_task_result_cmpl_partial : server_task_result {
int32_t n_decoded;
int32_t n_prompt_tokens;

std::vector<completion_token_output> probs_output;
completion_token_output prob_output;
result_timings timings;

// OAI-compat fields
Expand Down Expand Up @@ -628,8 +658,8 @@ struct server_task_result_cmpl_partial : server_task_result {
if (timings.prompt_n > 0) {
res.push_back({"timings", timings.to_json()});
}
if (!probs_output.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
if (!prob_output.probs.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output});
}
return res;
}
Expand Down Expand Up @@ -681,6 +711,14 @@ struct server_task_result_cmpl_partial : server_task_result {
}});
}

GGML_ASSERT(choices.size() >= 1);

if (prob_output.probs.size() > 0) {
choices[0]["logprobs"] = json{
{"content", completion_token_output::probs_vector_to_json({prob_output})},
};
}

json ret = json {
{"choices", choices},
{"created", t},
Expand Down Expand Up @@ -951,7 +989,6 @@ struct server_slot {

// stats
size_t n_sent_text = 0; // number of sent text character
size_t n_sent_token_probs = 0;

int64_t t_start_process_prompt;
int64_t t_start_generation;
Expand All @@ -973,7 +1010,6 @@ struct server_slot {
stopping_word = "";
n_past = 0;
n_sent_text = 0;
n_sent_token_probs = 0;
task_type = SERVER_TASK_TYPE_COMPLETION;

generated_token_probs.clear();
Expand Down Expand Up @@ -1713,34 +1749,15 @@ struct server_context {

bool process_token(completion_token_output & result, server_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special);
const std::string token_str = result.text_to_send;
slot.sampled = result.tok;

// search stop word and delete it
slot.generated_text += token_str;
slot.has_next_token = true;

// check if there is incomplete UTF-8 character at the end
bool incomplete = false;
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
unsigned char c = slot.generated_text[slot.generated_text.size() - i];
if ((c & 0xC0) == 0x80) {
// continuation byte: 10xxxxxx
continue;
}
if ((c & 0xE0) == 0xC0) {
// 2-byte character: 110xxxxx ...
incomplete = i < 2;
} else if ((c & 0xF0) == 0xE0) {
// 3-byte character: 1110xxxx ...
incomplete = i < 3;
} else if ((c & 0xF8) == 0xF0) {
// 4-byte character: 11110xxx ...
incomplete = i < 4;
}
// else 1-byte character or invalid byte
break;
}
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();

if (!incomplete) {
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
Expand Down Expand Up @@ -1869,6 +1886,29 @@ struct server_context {
return slot.has_next_token; // continue
}

void populate_token_probs(const server_slot & slot, completion_token_output & result) {
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
const size_t max_probs = cur_p->size;

// set prob for the sampled token
for (size_t i = 0; i < max_probs; ++i) {
if (result.tok == cur_p->data[i].id) {
result.prob = cur_p->data[i].p;
break;
}
}

// set probs for the top n tokens
for (size_t i = 0; i < std::min(max_probs, (size_t) slot.params.sampling.n_probs); ++i) {
auto tok_id = cur_p->data[i].id;
result.probs.push_back({
tok_id,
tokens_to_output_formatted_string(ctx, tok_id),
cur_p->data[i].p,
});
}
}

void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
send_error(task.id, error, type);
}
Expand Down Expand Up @@ -1906,17 +1946,7 @@ struct server_context {

// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);

const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());

std::vector<completion_token_output> probs_output;
if (probs_pos < probs_stop_pos) {
res->probs_output = std::vector<completion_token_output>(
slot.generated_token_probs.begin() + probs_pos,
slot.generated_token_probs.begin() + probs_stop_pos);
}
res->prob_output = tkn; // copy the token probs
}

// populate timings if this is final response or timings_per_token is enabled
Expand Down Expand Up @@ -2747,17 +2777,12 @@ struct server_context {
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;

completion_token_output result;
result.tok = id;
result.tok = id;
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
result.prob = 1.0f; // set later

const auto * cur_p = common_sampler_get_candidates(slot.smpl);

for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
auto tok_id = cur_p->data[i].id;
result.probs.push_back({
tok_id,
tokens_to_output_formatted_string(ctx, tok_id),
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
});
if (slot.params.sampling.n_probs > 0) {
populate_token_probs(slot, result);
}

if (!process_token(result, slot)) {
Expand Down Expand Up @@ -2841,7 +2866,9 @@ struct server_context {
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;

result.tok = ids[i];
result.tok = ids[i];
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
result.prob = 1.0f; // set later
Copy link
Collaborator Author

@ngxson ngxson Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the branch for speculative decoding. I'm not sure now I can get token probs here. Could you give me some clues? @ggerganov

(Or we can skip this for now if it's complicated)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we need to update common_sampler_sample_and_accept_n() to return the probs. But let's fix this later.


if (!process_token(result, slot)) {
// release slot because of stop condition
Expand Down
Loading
Loading