Skip to content

Commit 2efc7e4

Browse files
committed
common : add option to sort sampling candidates by probability
ggml-ci
1 parent e665a46 commit 2efc7e4

File tree

8 files changed

+34
-11
lines changed

8 files changed

+34
-11
lines changed

common/sampling.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,29 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
426426

427427
// helpers
428428

429-
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
430-
return &gsmpl->cur_p;
429+
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
430+
auto * res = &gsmpl->cur_p;
431+
432+
if (do_sort && !res->sorted) {
433+
// remember the selected token before sorting
434+
const llama_token id = res->data[res->selected].id;
435+
436+
std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) {
437+
return a.p > b.p;
438+
});
439+
440+
// restore the selected token after sorting
441+
for (size_t i = 0; i < res->size; ++i) {
442+
if (res->data[i].id == id) {
443+
res->selected = i;
444+
break;
445+
}
446+
}
447+
448+
res->sorted = true;
449+
}
450+
451+
return res;
431452
}
432453

433454
llama_token common_sampler_last(const struct common_sampler * gsmpl) {

common/sampling.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
8686
// helpers
8787

8888
// access the internal list of current candidate tokens
89-
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl);
89+
// if do_sort == true, the candidates will be sorted (in descending order of probability) in case they are not already sorted
90+
// if do_sort == false, the candidates *might* not be sorted. use the .sorted flag of the result to determine that
91+
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort);
9092

9193
// get the last accepted token
9294
llama_token common_sampler_last(const struct common_sampler * gsmpl);

common/speculative.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ llama_tokens common_speculative_gen_draft(
317317

318318
common_sampler_sample(smpl, ctx_dft, 0, true);
319319

320-
const auto * cur_p = common_sampler_get_candidates(smpl);
320+
const auto * cur_p = common_sampler_get_candidates(smpl, true);
321321

322322
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
323323
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",

examples/speculative/speculative.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ int main(int argc, char ** argv) {
244244
// stochastic verification
245245
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
246246

247-
auto & dist_tgt = *common_sampler_get_candidates(smpl);
247+
auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
248248

249249
float p_tgt = 0.0f;
250250
float p_dft = 0.0f;
@@ -493,7 +493,7 @@ int main(int argc, char ** argv) {
493493

494494
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
495495

496-
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl);
496+
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);
497497

498498
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
499499
LOG_DBG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",

include/llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ extern "C" {
198198
llama_token_data * data;
199199
size_t size;
200200
int64_t selected; // this is the index in the data array (i.e. not the token id)
201-
bool sorted;
201+
bool sorted; // note: do not assume the data is sorted - always check this flag
202202
} llama_token_data_array;
203203

204204
typedef bool (*llama_progress_callback)(float progress, void * user_data);

src/llama-sampling.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1200,7 +1200,6 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
12001200
return;
12011201
}
12021202

1203-
// in case it's not sorted/recalculated yet
12041203
llama_sampler_softmax_impl(cur_p, &ctx->buf_sort);
12051204

12061205
int pos_last = 0;

tools/server/server.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2485,11 +2485,12 @@ struct server_context {
24852485
return slot.has_next_token; // continue
24862486
}
24872487

2488-
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
2488+
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
24892489
size_t n_probs = slot.params.sampling.n_probs;
24902490
size_t n_vocab = llama_vocab_n_tokens(vocab);
2491+
24912492
if (post_sampling) {
2492-
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
2493+
const auto * cur_p = common_sampler_get_candidates(slot.smpl, true);
24932494
const size_t max_probs = cur_p->size;
24942495

24952496
// set probability for sampled token

tools/tts/tts.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
895895

896896
codes.push_back(new_token_id);
897897

898-
const auto * cands = common_sampler_get_candidates(smpl[i]);
898+
const auto * cands = common_sampler_get_candidates(smpl[i], false);
899899

900900
// is it an end of generation? -> mark the stream as finished
901901
if (llama_vocab_is_eog(vocab, new_token_id) || n_decode == n_predict) {

0 commit comments

Comments
 (0)