Skip to content

Commit

Permalink
reserve space in decode_utf8
Browse files Browse the repository at this point in the history
This change makes llama_sample_grammar go 25% faster for llama2.

This change is a cherry-pick of ggerganov/llama.cpp@f837c3a
  • Loading branch information
MarcusDunn authored and jart committed Dec 1, 2023
1 parent 95703b6 commit e574488
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions llama.cpp/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6137,10 +6137,13 @@ struct llama_grammar_candidate {
// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
const char * src,
size_t n_src,
llama_partial_utf8 partial_start) {
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
const char * pos = src;
std::vector<uint32_t> code_points;
// common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
code_points.reserve(n_src + 1);
uint32_t value = partial_start.value;
int n_remain = partial_start.n_remain;

Expand Down Expand Up @@ -6191,6 +6194,13 @@ static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
}

static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
std::string src,
llama_partial_utf8 partial_start
) {
return decode_utf8(src.c_str(), src.size(), partial_start);
}

// returns true iff pos points to the end of one of the definitions of a rule
static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
switch (pos->type) {
Expand Down Expand Up @@ -6840,7 +6850,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
} else if (piece.empty() || piece[0] == 0) {
candidates->data[i].logit = -INFINITY;
} else {
candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8));
candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
}
}
Expand Down Expand Up @@ -7047,7 +7057,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
const std::string piece = llama_token_to_piece(ctx, token);

// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8);
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
const auto & code_points = decoded.first;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
Expand Down

0 comments on commit e574488

Please sign in to comment.